Orthogonal RNN¶
Train Orthogonal RNN for MNIST classification based on this Paper
NOTE: this example is still under development.
Problem Description¶
For each element in the input sequence, each layer computes the following function:
where \(h_{t}\) is the hidden state at time \(t\), and \(h_{t-1}\) is the hidden state of the previous layer at time \(t-1\) or the initial hidden state at time \(o\).
For each layer, we have the orthogonal constraint:
Modules Importing¶
Import all necessary modules and add PyGRANSO src folder to system path.
[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL
Data Initialization¶
Specify torch device, neural network architecture, and generate data.
NOTE: please specify path for downloading data.
Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)
[2]:
device = torch.device('cuda')
sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100
double_precision = torch.double
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
pass
def forward(self, x):
x = torch.reshape(x,(batch_size,sequence_length,input_size))
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
out, hidden = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size)
#Reshaping the outputs such that it can be fit into the fully connected layer
out = self.fc(out[:, -1, :])
return out
torch.manual_seed(0)
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()
train_data = datasets.MNIST(
root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',
train = True,
transform = ToTensor(),
download = True,
)
loaders = {
'train' : torch.utils.data.DataLoader(train_data,
batch_size=100,
shuffle=True,
num_workers=1),
}
inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)
/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Function Set-Up¶
Encode the optimization variables, and objective and constraint functions.
Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.
[3]:
def user_fn(model,inputs,labels):
# objective function
logits = model(inputs)
criterion = nn.CrossEntropyLoss()
f = criterion(logits, labels)
A = list(model.parameters())[1]
# inequality constraint
ci = None
# equality constraint
# special orthogonal group
ce = pygransoStruct()
c1_vec = (A.T @ A
- torch.eye(hidden_size)
.to(device=device, dtype=double_precision)
).reshape(1,-1)
ce.c1 = torch.linalg.vector_norm(c1_vec,2) # l2 folding to reduce the total number of constraints
# ce.c2 = torch.det(A) - 1
# ce = None
return [f,ci,ce]
comb_fn = lambda model : user_fn(model,inputs,labels)
User Options¶
Specify user-defined options for PyGRANSO
[4]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 1e-3
opts.viol_eq_tol = 1e-4
# opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 50
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True
opts.mu0 = 1
Initial Test¶
Check initial accuracy of the RNN model
[5]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%
Main Algorithm¶
[6]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║
║ the default is osqp. Users may provide their own wrapper for the QP solver. ║
║ To disable this notice, set opts.quadprog_info_msg = False ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║
Version 1.2.0 ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║
# of variables : 2110 ║
# of inequality constraints : 0 ║
# of equality constraints : 1 ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
0 ║ 1.000000 │ 6.34479003561 ║ 2.28955713337 ║ - │ 4.055233 ║ - │ 1 │ 0.000000 ║ 1 │ 0.727032 ║
50 ║ 1.000000 │ 0.44342852064 ║ 0.26057998535 ║ - │ 0.182849 ║ S │ 2 │ 0.500000 ║ 1 │ 1.012217 ║
100 ║ 1.000000 │ 0.03502773133 ║ 0.01024156848 ║ - │ 0.024786 ║ S │ 3 │ 0.250000 ║ 1 │ 0.072379 ║
150 ║ 1.000000 │ 0.00964913672 ║ 0.00333188716 ║ - │ 0.006317 ║ S │ 4 │ 0.125000 ║ 1 │ 0.460388 ║
200 ║ 1.000000 │ 0.00402016055 ║ 0.00218694435 ║ - │ 0.001833 ║ S │ 3 │ 0.250000 ║ 1 │ 0.185655 ║
250 ║ 1.000000 │ 0.00259692636 ║ 0.00156567317 ║ - │ 0.001031 ║ S │ 3 │ 0.250000 ║ 1 │ 0.278221 ║
300 ║ 1.000000 │ 0.00177902934 ║ 0.00124617387 ║ - │ 5.33e-04 ║ S │ 6 │ 0.031250 ║ 1 │ 0.013047 ║
350 ║ 1.000000 │ 0.00142805944 ║ 0.00111832264 ║ - │ 3.10e-04 ║ S │ 10 │ 0.001953 ║ 1 │ 0.377399 ║
400 ║ 1.000000 │ 0.00125541329 ║ 0.00104496642 ║ - │ 2.10e-04 ║ S │ 7 │ 0.015625 ║ 1 │ 0.151672 ║
450 ║ 1.000000 │ 0.00111109651 ║ 9.8199316e-04 ║ - │ 1.29e-04 ║ S │ 11 │ 9.77e-04 ║ 1 │ 0.239003 ║
500 ║ 1.000000 │ 0.00104336348 ║ 9.4179093e-04 ║ - │ 1.02e-04 ║ S │ 13 │ 2.44e-04 ║ 2 │ 0.012828 ║
550 ║ 1.000000 │ 9.4210490e-04 ║ 8.5316606e-04 ║ - │ 8.89e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.003028 ║
600 ║ 1.000000 │ 8.9698591e-04 ║ 7.9341462e-04 ║ - │ 1.04e-04 ║ S │ 6 │ 0.031250 ║ 1 │ 0.228958 ║
650 ║ 1.000000 │ 8.3518341e-04 ║ 7.5023136e-04 ║ - │ 8.50e-05 ║ S │ 8 │ 0.007812 ║ 1 │ 0.249420 ║
700 ║ 1.000000 │ 7.9166675e-04 ║ 7.3029674e-04 ║ - │ 6.14e-05 ║ S │ 7 │ 0.015625 ║ 1 │ 0.156740 ║
750 ║ 1.000000 │ 7.5802236e-04 ║ 7.1762811e-04 ║ - │ 4.04e-05 ║ S │ 8 │ 0.007812 ║ 2 │ 0.003298 ║
800 ║ 1.000000 │ 7.4858692e-04 ║ 7.0752411e-04 ║ - │ 4.11e-05 ║ S │ 15 │ 6.10e-05 ║ 2 │ 0.142087 ║
850 ║ 1.000000 │ 7.3196260e-04 ║ 6.9034953e-04 ║ - │ 4.16e-05 ║ S │ 6 │ 0.031250 ║ 1 │ 0.671820 ║
900 ║ 1.000000 │ 7.1470716e-04 ║ 6.8116904e-04 ║ - │ 3.35e-05 ║ S │ 5 │ 0.062500 ║ 1 │ 0.045029 ║
950 ║ 1.000000 │ 6.4488437e-04 ║ 6.0576357e-04 ║ - │ 3.91e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036849 ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
1000 ║ 1.000000 │ 4.8682921e-04 ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ S │ 4 │ 0.125000 ║ 1 │ 0.036078 ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results: ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
F ║ │ ║ 4.5088613e-04 ║ - │ 3.59e-05 ║ │ │ ║ │ ║
B ║ │ ║ 4.4406815e-04 ║ - │ 8.24e-05 ║ │ │ ║ │ ║
MF ║ │ ║ 6.7149885e-04 ║ - │ 2.58e-05 ║ │ │ ║ │ ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 1000 ║
Function evaluations: 6636 ║
PyGRANSO termination code: 4 --- max iterations reached. ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 56.05571532249451s
Train Accuracy¶
[7]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
print("final feasibility = {}".format(soln.final.tve))
Final acc = 100.00%
final feasibility = 3.59430769855918e-05