Dictionary Learning

Solve orthogonal dictionary learning problem taken from: Yu Bai, Qijia Jiang, and Ju Sun. “Subgradient descent learns orthogonal dictionaries.” arXiv preprint arXiv:1810.10702 (2018).

Problem Description

Given data \(\{y_i \}_{i \in[m]}\) generated as \(y_i = A x_i\), where \(A \in R^{n \times n}\) is a fixed unknown orthogonal matrix and each \(x_i \in R^n\) is an iid Bernoulli-Gaussian random vector with parameter \(\theta \in (0,1)\), recover \(A\).

Write \(Y \doteq [y_1,...,y_m]\) and \(X \doteq [x_1,...,x_m]\). To find the column of \(A\), one can perform the following optimization:

\[\min_{q \in R^n} f(q) \doteq \frac{1}{m} ||q^T Y||_{1} = \frac{1}{m} \sum_{i=1}^m |q^T y_i|,\]
\[\text{s.t.} ||q||_2 = 1\]

This problem is nonconvex due to the constraint and nonsmooth due to the objective.

Based on the above statistical model, \(q^T Y = q^T A X\) has the highest sparsity when \(q\) is a column of \(A\) (up to sign) so that \(q^T A\) is 1-sparse.

Modules Importing

Import all necessary modules and add PyGRANSO src folder to system path.

[1]:
import time
import numpy as np
import torch
import numpy.linalg as la
from scipy.stats import norm
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct

Data Initialization

Specify torch device, and generate data

Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)

[2]:
device = torch.device('cuda')
n = 30

np.random.seed(1)
m = 10*n**2   # sample complexity
theta = 0.3   # sparsity level
Y = norm.ppf(np.random.rand(n,m)) * (norm.ppf(np.random.rand(n,m)) <= theta)  # Bernoulli-Gaussian model
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
Y = torch.from_numpy(Y).to(device=device, dtype=torch.double)

Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

[3]:
# variables and corresponding dimensions.
var_in = {"q": [n,1]}


def user_fn(X_struct,Y):
    q = X_struct.q

    # objective function
    qtY = q.T @ Y
    f = 1/m * torch.norm(qtY, p = 1)

    # inequality constraint, matrix form
    ci = None

    # equality constraint
    ce = pygransoStruct()
    ce.c1 = q.T @ q - 1

    return [f,ci,ce]

comb_fn = lambda X_struct : user_fn(X_struct,Y)

User Options

Specify user-defined options for PyGRANSO

[4]:
opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)

opts.print_frequency = 10

Main Algorithm

[5]:
start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
print(max(abs(soln.final.x))) # should be close to 1


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   30                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    1                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  0.61751624522 ║  0.61751624522 ║   -  │ 0.000000 ║ -  │     1 │ 0.000000 ║     1 │ 0.054664   ║
  10 ║ 1.000000 │  0.60573380055 ║  0.60513582468 ║   -  │ 5.98e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.024968   ║
  20 ║ 1.000000 │  0.58456516016 ║  0.58301955756 ║   -  │ 0.001546 ║ S  │     1 │ 1.000000 ║     1 │ 0.043517   ║
  30 ║ 1.000000 │  0.50113197499 ║  0.49475409554 ║   -  │ 0.006378 ║ S  │     3 │ 0.250000 ║     1 │ 0.121253   ║
  40 ║ 1.000000 │  0.49278124194 ║  0.49260444460 ║   -  │ 1.77e-04 ║ S  │     4 │ 0.125000 ║     1 │ 0.037304   ║
  50 ║ 1.000000 │  0.49225009818 ║  0.49217494722 ║   -  │ 7.52e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.032163   ║
  60 ║ 1.000000 │  0.49212731751 ║  0.49208854433 ║   -  │ 3.88e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.051779   ║
  70 ║ 1.000000 │  0.49203371691 ║  0.49201049130 ║   -  │ 2.32e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.054529   ║
  80 ║ 1.000000 │  0.49197689465 ║  0.49197679422 ║   -  │ 1.00e-07 ║ S  │     2 │ 0.500000 ║     1 │ 0.001300   ║
  90 ║ 1.000000 │  0.49194701030 ║  0.49194698105 ║   -  │ 2.93e-08 ║ S  │     5 │ 0.062500 ║     5 │ 6.92e-05   ║
 100 ║ 1.000000 │  0.49194382838 ║  0.49194381415 ║   -  │ 1.42e-08 ║ S  │     6 │ 0.031250 ║    10 │ 1.28e-05   ║
 110 ║ 1.000000 │  0.49194277900 ║  0.49194277111 ║   -  │ 7.88e-09 ║ S  │     5 │ 0.062500 ║    18 │ 6.93e-07   ║
 120 ║ 1.000000 │  0.49194243076 ║  0.49194242538 ║   -  │ 5.38e-09 ║ S  │     6 │ 0.031250 ║    27 │ 1.99e-07   ║
 130 ║ 1.000000 │  0.49194218055 ║  0.49194217869 ║   -  │ 1.87e-09 ║ S  │     4 │ 0.125000 ║    37 │ 5.60e-08   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
   B ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
  MF ║          │                ║  0.61751624522 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              134                                                                                     ║
Function evaluations:    525                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 3.2532527446746826s
tensor([1.0000], device='cuda:0', dtype=torch.float64)

Various Options

(Optional) Set fvalquit. Quit if the objective value drops below this value at a feasible iterate (that is, satisfying feasibility tolerances opts.viol_ineq_tol and opts.viol_eq_tol)

In the example below, we get termination code 2 since the target objective reached at point feasible to tolerances

[6]:
opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
opts.print_frequency = 10
opts.print_ascii = True


opts.fvalquit = 0.4963

soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
print(max(abs(soln.final.x))) # should be close to 1


###### QP SOLVER NOTICE #########################################################################
#  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  #
#  the default is osqp. Users may provide their own wrapper for the QP solver.                  #
#  To disable this notice, set opts.quadprog_info_msg = False                                   #
#################################################################################################
==================================================================================================================
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             |
Version 1.2.0                                                                                                    |
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  |
==================================================================================================================
Problem specifications:                                                                                          |
 # of variables                     :   30                                                                       |
 # of inequality constraints        :    0                                                                       |
 # of equality constraints          :    1                                                                       |
==================================================================================================================
     | <--- Penalty Function --> |                | Total Violation | <--- Line Search ---> | <- Stationarity -> |
Iter |    Mu    |      Value     |    Objective   | Ineq |    Eq    | SD | Evals |     t    | Grads |    Value   |
=====|===========================|================|=================|=======================|====================|
   0 | 1.000000 |  0.61751624522 |  0.61751624522 |   -  | 0.000000 | -  |     1 | 0.000000 |     1 | 0.054664   |
  10 | 1.000000 |  0.60573380055 |  0.60513582468 |   -  | 5.98e-04 | S  |     1 | 1.000000 |     1 | 0.024968   |
  20 | 1.000000 |  0.58456516016 |  0.58301955756 |   -  | 0.001546 | S  |     1 | 1.000000 |     1 | 0.043517   |
  30 | 1.000000 |  0.50113197499 |  0.49475409554 |   -  | 0.006378 | S  |     3 | 0.250000 |     1 | 0.121253   |
  40 | 1.000000 |  0.49278124194 |  0.49260444460 |   -  | 1.77e-04 | S  |     4 | 0.125000 |     1 | 0.037304   |
  50 | 1.000000 |  0.49225009818 |  0.49217494722 |   -  | 7.52e-05 | S  |     5 | 0.062500 |     1 | 0.032163   |
  60 | 1.000000 |  0.49212731751 |  0.49208854433 |   -  | 3.88e-05 | S  |     4 | 0.125000 |     1 | 0.051779   |
  70 | 1.000000 |  0.49203371691 |  0.49201049130 |   -  | 2.32e-05 | S  |     4 | 0.125000 |     1 | 0.054529   |
==================================================================================================================
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   |
Optimization results:                                                                                            |
==================================================================================================================
   F |          |                |  0.49205285240 |   -  | 3.35e-07 |    |       |          |       |            |
   B |          |                |  0.49205285240 |   -  | 3.35e-07 |    |       |          |       |            |
  MF |          |                |  0.61751624522 |   -  | 0.000000 |    |       |          |       |            |
==================================================================================================================
Iterations:              77                                                                                      |
Function evaluations:    256                                                                                     |
PyGRANSO termination code: 2 --- target objective reached at point feasible to tolerances.                       |
==================================================================================================================
tensor([1.0000], device='cuda:0', dtype=torch.float64)

Set opt_tol. Tolerance for reaching (approximate) optimality/stationarity. See opts.ngrad, opts.evaldist, and the description of PyGRANSO’s output argument soln, specifically the subsubfield .dnorm for more information.

In the result below, PyGRANSO terminated when stationarity is below 1e-4

[7]:
opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
opts.print_frequency = 10
opts.print_ascii = True

opts.opt_tol = 1e-4 # default 1e-8

soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
print(max(abs(soln.final.x))) # should be close to 1


###### QP SOLVER NOTICE #########################################################################
#  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  #
#  the default is osqp. Users may provide their own wrapper for the QP solver.                  #
#  To disable this notice, set opts.quadprog_info_msg = False                                   #
#################################################################################################
==================================================================================================================
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             |
Version 1.2.0                                                                                                    |
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  |
==================================================================================================================
Problem specifications:                                                                                          |
 # of variables                     :   30                                                                       |
 # of inequality constraints        :    0                                                                       |
 # of equality constraints          :    1                                                                       |
==================================================================================================================
     | <--- Penalty Function --> |                | Total Violation | <--- Line Search ---> | <- Stationarity -> |
Iter |    Mu    |      Value     |    Objective   | Ineq |    Eq    | SD | Evals |     t    | Grads |    Value   |
=====|===========================|================|=================|=======================|====================|
   0 | 1.000000 |  0.61751624522 |  0.61751624522 |   -  | 0.000000 | -  |     1 | 0.000000 |     1 | 0.054664   |
  10 | 1.000000 |  0.60573380055 |  0.60513582468 |   -  | 5.98e-04 | S  |     1 | 1.000000 |     1 | 0.024968   |
  20 | 1.000000 |  0.58456516016 |  0.58301955756 |   -  | 0.001546 | S  |     1 | 1.000000 |     1 | 0.043517   |
  30 | 1.000000 |  0.50113197499 |  0.49475409554 |   -  | 0.006378 | S  |     3 | 0.250000 |     1 | 0.121253   |
  40 | 1.000000 |  0.49278124194 |  0.49260444460 |   -  | 1.77e-04 | S  |     4 | 0.125000 |     1 | 0.037304   |
  50 | 1.000000 |  0.49225009818 |  0.49217494722 |   -  | 7.52e-05 | S  |     5 | 0.062500 |     1 | 0.032163   |
  60 | 1.000000 |  0.49212731751 |  0.49208854433 |   -  | 3.88e-05 | S  |     4 | 0.125000 |     1 | 0.051779   |
  70 | 1.000000 |  0.49203371691 |  0.49201049130 |   -  | 2.32e-05 | S  |     4 | 0.125000 |     1 | 0.054529   |
  80 | 1.000000 |  0.49197689465 |  0.49197679422 |   -  | 1.00e-07 | S  |     2 | 0.500000 |     1 | 0.001300   |
  90 | 1.000000 |  0.49194701030 |  0.49194698105 |   -  | 2.93e-08 | S  |     5 | 0.062500 |     5 | 6.92e-05   |
==================================================================================================================
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   |
Optimization results:                                                                                            |
==================================================================================================================
   F |          |                |  0.49194698105 |   -  | 2.93e-08 |    |       |          |       |            |
   B |          |                |  0.49194698105 |   -  | 2.93e-08 |    |       |          |       |            |
  MF |          |                |  0.61751624522 |   -  | 0.000000 |    |       |          |       |            |
==================================================================================================================
Iterations:              90                                                                                      |
Function evaluations:    299                                                                                     |
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           |
==================================================================================================================
tensor([1.0000], device='cuda:0', dtype=torch.float64)

There are multiple other settings. Please uncomment to try them. Detailed description can be found by typing

import pygransoOptionsAdvanced

help(pygransoOptionsAdvanced)

[8]:
opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
opts.print_frequency = 10

# Please uncomment to try different settings

# Tolerance for determining when the relative decrease in the penalty
# function is sufficiently small.  PyGRANSO will terminate if when
# the relative decrease in the penalty function is at or below this
# tolerance and the current iterate is feasible to tolerances.
# Generally, we don't recommend using this feature since small steps
# are not necessarily indicative of being near a stationary point,
# particularly for nonsmooth problems.

# Termination Code 1
# opts.rel_tol = 1e-2 # default 0

# Tolerance for how small of a step the line search will attempt
# before terminating.

# Termination Code 6 or 7
# opts.step_tol = 1e-6 # default 1e-12
# opts.step_tol = 1e-3

# Acceptable total violation tolerance of the equality constraints.
# opts.viol_eq_tol = 1e-12# default 1e-6, make it smaller will make current point harder to be considered as feasible

# Quit if the elapsed clock time in seconds exceeds this. unit: second
# opts.maxclocktime = 1.

# Number of characters wide to print values for the penalty function,
# the objective function, and the total violations of the inequality
# and equality constraints.
# opts.print_width = 9

# PyGRANSO's uses orange
# printing to highlight pertinent information.  However, the user
# is the given option to disable it, since support cannot be
# guaranteed (since it is an undocumented feature).
# opts.print_use_orange = False

# opts.init_step_size = 1e-2
# opts.search_direction_rescaling = True

soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
print(max(abs(soln.final.x))) # should be close to 1


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   30                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    1                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  0.61751624522 ║  0.61751624522 ║   -  │ 0.000000 ║ -  │     1 │ 0.000000 ║     1 │ 0.054664   ║
  10 ║ 1.000000 │  0.60573380055 ║  0.60513582468 ║   -  │ 5.98e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.024968   ║
  20 ║ 1.000000 │  0.58456516016 ║  0.58301955756 ║   -  │ 0.001546 ║ S  │     1 │ 1.000000 ║     1 │ 0.043517   ║
  30 ║ 1.000000 │  0.50113197499 ║  0.49475409554 ║   -  │ 0.006378 ║ S  │     3 │ 0.250000 ║     1 │ 0.121253   ║
  40 ║ 1.000000 │  0.49278124194 ║  0.49260444460 ║   -  │ 1.77e-04 ║ S  │     4 │ 0.125000 ║     1 │ 0.037304   ║
  50 ║ 1.000000 │  0.49225009818 ║  0.49217494722 ║   -  │ 7.52e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.032163   ║
  60 ║ 1.000000 │  0.49212731751 ║  0.49208854433 ║   -  │ 3.88e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.051779   ║
  70 ║ 1.000000 │  0.49203371691 ║  0.49201049130 ║   -  │ 2.32e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.054529   ║
  80 ║ 1.000000 │  0.49197689465 ║  0.49197679422 ║   -  │ 1.00e-07 ║ S  │     2 │ 0.500000 ║     1 │ 0.001300   ║
  90 ║ 1.000000 │  0.49194701030 ║  0.49194698105 ║   -  │ 2.93e-08 ║ S  │     5 │ 0.062500 ║     5 │ 6.92e-05   ║
 100 ║ 1.000000 │  0.49194382838 ║  0.49194381415 ║   -  │ 1.42e-08 ║ S  │     6 │ 0.031250 ║    10 │ 1.28e-05   ║
 110 ║ 1.000000 │  0.49194277900 ║  0.49194277111 ║   -  │ 7.88e-09 ║ S  │     5 │ 0.062500 ║    18 │ 6.93e-07   ║
 120 ║ 1.000000 │  0.49194243076 ║  0.49194242538 ║   -  │ 5.38e-09 ║ S  │     6 │ 0.031250 ║    27 │ 1.99e-07   ║
 130 ║ 1.000000 │  0.49194218055 ║  0.49194217869 ║   -  │ 1.87e-09 ║ S  │     4 │ 0.125000 ║    37 │ 5.60e-08   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
   B ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
  MF ║          │                ║  0.61751624522 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              134                                                                                     ║
Function evaluations:    525                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
tensor([1.0000], device='cuda:0', dtype=torch.float64)

(For Advanced User) Users can specify analytical gradients instead of using the auto-differentiation feature. Please check the documentation in pygranso.py on the required format of combined_fn

[9]:
# Without AD
def comb_fn(X_struct):
    q = X_struct.q

    # objective function
    qtY = q.T @ Y
    f = 1/m * torch.norm(qtY, p = 1).item()
    f_grad = 1/m*Y@torch.sign(Y.T@q)

    # inequality constraint, matrix form
    ci = None
    ci_grad = None

    # equality constraint
    ce = q.T @ q - 1
    ce_grad = 2*q

    return [f,f_grad,ci,ci_grad,ce,ce_grad]

opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
opts.print_frequency = 10
opts.globalAD = False # disable global auto-differentiation

start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
print(max(abs(soln.final.x))) # should be close to 1


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   30                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    1                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  0.61751624522 ║  0.61751624522 ║   -  │ 0.000000 ║ -  │     1 │ 0.000000 ║     1 │ 0.054664   ║
  10 ║ 1.000000 │  0.60573380055 ║  0.60513582468 ║   -  │ 5.98e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.024968   ║
  20 ║ 1.000000 │  0.58456516016 ║  0.58301955756 ║   -  │ 0.001546 ║ S  │     1 │ 1.000000 ║     1 │ 0.043517   ║
  30 ║ 1.000000 │  0.50113197499 ║  0.49475409554 ║   -  │ 0.006378 ║ S  │     3 │ 0.250000 ║     1 │ 0.121253   ║
  40 ║ 1.000000 │  0.49278124194 ║  0.49260444460 ║   -  │ 1.77e-04 ║ S  │     4 │ 0.125000 ║     1 │ 0.037304   ║
  50 ║ 1.000000 │  0.49225009818 ║  0.49217494722 ║   -  │ 7.52e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.032163   ║
  60 ║ 1.000000 │  0.49212731751 ║  0.49208854433 ║   -  │ 3.88e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.051779   ║
  70 ║ 1.000000 │  0.49203371691 ║  0.49201049130 ║   -  │ 2.32e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.054529   ║
  80 ║ 1.000000 │  0.49197689465 ║  0.49197679422 ║   -  │ 1.00e-07 ║ S  │     2 │ 0.500000 ║     1 │ 0.001300   ║
  90 ║ 1.000000 │  0.49194701030 ║  0.49194698105 ║   -  │ 2.93e-08 ║ S  │     5 │ 0.062500 ║     5 │ 6.92e-05   ║
 100 ║ 1.000000 │  0.49194382838 ║  0.49194381415 ║   -  │ 1.42e-08 ║ S  │     6 │ 0.031250 ║    10 │ 1.28e-05   ║
 110 ║ 1.000000 │  0.49194277900 ║  0.49194277111 ║   -  │ 7.88e-09 ║ S  │     5 │ 0.062500 ║    18 │ 6.93e-07   ║
 120 ║ 1.000000 │  0.49194243076 ║  0.49194242538 ║   -  │ 5.38e-09 ║ S  │     6 │ 0.031250 ║    27 │ 1.99e-07   ║
 130 ║ 1.000000 │  0.49194218055 ║  0.49194217869 ║   -  │ 1.87e-09 ║ S  │     4 │ 0.125000 ║    37 │ 5.66e-08   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
   B ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
  MF ║          │                ║  0.61751624522 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              134                                                                                     ║
Function evaluations:    525                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 2.052833318710327s
tensor([1.0000], device='cuda:0', dtype=torch.float64)

(For Advanced User) Alternatively, users can use the auto-differentiation feature partially.

[10]:
# import the AD function
from pygranso.private.tensor2vec import getCiGradVec

# partial AD
def comb_fn(X_struct):
    q = X_struct.q
    q.requires_grad_(True)

    # objective function
    q_tmp = q.detach().clone()
    qtY = q_tmp.T @ Y
    f = 1/m * torch.norm(qtY, p = 1).item()
    f_grad = 1/m*Y@torch.sign(Y.T@q_tmp)

    # inequality constraint, matrix form
    ci = None
    ci_grad = None

    # equality constraint
    ce = q.T @ q - 1
    # ce_grad = 2*q
    ce_grad = getCiGradVec(nvar=n,nconstr_ci_total=1,var_dim_map=var_in,X=X_struct,ci_vec_torch=ce,torch_device=device,double_precision=torch.double)

    # return value must be detached from the computational graph
    ce = ce.detach()

    return [f,f_grad,ci,ci_grad,ce,ce_grad]

opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
opts.print_frequency = 10
opts.globalAD = False # disable global auto-differentiation


start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
print(max(abs(soln.final.x))) # should be close to 1**(For Advanced User)** Users can specify analytical gradients instead of using the AD feature


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   30                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    1                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  0.61751624522 ║  0.61751624522 ║   -  │ 0.000000 ║ -  │     1 │ 0.000000 ║     1 │ 0.054664   ║
  10 ║ 1.000000 │  0.60573380055 ║  0.60513582468 ║   -  │ 5.98e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.024968   ║
  20 ║ 1.000000 │  0.58456516016 ║  0.58301955756 ║   -  │ 0.001546 ║ S  │     1 │ 1.000000 ║     1 │ 0.043517   ║
  30 ║ 1.000000 │  0.50113197499 ║  0.49475409554 ║   -  │ 0.006378 ║ S  │     3 │ 0.250000 ║     1 │ 0.121253   ║
  40 ║ 1.000000 │  0.49278124194 ║  0.49260444460 ║   -  │ 1.77e-04 ║ S  │     4 │ 0.125000 ║     1 │ 0.037304   ║
  50 ║ 1.000000 │  0.49225009818 ║  0.49217494722 ║   -  │ 7.52e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.032163   ║
  60 ║ 1.000000 │  0.49212731751 ║  0.49208854433 ║   -  │ 3.88e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.051779   ║
  70 ║ 1.000000 │  0.49203371691 ║  0.49201049130 ║   -  │ 2.32e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.054529   ║
  80 ║ 1.000000 │  0.49197689465 ║  0.49197679422 ║   -  │ 1.00e-07 ║ S  │     2 │ 0.500000 ║     1 │ 0.001300   ║
  90 ║ 1.000000 │  0.49194701030 ║  0.49194698105 ║   -  │ 2.93e-08 ║ S  │     5 │ 0.062500 ║     5 │ 6.92e-05   ║
 100 ║ 1.000000 │  0.49194382838 ║  0.49194381415 ║   -  │ 1.42e-08 ║ S  │     6 │ 0.031250 ║    10 │ 1.28e-05   ║
 110 ║ 1.000000 │  0.49194277900 ║  0.49194277111 ║   -  │ 7.88e-09 ║ S  │     5 │ 0.062500 ║    18 │ 6.93e-07   ║
 120 ║ 1.000000 │  0.49194243076 ║  0.49194242538 ║   -  │ 5.38e-09 ║ S  │     6 │ 0.031250 ║    27 │ 1.99e-07   ║
 130 ║ 1.000000 │  0.49194218055 ║  0.49194217869 ║   -  │ 1.87e-09 ║ S  │     4 │ 0.125000 ║    37 │ 5.66e-08   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
   B ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
  MF ║          │                ║  0.61751624522 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              134                                                                                     ║
Function evaluations:    525                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 2.0349695682525635s
tensor([1.0000], device='cuda:0', dtype=torch.float64)

Different Set-Up

(Optional) Using torch.nn to set up the user-provided function. This setting is important in solving constrained deep learning problems.

Modules Importing

Import all necessary modules and add PyGRANSO src folder to system path.

[11]:
import time
import numpy as np
import torch
import numpy.linalg as la
from scipy.stats import norm
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct

from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn

Initialization

Specify torch device, create torch model and generate data

Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)

[12]:
device = torch.device('cuda')

class Dict_Learning(nn.Module):

    def __init__(self,n):
        super().__init__()
        np.random.seed(1)
        q0 = norm.ppf(np.random.rand(n,1))
        q0 /= la.norm(q0,2)
        self.q = nn.Parameter( torch.from_numpy(q0) )

    def forward(self, Y,m):
        qtY = self.q.T @ Y
        f = 1/m * torch.norm(qtY, p = 1)
        return f

## Data initialization
n = 30
np.random.seed(1)
m = 10*n**2   # sample complexity
theta = 0.3   # sparsity level
Y = norm.ppf(np.random.rand(n,m)) * (norm.ppf(np.random.rand(n,m)) <= theta)  # Bernoulli-Gaussian model
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
Y = torch.from_numpy(Y).to(device=device, dtype=torch.double)

torch.manual_seed(0)

model = Dict_Learning(n).to(device=device, dtype=torch.double)

Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

[13]:
def user_fn(model,Y,m):
    # objective function
    f = model(Y,m)

    q = list(model.parameters())[0]

    # inequality constraint
    ci = None

    # equality constraint
    ce = pygransoStruct()
    ce.c1 = q.T @ q - 1

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,Y,m)

User Options

Specify user-defined options for PyGRANSO

[14]:
opts = pygransoStruct()
opts.torch_device = device
opts.maxit = 500
np.random.seed(1)
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)

opts.print_frequency = 10

Main Algorithm

[15]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
print(max(abs(soln.final.x))) # should be close to 1


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   30                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    1                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  0.61751624522 ║  0.61751624522 ║   -  │ 0.000000 ║ -  │     1 │ 0.000000 ║     1 │ 0.054664   ║
  10 ║ 1.000000 │  0.60573380055 ║  0.60513582468 ║   -  │ 5.98e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.024968   ║
  20 ║ 1.000000 │  0.58456516016 ║  0.58301955756 ║   -  │ 0.001546 ║ S  │     1 │ 1.000000 ║     1 │ 0.043517   ║
  30 ║ 1.000000 │  0.50113197499 ║  0.49475409554 ║   -  │ 0.006378 ║ S  │     3 │ 0.250000 ║     1 │ 0.121253   ║
  40 ║ 1.000000 │  0.49278124194 ║  0.49260444460 ║   -  │ 1.77e-04 ║ S  │     4 │ 0.125000 ║     1 │ 0.037304   ║
  50 ║ 1.000000 │  0.49225009818 ║  0.49217494722 ║   -  │ 7.52e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.032163   ║
  60 ║ 1.000000 │  0.49212731751 ║  0.49208854433 ║   -  │ 3.88e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.051779   ║
  70 ║ 1.000000 │  0.49203371691 ║  0.49201049130 ║   -  │ 2.32e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.054529   ║
  80 ║ 1.000000 │  0.49197689465 ║  0.49197679422 ║   -  │ 1.00e-07 ║ S  │     2 │ 0.500000 ║     1 │ 0.001300   ║
  90 ║ 1.000000 │  0.49194701030 ║  0.49194698105 ║   -  │ 2.93e-08 ║ S  │     5 │ 0.062500 ║     5 │ 6.92e-05   ║
 100 ║ 1.000000 │  0.49194382838 ║  0.49194381415 ║   -  │ 1.42e-08 ║ S  │     6 │ 0.031250 ║    10 │ 1.28e-05   ║
 110 ║ 1.000000 │  0.49194277900 ║  0.49194277111 ║   -  │ 7.88e-09 ║ S  │     5 │ 0.062500 ║    18 │ 6.93e-07   ║
 120 ║ 1.000000 │  0.49194243076 ║  0.49194242538 ║   -  │ 5.38e-09 ║ S  │     6 │ 0.031250 ║    27 │ 1.99e-07   ║
 130 ║ 1.000000 │  0.49194218055 ║  0.49194217869 ║   -  │ 1.87e-09 ║ S  │     4 │ 0.125000 ║    37 │ 5.60e-08   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
   B ║          │                ║  0.49194215780 ║   -  │ 1.30e-09 ║    │       │          ║       │            ║
  MF ║          │                ║  0.61751624522 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              134                                                                                     ║
Function evaluations:    525                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 2.3208694458007812s
tensor([1.0000], device='cuda:0', dtype=torch.float64)