{ "cells": [ { "cell_type": "markdown", "id": "5257bb27", "metadata": {}, "source": [ "# Orthogonal RNN\n", "\n", "Train Orthogonal RNN for MNIST classification based on [this Paper](https://arxiv.org/pdf/1901.08428.pdf)\n", "\n", "NOTE: this example is still under development. " ] }, { "cell_type": "markdown", "id": "c859c154", "metadata": {}, "source": [ "## Problem Description" ] }, { "cell_type": "markdown", "id": "b96269c7", "metadata": {}, "source": [ "For each element in the input sequence, each layer computes the following function:\n", "$$h_t=\\tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_hh)$$\n", "\n", "where $h_{t}$ is the hidden state at time $t$, and $h_{t-1}$ is the hidden state of the previous layer at time $t-1$ or the initial hidden state at time $o$. \n", "\n", "For each layer, we have the orthogonal constraint:\n", "$$ W_{hh}^T W_{hh} = I $$" ] }, { "cell_type": "markdown", "id": "08dfdd50", "metadata": {}, "source": [ "## Modules Importing\n", "Import all necessary modules and add PyGRANSO src folder to system path. " ] }, { "cell_type": "code", "execution_count": 1, "id": "90ed32f9", "metadata": {}, "outputs": [], "source": [ "import time\n", "import torch\n", "import sys\n", "## Adding PyGRANSO directories. Should be modified by user\n", "sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')\n", "from pygranso.pygranso import pygranso\n", "from pygranso.pygransoStruct import pygransoStruct \n", "from pygranso.private.getNvar import getNvarTorch\n", "import torch.nn as nn\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "from pygranso.private.getObjGrad import getObjGradDL" ] }, { "cell_type": "markdown", "id": "17a1b7fe", "metadata": {}, "source": [ "## Data Initialization \n", "Specify torch device, neural network architecture, and generate data.\n", "\n", "NOTE: please specify path for downloading data.\n", "\n", "Use GPU for this problem. If no cuda device available, please set *device = torch.device('cpu')*" ] }, { "cell_type": "code", "execution_count": 2, "id": "8b4842e1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] } ], "source": [ "device = torch.device('cuda')\n", "\n", "sequence_length = 28\n", "input_size = 28\n", "hidden_size = 30\n", "num_layers = 1\n", "num_classes = 10\n", "batch_size = 100\n", "\n", "\n", "double_precision = torch.double\n", "\n", "class RNN(nn.Module):\n", " \n", " def __init__(self, input_size, hidden_size, num_layers, num_classes):\n", " super(RNN, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.num_layers = num_layers\n", " # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n", " self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)\n", " self.fc = nn.Linear(hidden_size, num_classes)\n", " pass\n", " \n", " def forward(self, x):\n", " x = torch.reshape(x,(batch_size,sequence_length,input_size))\n", " # Set initial hidden and cell states \n", " h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)\n", " out, hidden = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size)\n", " #Reshaping the outputs such that it can be fit into the fully connected layer\n", " out = self.fc(out[:, -1, :])\n", " return out\n", " \n", "torch.manual_seed(0)\n", "\n", "model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)\n", "model.train()\n", "\n", "train_data = datasets.MNIST(\n", " root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',\n", " train = True, \n", " transform = ToTensor(), \n", " download = True, \n", ") \n", "\n", "loaders = {\n", " 'train' : torch.utils.data.DataLoader(train_data, \n", " batch_size=100, \n", " shuffle=True, \n", " num_workers=1),\n", "}\n", "\n", "inputs, labels = next(iter(loaders['train']))\n", "inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)" ] }, { "cell_type": "markdown", "id": "ec80716b", "metadata": {}, "source": [ "## Function Set-Up\n", "\n", "Encode the optimization variables, and objective and constraint functions.\n", "\n", "Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm." ] }, { "cell_type": "code", "execution_count": 3, "id": "fb360e75", "metadata": {}, "outputs": [], "source": [ "def user_fn(model,inputs,labels):\n", " # objective function \n", " logits = model(inputs)\n", " criterion = nn.CrossEntropyLoss()\n", " f = criterion(logits, labels)\n", "\n", " A = list(model.parameters())[1]\n", "\n", " # inequality constraint\n", " ci = None\n", "\n", " # equality constraint \n", " # special orthogonal group\n", " \n", " ce = pygransoStruct()\n", "\n", " c1_vec = (A.T @ A \n", " - torch.eye(hidden_size)\n", " .to(device=device, dtype=double_precision)\n", " ).reshape(1,-1)\n", " \n", " ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints\n", " # ce.c2 = torch.det(A) - 1\n", "\n", " # ce = None\n", "\n", " return [f,ci,ce]\n", "\n", "comb_fn = lambda model : user_fn(model,inputs,labels)" ] }, { "cell_type": "markdown", "id": "f0f55ace", "metadata": {}, "source": [ "## User Options\n", "Specify user-defined options for PyGRANSO " ] }, { "cell_type": "code", "execution_count": 4, "id": "f3a65b57", "metadata": {}, "outputs": [], "source": [ "opts = pygransoStruct()\n", "opts.torch_device = device\n", "nvar = getNvarTorch(model.parameters())\n", "opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)\n", "opts.opt_tol = 1e-3\n", "opts.viol_eq_tol = 1e-4\n", "# opts.maxit = 150\n", "# opts.fvalquit = 1e-6\n", "opts.print_level = 1\n", "opts.print_frequency = 50\n", "# opts.print_ascii = True\n", "# opts.limited_mem_size = 100\n", "opts.double_precision = True\n", "\n", "opts.mu0 = 1" ] }, { "cell_type": "markdown", "id": "754ba30a", "metadata": {}, "source": [ "## Initial Test \n", "Check initial accuracy of the RNN model" ] }, { "cell_type": "code", "execution_count": 5, "id": "711f0e9c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial acc = 10.00%\n" ] } ], "source": [ "logits = model(inputs)\n", "_, predicted = torch.max(logits.data, 1)\n", "correct = (predicted == labels).sum().item()\n", "print(\"Initial acc = {:.2f}%\".format((100 * correct/len(inputs)))) " ] }, { "cell_type": "markdown", "id": "8bca18c7", "metadata": {}, "source": [ "## Main Algorithm" ] }, { "cell_type": "code", "execution_count": 6, "id": "632976b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗\n", "\u001b[0m\u001b[33m║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║\n", "\u001b[0m\u001b[33m║ the default is osqp. Users may provide their own wrapper for the QP solver. ║\n", "\u001b[0m\u001b[33m║ To disable this notice, set opts.quadprog_info_msg = False ║\n", "\u001b[0m\u001b[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝\n", "\u001b[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗\n", "PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║ \n", "Version 1.2.0 ║ \n", "Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣\n", "Problem specifications: ║ \n", " # of variables : 2110 ║ \n", " # of inequality constraints : 0 ║ \n", " # of equality constraints : 1 ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " 0 ║ 1.000000 │ 6.34479003561 ║ 2.28955713337 ║ - │ 4.055233 ║ - │ 1 │ 0.000000 ║ 1 │ 0.727032 ║ \n", " 50 ║ 1.000000 │ 0.44342852064 ║ 0.26057998535 ║ - │ 0.182849 ║ S │ 2 │ 0.500000 ║ 1 │ 1.012217 ║ \n", " 100 ║ 1.000000 │ 0.03502773133 ║ 0.01024156848 ║ - │ 0.024786 ║ S │ 3 │ 0.250000 ║ 1 │ 0.072379 ║ \n", " 150 ║ 1.000000 │ 0.00964913672 ║ 0.00333188716 ║ - │ 0.006317 ║ S │ 4 │ 0.125000 ║ 1 │ 0.460388 ║ \n", " 200 ║ 1.000000 │ 0.00402016055 ║ 0.00218694435 ║ - │ 0.001833 ║ S │ 3 │ 0.250000 ║ 1 │ 0.185655 ║ \n", " 250 ║ 1.000000 │ 0.00259692636 ║ 0.00156567317 ║ - │ 0.001031 ║ S │ 3 │ 0.250000 ║ 1 │ 0.278221 ║ \n", " 300 ║ 1.000000 │ 0.00177902934 ║ 0.00124617387 ║ - │ 5.33e-04 ║ S │ 6 │ 0.031250 ║ 1 │ 0.013047 ║ \n", " 350 ║ 1.000000 │ 0.00142805944 ║ 0.00111832264 ║ - │ 3.10e-04 ║ S │ 10 │ 0.001953 ║ 1 │ 0.377399 ║ \n", " 400 ║ 1.000000 │ 0.00125541329 ║ 0.00104496642 ║ - │ 2.10e-04 ║ S │ 7 │ 0.015625 ║ 1 │ 0.151672 ║ \n", " 450 ║ 1.000000 │ 0.00111109651 ║ 9.8199316e-04 ║ - │ 1.29e-04 ║ S │ 11 │ 9.77e-04 ║ 1 │ 0.239003 ║ \n", " 500 ║ 1.000000 │ 0.00104336348 ║ 9.4179093e-04 ║ - │ 1.02e-04 ║ S │ 13 │ 2.44e-04 ║ 2 │ 0.012828 ║ \n", " 550 ║ 1.000000 │ 9.4210490e-04 ║ 8.5316606e-04 ║ - │ 8.89e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.003028 ║ \n", " 600 ║ 1.000000 │ 8.9698591e-04 ║ 7.9341462e-04 ║ - │ 1.04e-04 ║ S │ 6 │ 0.031250 ║ 1 │ 0.228958 ║ \n", " 650 ║ 1.000000 │ 8.3518341e-04 ║ 7.5023136e-04 ║ - │ 8.50e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.249420 ║ \n", " 700 ║ 1.000000 │ 7.9166675e-04 ║ 7.3029674e-04 ║ - │ 6.14e-05 ║ S │ 7 │ 0.015625 ║ 1 │ 0.156740 ║ \n", " 750 ║ 1.000000 │ 7.5802236e-04 ║ 7.1762811e-04 ║ - │ 4.04e-05 ║ S │ 8 │ 0.007812 ║ 2 │ 0.003298 ║ \n", " 800 ║ 1.000000 │ 7.4858692e-04 ║ 7.0752411e-04 ║ - │ 4.11e-05 ║ S │ 15 │ 6.10e-05 ║ 2 │ 0.142087 ║ \n", " 850 ║ 1.000000 │ 7.3196260e-04 ║ 6.9034953e-04 ║ - │ 4.16e-05 ║ S │ 6 │ 0.031250 ║ 1 │ 0.671820 ║ \n", " 900 ║ 1.000000 │ 7.1470716e-04 ║ 6.8116904e-04 ║ - │ 3.35e-05 ║ S │ 5 │ 0.062500 ║ 1 │ 0.045029 ║ \n", " 950 ║ 1.000000 │ 6.4488437e-04 ║ 6.0576357e-04 ║ - │ 3.91e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036849 ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", "1000 ║ 1.000000 │ 4.8682921e-04 ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036078 ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Optimization results: ║ \n", "F = final iterate, B = Best (to tolerance), MF = Most Feasible ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " F ║ │ ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ │ │ ║ │ ║ \n", " B ║ │ ║ 4.4406815e-04 ║ - │ 8.24e-05 ║ │ │ ║ │ ║ \n", " MF ║ │ ║ 6.7149885e-04 ║ - │ 2.58e-05 ║ │ │ ║ │ ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Iterations: 1000 ║ \n", "Function evaluations: 6636 ║ \n", "PyGRANSO termination code: 4 --- max iterations reached. ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝\n", "Total Wall Time: 56.05571532249451s\n" ] } ], "source": [ "start = time.time()\n", "soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)\n", "end = time.time()\n", "print(\"Total Wall Time: {}s\".format(end - start))" ] }, { "cell_type": "markdown", "id": "21bff5fd", "metadata": {}, "source": [ "## Train Accuracy" ] }, { "cell_type": "code", "execution_count": 7, "id": "8d846f87", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final acc = 100.00%\n", "final feasibility = 3.59430769855918e-05\n" ] } ], "source": [ "torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())\n", "logits = model(inputs)\n", "_, predicted = torch.max(logits.data, 1)\n", "correct = (predicted == labels).sum().item()\n", "print(\"Final acc = {:.2f}%\".format((100 * correct/len(inputs)))) \n", "print(\"final feasibility = {}\".format(soln.final.tve))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }