# Orthogonal RNN

Train Orthogonal RNN for MNIST classification based on [this Paper](https://arxiv.org/pdf/1901.08428.pdf) " ] }, { "cell_type": "markdown", "id": "c859c154", "metadata": {}, "source": [ "## Problem Description" ] }, { "cell_type": "markdown", "id": "b96269c7", "metadata": {}, "source": [ "For each element in the input sequence, each layer computes the following function:\n", "$$h_t=\\tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_hh)$$\n", "\n", "where $h_{t}$ is the hidden state at time $t$, and $h_{t-1}$ is the hidden state of the previous layer at time $t-1$ or the initial hidden state at time $o$. \n", "\n", "For each layer, we have the orthogonal constraint:\n", "$$ W_{hh}^T W_{hh} = I $$" ] }, { "cell_type": "markdown", "id": "08dfdd50", "metadata": {}, "source": [ "## Modules Importing\n", "Import all necessary modules and add PyGRANSO src folder to system path. " ] }, { "cell_type": "code", "execution_count": 1, "id": "90ed32f9", "metadata": {}, "outputs": [], "source": [ "import time\n", "import torch\n", "import sys\n", "## Adding PyGRANSO directories. Should be modified by user\n", "sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')\n", "from pygranso.pygranso import pygranso\n", "from pygranso.pygransoStruct import pygransoStruct \n", "from pygranso.private.getNvar import getNvarTorch\n", "import torch.nn as nn\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "from pygranso.private.getObjGrad import getObjGradDL" ] }, { "cell_type": "markdown", "id": "17a1b7fe", "metadata": {}, "source": [ "## Data Initialization \n", "Specify torch device, neural network architecture, and generate data.\n", "\n", "NOTE: please specify path for downloading data.\n", "\n", "Use GPU for this problem. If no cuda device available, please set *device = torch.device('cpu')*" ] }, { "cell_type": "code", "execution_count": 2, "id": "8b4842e1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] } ], "source": [ "device = torch.device('cuda')\n", "\n", "sequence_length = 28\n", "input_size = 28\n", "hidden_size = 30\n", "num_layers = 1\n", "num_classes = 10\n", "batch_size = 100\n", "\n", "\n", "double_precision = torch.double\n", "\n", "class RNN(nn.Module):\n", " \n", " def __init__(self, input_size, hidden_size, num_layers, num_classes):\n", " super(RNN, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.num_layers = num_layers\n", " # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n", " self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)\n", " self.fc = nn.Linear(hidden_size, num_classes)\n", " pass\n", " \n", " def forward(self, x):\n", " x = torch.reshape(x,(batch_size,sequence_length,input_size))\n", " # Set initial hidden and cell states \n", " h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)\n", " out, hidden = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size)\n", " #Reshaping the outputs such that it can be fit into the fully connected layer\n", " out = self.fc(out[:, -1, :])\n", " return out\n", " \n", "torch.manual_seed(0)\n", "\n", "model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)\n", "model.train()\n", "\n", "train_data = datasets.MNIST(\n", " root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',\n", " train = True, \n", " transform = ToTensor(), \n", " download = True, \n", ") \n", "\n", "loaders = {\n", " 'train' : torch.utils.data.DataLoader(train_data, \n", " batch_size=100, \n", " shuffle=True, \n", " num_workers=1),\n", "}\n", "\n", "inputs, labels = next(iter(loaders['train']))\n", "inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)" ] }, { "cell_type": "markdown", "id": "ec80716b", "metadata": {}, "source": [ "## Function Set-Up\n", "\n", "Encode the optimization variables, and objective and constraint functions.\n", "\n", "Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm." ] }, { "cell_type": "code", "execution_count": 3, "id": "fb360e75", "metadata": {}, "outputs": [], "source": [ "def user_fn(model,inputs,labels):\n", " # objective function \n", " logits = model(inputs)\n", " criterion = nn.CrossEntropyLoss()\n", " f = criterion(logits, labels)\n", "\n", " A = list(model.parameters())[1]\n", "\n", " # inequality constraint\n", " ci = None\n", "\n", " # equality constraint \n", " # special orthogonal group\n", " \n", " ce = pygransoStruct()\n", "\n", " c1_vec = (A.T @ A \n", " - torch.eye(hidden_size)
 .to(device=device, dtype=double_precision)
 ).reshape(1,-1)
 
 ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
 # ce.c2 = torch.det(A) - 1

 # ce = None

 return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels) Users may provide their own wrapper for the QP solver. ║
\u001b[0m\u001b[33m║ To disable this notice, set opts.quadprog_info_msg = False ║
\u001b[0m\u001b[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
\u001b[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║ 
Version 1.2.0 ║ 
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║ 
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║ 
 # of variables : 2110 ║ 
 # of inequality constraints : 0 ║ 
 # of equality constraints : 1 ║ 
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
 ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ 
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ 
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣ │ 4.055233 ║ - │ 1 │ 0.000000 ║ 1 │ 0.727032 ║ │ 1.02e-04 ║ S │ 13 │ 2.44e-04 ║ 2 │ 0.012828 ║ \n", " 550 ║ 1.000000 │ 9.4210490e-04 ║ 8.5316606e-04 ║ - │ 8.89e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.003028 ║ \n", " 600 ║ 1.000000 │ 8.9698591e-04 ║ 7.9341462e-04 ║ - │ 1.04e-04 ║ S │ 6 │ 0.031250 ║ 1 │ 0.228958 ║ \n", " 650 ║ 1.000000 │ 8.3518341e-04 ║ 7.5023136e-04 ║ - │ 8.50e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.249420 ║ \n", " 700 ║ 1.000000 │ 7.9166675e-04 ║ 7.3029674e-04 ║ - │ 6.14e-05 ║ S │ 7 │ 0.015625 ║ 1 │ 0.156740 ║ \n", " 750 ║ 1.000000 │ 7.5802236e-04 ║ 7.1762811e-04 ║ - │ 4.04e-05 ║ S │ 8 │ 0.007812 ║ 2 │ 0.003298 ║ \n", " 800 ║ 1.000000 │ 7.4858692e-04 ║ 7.0752411e-04 ║ - │ 4.11e-05 ║ S │ 15 │ 6.10e-05 ║ 2 │ 0.142087 ║ \n", " 850 ║ 1.000000 │ 7.3196260e-04 ║ 6.9034953e-04 ║ - │ 4.16e-05 ║ S │ 6 │ 0.031250 ║ 1 │ 0.671820 ║ \n", " 900 ║ 1.000000 │ 7.1470716e-04 ║ 6.8116904e-04 ║ - │ 3.35e-05 ║ S │ 5 │ 0.062500 ║ 1 │ 0.045029 ║ \n", " 950 ║ 1.000000 │ 6.4488437e-04 ║ 6.0576357e-04 ║ - │ 3.91e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036849 ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", "1000 ║ 1.000000 │ 4.8682921e-04 ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036078 ║ 
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results: ║ 
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║ 
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
 F ║ │ ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ │ │ ║ │ ║ 
 B ║ │ ║ 4.4406815e-04 ║ - │ 8.24e-05 ║ │ │ ║ │ ║ 
 MF ║ │ ║ 6.7149885e-04 ║ - │ 2.58e-05 ║ │ │ ║ │ ║ 
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 1000 ║ 
Function evaluations: 6636 ║ 
PyGRANSO termination code: 4 --- max iterations reached. ║ 
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 56.05571532249451s 