Orthogonal RNN¶
Train Orthogonal RNN for MNIST classification based on this Paper
NOTE: this example is still under development.
Problem Description¶
For each element in the input sequence, each layer computes the following function:
where \(h_{t}\) is the hidden state at time \(t\), and \(h_{t-1}\) is the hidden state of the previous layer at time \(t-1\) or the initial hidden state at time \(o\).
For each layer, we have the orthogonal constraint:
Modules Importing¶
Import all necessary modules and add PyGRANSO src folder to system path.
[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL
Data Initialization¶
Specify torch device, neural network architecture, and generate data.
NOTE: please specify path for downloading data.
Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)
[2]:
device = torch.device('cuda')
sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100
double_precision = torch.double
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
pass
def forward(self, x):
x = torch.reshape(x,(batch_size,sequence_length,input_size))
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
out, hidden = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size)
#Reshaping the outputs such that it can be fit into the fully connected layer
out = self.fc(out[:, -1, :])
return out
torch.manual_seed(0)
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()
train_data = datasets.MNIST(
root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',
train = True,
transform = ToTensor(),
download = True,
)
loaders = {
'train' : torch.utils.data.DataLoader(train_data,
batch_size=100,
shuffle=True,
num_workers=1),
}
inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)
/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Function Set-Up¶
Encode the optimization variables, and objective and constraint functions.
Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.
[3]:
def user_fn(model,inputs,labels):
# objective function
logits = model(inputs)
criterion = nn.CrossEntropyLoss()
f = criterion(logits, labels)
A = list(model.parameters())[1]
# inequality constraint
ci = None
# equality constraint
# special orthogonal group
ce = pygransoStruct()
ce.c1 = A.T @ A - torch.eye(hidden_size).to(device=device, dtype=double_precision)
# ce.c2 = torch.det(A) - 1
# ce = None
return [f,ci,ce]
comb_fn = lambda model : user_fn(model,inputs,labels)
User Options¶
Specify user-defined options for PyGRANSO
[4]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 3e-4
opts.viol_eq_tol = 3e-4
opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 10
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True
opts.mu0 = 200
Initial Test¶
Check initial accuracy of the RNN model
[5]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%
Main Algorithm¶
[6]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║
║ the default is osqp. Users may provide their own wrapper for the QP solver. ║
║ To disable this notice, set opts.quadprog_info_msg = False ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║
Version 1.2.0 ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║
# of variables : 2110 ║
# of inequality constraints : 0 ║
# of equality constraints : 900 ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
0 ║ 200.0000 │ 521.858725680 ║ 2.28955713337 ║ - │ 0.746091 ║ - │ 1 │ 0.000000 ║ 1 │ 22.68972 ║
10 ║ 62.76212 │ 128.098797617 ║ 1.73642166416 ║ - │ 0.957309 ║ S │ 1 │ 1.000000 ║ 1 │ 1.225705 ║
20 ║ 41.17823 │ 50.8284833052 ║ 1.06189375932 ║ - │ 1.568086 ║ S │ 1 │ 1.000000 ║ 1 │ 0.864339 ║
30 ║ 27.01703 │ 20.8374358771 ║ 0.65467158735 ║ - │ 1.286537 ║ S │ 1 │ 1.000000 ║ 1 │ 0.340082 ║
40 ║ 21.88380 │ 10.3384384935 ║ 0.38950779214 ║ - │ 0.627510 ║ S │ 2 │ 0.500000 ║ 1 │ 1.333055 ║
50 ║ 15.95329 │ 4.70122280552 ║ 0.25233410359 ║ - │ 0.241137 ║ S │ 1 │ 1.000000 ║ 1 │ 0.715793 ║
60 ║ 3.649601 │ 0.73125619485 ║ 0.16914007692 ║ - │ 0.001355 ║ S │ 1 │ 1.000000 ║ 1 │ 0.310019 ║
70 ║ 2.394503 │ 0.32977304232 ║ 0.12641808772 ║ - │ 1.86e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.267671 ║
80 ║ 0.751420 │ 0.08588708751 ║ 0.10989096115 ║ - │ 5.43e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.026373 ║
90 ║ 0.028668 │ 0.00476881762 ║ 0.10632401897 ║ - │ 5.62e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.001171 ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results: ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
F ║ │ ║ 0.10599965581 ║ - │ 4.89e-05 ║ │ │ ║ │ ║
B ║ │ ║ 0.10528524892 ║ - │ 1.41e-04 ║ │ │ ║ │ ║
MF ║ │ ║ 0.10851492951 ║ - │ 1.34e-05 ║ │ │ ║ │ ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 93 ║
Function evaluations: 136 ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 160.9356348514557s
Train Accuracy¶
[7]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
Final acc = 99.00%
PyGRANSO Restarting¶
(Optional) The following example shows how to use partial auto-differentiation feature (with user provided gradients) to accelerate pygranso
[8]:
# partial AD
def user_fn(model,inputs,labels):
# objective function
logits = model(inputs)
criterion = nn.CrossEntropyLoss()
f = criterion(logits, labels)
A = list(model.parameters())[1]
# get f_grad by AD
n = getNvarTorch(model.parameters())
f_grad = getObjGradDL(nvar=n,model=model,f=f, torch_device=device, double_precision=True)
f = f.detach().item()
# inequality constraint
ci = None
ci_grad = None
# equality constraint
# special orthogonal group
ce = pygransoStruct()
ce = A.T @ A - torch.eye(hidden_size).to(device=device, dtype=double_precision)
ce = ce.detach()
nconstr = hidden_size*hidden_size
ce = torch.reshape(ce,(nconstr,1))
ce_grad = torch.zeros((n,nconstr)).to(device=device, dtype=double_precision)
M = torch.zeros((nconstr,nconstr)).to(device=device, dtype=double_precision)
for i in range(hidden_size):
for j in range(hidden_size):
J_ij = torch.zeros((hidden_size,hidden_size)).to(device=device, dtype=double_precision)
J_ij[i,j] = 1
tmp = A.T@J_ij + J_ij.T@A
M[hidden_size*i+j,:] = tmp.reshape((1,hidden_size*hidden_size))
ce_grad[input_size*hidden_size:input_size*hidden_size+ hidden_size*(hidden_size),:] = M
ce_grad = ce_grad.detach()
return [f,f_grad,ci,ci_grad,ce,ce_grad]
Important: set opts.globalAD = False to disable global auto-differentiation
[9]:
opts.globalAD = False # disable global auto-differentiation
# reinitialize model
torch.manual_seed(0)
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()
# initial acc
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))
# Main Algorithm
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
# Train acc
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%
╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║
║ the default is osqp. Users may provide their own wrapper for the QP solver. ║
║ To disable this notice, set opts.quadprog_info_msg = False ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║
Version 1.2.0 ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║
# of variables : 2110 ║
# of inequality constraints : 0 ║
# of equality constraints : 900 ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
0 ║ 200.0000 │ 521.858725680 ║ 2.28955713337 ║ - │ 0.746091 ║ - │ 1 │ 0.000000 ║ 1 │ 22.68972 ║
10 ║ 62.76212 │ 128.098797617 ║ 1.73642166416 ║ - │ 0.957309 ║ S │ 1 │ 1.000000 ║ 1 │ 1.225705 ║
20 ║ 41.17823 │ 50.8284833052 ║ 1.06189375932 ║ - │ 1.568086 ║ S │ 1 │ 1.000000 ║ 1 │ 0.864339 ║
30 ║ 27.01703 │ 20.8374358771 ║ 0.65467158735 ║ - │ 1.286537 ║ S │ 1 │ 1.000000 ║ 1 │ 0.340082 ║
40 ║ 21.88380 │ 10.3384384935 ║ 0.38950779214 ║ - │ 0.627510 ║ S │ 2 │ 0.500000 ║ 1 │ 1.333055 ║
50 ║ 15.95329 │ 4.70122280552 ║ 0.25233410359 ║ - │ 0.241137 ║ S │ 1 │ 1.000000 ║ 1 │ 0.715793 ║
60 ║ 3.649601 │ 0.73125619485 ║ 0.16914007692 ║ - │ 0.001355 ║ S │ 1 │ 1.000000 ║ 1 │ 0.310019 ║
70 ║ 2.394503 │ 0.32977304232 ║ 0.12641808772 ║ - │ 1.86e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.267671 ║
80 ║ 0.751420 │ 0.08588708751 ║ 0.10989096115 ║ - │ 5.43e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.026373 ║
90 ║ 0.028668 │ 0.00476881762 ║ 0.10632401897 ║ - │ 5.62e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.001171 ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results: ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
F ║ │ ║ 0.10599965581 ║ - │ 4.89e-05 ║ │ │ ║ │ ║
B ║ │ ║ 0.10528524892 ║ - │ 1.41e-04 ║ │ │ ║ │ ║
MF ║ │ ║ 0.10851492951 ║ - │ 1.34e-05 ║ │ │ ║ │ ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 93 ║
Function evaluations: 136 ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 94.66318130493164s
Final acc = 99.00%