{ "cells": [ { "cell_type": "markdown", "id": "5257bb27", "metadata": {}, "source": [ "# Dictionary Learning\n", "\n", "Solve orthogonal dictionary learning problem taken from: Yu Bai, Qijia Jiang, and Ju Sun. \n", "[\"Subgradient descent learns orthogonal dictionaries.\"](https://arxiv.org/abs/1810.10702) arXiv preprint arXiv:1810.10702 (2018)." ] }, { "cell_type": "markdown", "id": "d5b1ab10", "metadata": {}, "source": [ "## Problem Description" ] }, { "cell_type": "markdown", "id": "28993e76", "metadata": {}, "source": [ "Given data $\\{y_i \\}_{i \\in[m]}$ generated as $y_i = A x_i$, where $A \\in R^{n \\times n}$ is a fixed unknown orthogonal matrix and each $x_i \\in R^n$ is an iid Bernoulli-Gaussian random vector with parameter $\\theta \\in (0,1)$, recover $A$. \n", "\n", "Write $Y \\doteq [y_1,...,y_m]$ and $X \\doteq [x_1,...,x_m]$. To find the column of $A$, one can perform the following optimization:\n", "\n", "$$\\min_{q \\in R^n} f(q) \\doteq \\frac{1}{m} ||q^T Y||_{1} = \\frac{1}{m} \\sum_{i=1}^m |q^T y_i|,$$\n", "$$\\text{s.t.} ||q||_2 = 1$$\n", "\n", "This problem is nonconvex due to the constraint and nonsmooth due to the objective.\n", "\n", "Based on the above statistical model, $q^T Y = q^T A X$ has the highest sparsity when $q$ is a column of $A$ (up to sign) so that $q^T A$ is 1-sparse. " ] }, { "cell_type": "markdown", "id": "08dfdd50", "metadata": {}, "source": [ "## Modules Importing\n", "Import all necessary modules and add PyGRANSO src folder to system path." ] }, { "cell_type": "code", "execution_count": 1, "id": "90ed32f9", "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import numpy.linalg as la\n", "from scipy.stats import norm\n", "from pygranso.pygranso import pygranso\n", "from pygranso.pygransoStruct import pygransoStruct\n", "\n", "from pygranso.private.getNvar import getNvarTorch\n", "import torch.nn as nn" ] }, { "cell_type": "markdown", "id": "17a1b7fe", "metadata": {}, "source": [ "## Initialization \n", "Specify torch device, create torch model and generate data\n", "\n", "Use GPU for this problem. If no cuda device available, please set *device = torch.device('cpu')*" ] }, { "cell_type": "code", "execution_count": 2, "id": "8b4842e1", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda')\n", "\n", "class Dict_Learning(nn.Module):\n", " \n", " def __init__(self,n):\n", " super().__init__()\n", " np.random.seed(1)\n", " q0 = norm.ppf(np.random.rand(n,1))\n", " q0 /= la.norm(q0,2)\n", " self.q = nn.Parameter( torch.from_numpy(q0) )\n", " \n", " def forward(self, Y,m):\n", " qtY = self.q.T @ Y\n", " f = 1/m * torch.norm(qtY, p = 1)\n", " return f\n", "\n", "## Data initialization\n", "n = 30\n", "np.random.seed(1)\n", "m = 10*n**2 # sample complexity\n", "theta = 0.3 # sparsity level\n", "Y = norm.ppf(np.random.rand(n,m)) * (norm.ppf(np.random.rand(n,m)) <= theta) # Bernoulli-Gaussian model\n", "# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.\n", "# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.\n", "# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.\n", "Y = torch.from_numpy(Y).to(device=device, dtype=torch.double)\n", "\n", "torch.manual_seed(0)\n", "\n", "model = Dict_Learning(n).to(device=device, dtype=torch.double)" ] }, { "cell_type": "markdown", "id": "ec80716b", "metadata": {}, "source": [ "## Function Set-Up\n", "\n", "Encode the optimization variables, and objective and constraint functions.\n", "\n", "Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm." ] }, { "cell_type": "code", "execution_count": 3, "id": "fb360e75", "metadata": {}, "outputs": [], "source": [ "def user_fn(model,Y,m):\n", " # objective function \n", " f = model(Y,m)\n", "\n", " q = list(model.parameters())[0]\n", "\n", " # inequality constraint\n", " ci = None\n", "\n", " # equality constraint \n", " ce = pygransoStruct()\n", " ce.c1 = q.T @ q - 1\n", "\n", " return [f,ci,ce]\n", "\n", "comb_fn = lambda model : user_fn(model,Y,m)" ] }, { "cell_type": "markdown", "id": "f0f55ace", "metadata": {}, "source": [ "## User Options\n", "Specify user-defined options for PyGRANSO" ] }, { "cell_type": "code", "execution_count": 4, "id": "f3a65b57", "metadata": {}, "outputs": [], "source": [ "opts = pygransoStruct()\n", "opts.torch_device = device\n", "opts.maxit = 500\n", "np.random.seed(1)\n", "nvar = getNvarTorch(model.parameters())\n", "opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)\n", "\n", "opts.print_frequency = 10" ] }, { "cell_type": "markdown", "id": "8bca18c7", "metadata": {}, "source": [ "## Main Algorithm" ] }, { "cell_type": "code", "execution_count": 5, "id": "632976b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗\n", "\u001b[0m\u001b[33m║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║\n", "\u001b[0m\u001b[33m║ the default is osqp. Users may provide their own wrapper for the QP solver. ║\n", "\u001b[0m\u001b[33m║ To disable this notice, set opts.quadprog_info_msg = False ║\n", "\u001b[0m\u001b[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝\n", "\u001b[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗\n", "PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║ \n", "Version 1.0.0 ║ \n", "Licensed under the AGPLv3, Copyright (C) 2021 Tim Mitchell and Buyun Liang ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣\n", "Problem specifications: ║ \n", " # of variables : 30 ║ \n", " # of inequality constraints : 0 ║ \n", " # of equality constraints : 1 ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " 0 ║ 1.000000 │ 0.61751624522 ║ 0.61751624522 ║ - │ 0.000000 ║ - │ 1 │ 0.000000 ║ 1 │ 0.054664 ║ \n", " 10 ║ 1.000000 │ 0.60573380055 ║ 0.60513582468 ║ - │ 5.98e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.024968 ║ \n", " 20 ║ 1.000000 │ 0.58456516016 ║ 0.58301955756 ║ - │ 0.001546 ║ S │ 1 │ 1.000000 ║ 1 │ 0.043517 ║ \n", " 30 ║ 1.000000 │ 0.50113197499 ║ 0.49475409554 ║ - │ 0.006378 ║ S │ 3 │ 0.250000 ║ 1 │ 0.121253 ║ \n", " 40 ║ 1.000000 │ 0.49278124194 ║ 0.49260444460 ║ - │ 1.77e-04 ║ S │ 4 │ 0.125000 ║ 1 │ 0.037304 ║ \n", " 50 ║ 1.000000 │ 0.49225009818 ║ 0.49217494723 ║ - │ 7.52e-05 ║ S │ 5 │ 0.062500 ║ 1 │ 0.032163 ║ \n", " 60 ║ 1.000000 │ 0.49212731751 ║ 0.49208854433 ║ - │ 3.88e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.051779 ║ \n", " 70 ║ 1.000000 │ 0.49203371691 ║ 0.49201049130 ║ - │ 2.32e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.054529 ║ \n", " 80 ║ 1.000000 │ 0.49197689465 ║ 0.49197679422 ║ - │ 1.00e-07 ║ S │ 2 │ 0.500000 ║ 1 │ 0.001300 ║ \n", " 90 ║ 1.000000 │ 0.49194701030 ║ 0.49194698105 ║ - │ 2.93e-08 ║ S │ 5 │ 0.062500 ║ 5 │ 1.02e-04 ║ \n", " 100 ║ 1.000000 │ 0.49194382838 ║ 0.49194381415 ║ - │ 1.42e-08 ║ S │ 6 │ 0.031250 ║ 10 │ 5.71e-05 ║ \n", " 110 ║ 1.000000 │ 0.49194277900 ║ 0.49194277111 ║ - │ 7.88e-09 ║ S │ 5 │ 0.062500 ║ 18 │ 9.98e-06 ║ \n", " 120 ║ 1.000000 │ 0.49194243076 ║ 0.49194242538 ║ - │ 5.38e-09 ║ S │ 6 │ 0.031250 ║ 27 │ 2.47e-06 ║ \n", " 130 ║ 1.000000 │ 0.49194218055 ║ 0.49194217869 ║ - │ 1.87e-09 ║ S │ 4 │ 0.125000 ║ 37 │ 5.14e-07 ║ \n", " 140 ║ 1.000000 │ 0.49194213249 ║ 0.49194213160 ║ - │ 8.82e-10 ║ S │ 5 │ 0.062500 ║ 40 │ 1.15e-07 ║ \n", " 150 ║ 1.000000 │ 0.49194211795 ║ 0.49194211747 ║ - │ 4.78e-10 ║ S │ 5 │ 0.062500 ║ 40 │ 4.44e-08 ║ \n", " 160 ║ 1.000000 │ 0.49194211356 ║ 0.49194211328 ║ - │ 2.77e-10 ║ S │ 5 │ 0.062500 ║ 40 │ 2.16e-08 ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Optimization results: ║ \n", "F = final iterate, B = Best (to tolerance), MF = Most Feasible ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " F ║ │ ║ 0.49194211312 ║ - │ 2.59e-10 ║ │ │ ║ │ ║ \n", " B ║ │ ║ 0.49194211312 ║ - │ 2.59e-10 ║ │ │ ║ │ ║ \n", " MF ║ │ ║ 0.61751624522 ║ - │ 0.000000 ║ │ │ ║ │ ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Iterations: 161 ║ \n", "Function evaluations: 664 ║ \n", "PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝\n", "Total Wall Time: 2.669196605682373s\n", "tensor([1.0000], device='cuda:0', dtype=torch.float64)\n" ] } ], "source": [ "start = time.time()\n", "soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)\n", "end = time.time()\n", "print(\"Total Wall Time: {}s\".format(end - start))\n", "print(max(abs(soln.final.x))) # should be close to 1" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }