Orthogonal RNN

Train Orthogonal RNN for MNIST classification based on this Paper

NOTE: this example is still under development.

Problem Description

For each element in the input sequence, each layer computes the following function:

\[h_t=\tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_hh)\]

where \(h_{t}\) is the hidden state at time \(t\), and \(h_{t-1}\) is the hidden state of the previous layer at time \(t-1\) or the initial hidden state at time \(o\).

For each layer, we have the orthogonal constraint:

\[W_{hh}^T W_{hh} = I\]

Modules Importing

Import all necessary modules and add PyGRANSO src folder to system path.

[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL

Data Initialization

Specify torch device, neural network architecture, and generate data.

NOTE: please specify path for downloading data.

Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)

[2]:
device = torch.device('cuda')

sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100


double_precision = torch.double

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        pass

    def forward(self, x):
        x = torch.reshape(x,(batch_size,sequence_length,input_size))
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
        out, hidden = self.rnn(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out

torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

train_data = datasets.MNIST(
    root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',
    train = True,
    transform = ToTensor(),
    download = True,
)

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1),
}

inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)
/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

[3]:
def user_fn(model,inputs,labels):
    # objective function
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # inequality constraint
    ci = None

    # equality constraint
    # special orthogonal group

    ce = pygransoStruct()

    c1_vec = (A.T @ A
              - torch.eye(hidden_size)
              .to(device=device, dtype=double_precision)
             ).reshape(1,-1)

    ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
    # ce.c2 = torch.det(A) - 1

    # ce = None

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels)

User Options

Specify user-defined options for PyGRANSO

[4]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 1e-3
opts.viol_eq_tol = 1e-4
# opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 50
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True

opts.mu0 = 1

Initial Test

Check initial accuracy of the RNN model

[5]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%

Main Algorithm

[6]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   2110                                                                     ║
 # of inequality constraints        :      0                                                                     ║
 # of equality constraints          :      1                                                                     ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  6.34479003561 ║  2.28955713337 ║   -  │ 4.055233 ║ -  │     1 │ 0.000000 ║     1 │ 0.727032   ║
  50 ║ 1.000000 │  0.44342852064 ║  0.26057998535 ║   -  │ 0.182849 ║ S  │     2 │ 0.500000 ║     1 │ 1.012217   ║
 100 ║ 1.000000 │  0.03502773133 ║  0.01024156848 ║   -  │ 0.024786 ║ S  │     3 │ 0.250000 ║     1 │ 0.072379   ║
 150 ║ 1.000000 │  0.00964913672 ║  0.00333188716 ║   -  │ 0.006317 ║ S  │     4 │ 0.125000 ║     1 │ 0.460388   ║
 200 ║ 1.000000 │  0.00402016055 ║  0.00218694435 ║   -  │ 0.001833 ║ S  │     3 │ 0.250000 ║     1 │ 0.185655   ║
 250 ║ 1.000000 │  0.00259692636 ║  0.00156567317 ║   -  │ 0.001031 ║ S  │     3 │ 0.250000 ║     1 │ 0.278221   ║
 300 ║ 1.000000 │  0.00177902934 ║  0.00124617387 ║   -  │ 5.33e-04 ║ S  │     6 │ 0.031250 ║     1 │ 0.013047   ║
 350 ║ 1.000000 │  0.00142805944 ║  0.00111832264 ║   -  │ 3.10e-04 ║ S  │    10 │ 0.001953 ║     1 │ 0.377399   ║
 400 ║ 1.000000 │  0.00125541329 ║  0.00104496642 ║   -  │ 2.10e-04 ║ S  │     7 │ 0.015625 ║     1 │ 0.151672   ║
 450 ║ 1.000000 │  0.00111109651 ║  9.8199316e-04 ║   -  │ 1.29e-04 ║ S  │    11 │ 9.77e-04 ║     1 │ 0.239003   ║
 500 ║ 1.000000 │  0.00104336348 ║  9.4179093e-04 ║   -  │ 1.02e-04 ║ S  │    13 │ 2.44e-04 ║     2 │ 0.012828   ║
 550 ║ 1.000000 │  9.4210490e-04 ║  8.5316606e-04 ║   -  │ 8.89e-05 ║ S  │     8 │ 0.007812 ║     1 │ 0.003028   ║
 600 ║ 1.000000 │  8.9698591e-04 ║  7.9341462e-04 ║   -  │ 1.04e-04 ║ S  │     6 │ 0.031250 ║     1 │ 0.228958   ║
 650 ║ 1.000000 │  8.3518341e-04 ║  7.5023136e-04 ║   -  │ 8.50e-05 ║ S  │     8 │ 0.007812 ║     1 │ 0.249420   ║
 700 ║ 1.000000 │  7.9166675e-04 ║  7.3029674e-04 ║   -  │ 6.14e-05 ║ S  │     7 │ 0.015625 ║     1 │ 0.156740   ║
 750 ║ 1.000000 │  7.5802236e-04 ║  7.1762811e-04 ║   -  │ 4.04e-05 ║ S  │     8 │ 0.007812 ║     2 │ 0.003298   ║
 800 ║ 1.000000 │  7.4858692e-04 ║  7.0752411e-04 ║   -  │ 4.11e-05 ║ S  │    15 │ 6.10e-05 ║     2 │ 0.142087   ║
 850 ║ 1.000000 │  7.3196260e-04 ║  6.9034953e-04 ║   -  │ 4.16e-05 ║ S  │     6 │ 0.031250 ║     1 │ 0.671820   ║
 900 ║ 1.000000 │  7.1470716e-04 ║  6.8116904e-04 ║   -  │ 3.35e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.045029   ║
 950 ║ 1.000000 │  6.4488437e-04 ║  6.0576357e-04 ║   -  │ 3.91e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.036849   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
1000 ║ 1.000000 │  4.8682921e-04 ║  4.5088613e-04 ║   -  │ 3.59e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.036078   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results:                                                                                            ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  4.5088613e-04 ║   -  │ 3.59e-05 ║    │       │          ║       │            ║
   B ║          │                ║  4.4406815e-04 ║   -  │ 8.24e-05 ║    │       │          ║       │            ║
  MF ║          │                ║  6.7149885e-04 ║   -  │ 2.58e-05 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              1000                                                                                    ║
Function evaluations:    6636                                                                                    ║
PyGRANSO termination code: 4 --- max iterations reached.                                                         ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 56.05571532249451s

Train Accuracy

[7]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
print("final feasibility = {}".format(soln.final.tve))
Final acc = 100.00%
final feasibility = 3.59430769855918e-05