Orthogonal RNN

Train Orthogonal RNN for MNIST classification based on this Paper

NOTE: this example is still under development.

Problem Description

For each element in the input sequence, each layer computes the following function:

\[h_t=\tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_hh)\]

where \(h_{t}\) is the hidden state at time \(t\), and \(h_{t-1}\) is the hidden state of the previous layer at time \(t-1\) or the initial hidden state at time \(o\).

For each layer, we have the orthogonal constraint:

\[W_{hh}^T W_{hh} = I\]

Modules Importing

Import all necessary modules and add PyGRANSO src folder to system path.

[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from pygranso.private.getObjGrad import getObjGradDL

Data Initialization

Specify torch device, neural network architecture, and generate data.

NOTE: please specify path for downloading data.

Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)

[2]:
device = torch.device('cuda')

sequence_length = 28
input_size = 28
hidden_size = 30
num_layers = 1
num_classes = 10
batch_size = 100


double_precision = torch.double

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        pass

    def forward(self, x):
        x = torch.reshape(x,(batch_size,sequence_length,input_size))
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device, dtype=double_precision)
        out, hidden = self.rnn(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out

torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

train_data = datasets.MNIST(
    root = '/home/buyun/Documents/GitHub/PyGRANSO/examples/data/mnist',
    train = True,
    transform = ToTensor(),
    download = True,
)

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1),
}

inputs, labels = next(iter(loaders['train']))
inputs, labels = inputs.reshape(-1, sequence_length, input_size).to(device=device, dtype=double_precision), labels.to(device=device)
/home/buyun/anaconda3/envs/cuosqp_pygranso/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448255797/work/torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

[3]:
def user_fn(model,inputs,labels):
    # objective function
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # inequality constraint
    ci = None

    # equality constraint
    # special orthogonal group

    ce = pygransoStruct()

    ce.c1 = A.T @ A - torch.eye(hidden_size).to(device=device, dtype=double_precision)
    # ce.c2 = torch.det(A) - 1

    # ce = None

    return [f,ci,ce]

comb_fn = lambda model : user_fn(model,inputs,labels)

User Options

Specify user-defined options for PyGRANSO

[4]:
opts = pygransoStruct()
opts.torch_device = device
nvar = getNvarTorch(model.parameters())
opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
opts.opt_tol = 3e-4
opts.viol_eq_tol = 3e-4
opts.maxit = 150
# opts.fvalquit = 1e-6
opts.print_level = 1
opts.print_frequency = 10
# opts.print_ascii = True
# opts.limited_mem_size = 100
opts.double_precision = True

opts.mu0 = 200

Initial Test

Check initial accuracy of the RNN model

[5]:
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%

Main Algorithm

[6]:
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   2110                                                                     ║
 # of inequality constraints        :      0                                                                     ║
 # of equality constraints          :    900                                                                     ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 200.0000 │  521.858725680 ║  2.28955713337 ║   -  │ 0.746091 ║ -  │     1 │ 0.000000 ║     1 │ 22.68972   ║
  10 ║ 62.76212 │  128.098797617 ║  1.73642166416 ║   -  │ 0.957309 ║ S  │     1 │ 1.000000 ║     1 │ 1.225705   ║
  20 ║ 41.17823 │  50.8284833052 ║  1.06189375932 ║   -  │ 1.568086 ║ S  │     1 │ 1.000000 ║     1 │ 0.864339   ║
  30 ║ 27.01703 │  20.8374358771 ║  0.65467158735 ║   -  │ 1.286537 ║ S  │     1 │ 1.000000 ║     1 │ 0.340082   ║
  40 ║ 21.88380 │  10.3384384935 ║  0.38950779214 ║   -  │ 0.627510 ║ S  │     2 │ 0.500000 ║     1 │ 1.333055   ║
  50 ║ 15.95329 │  4.70122280552 ║  0.25233410359 ║   -  │ 0.241137 ║ S  │     1 │ 1.000000 ║     1 │ 0.715793   ║
  60 ║ 3.649601 │  0.73125619485 ║  0.16914007692 ║   -  │ 0.001355 ║ S  │     1 │ 1.000000 ║     1 │ 0.310019   ║
  70 ║ 2.394503 │  0.32977304232 ║  0.12641808772 ║   -  │ 1.86e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.267671   ║
  80 ║ 0.751420 │  0.08588708751 ║  0.10989096115 ║   -  │ 5.43e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.026373   ║
  90 ║ 0.028668 │  0.00476881762 ║  0.10632401897 ║   -  │ 5.62e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.001171   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results:                                                                                            ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.10599965581 ║   -  │ 4.89e-05 ║    │       │          ║       │            ║
   B ║          │                ║  0.10528524892 ║   -  │ 1.41e-04 ║    │       │          ║       │            ║
  MF ║          │                ║  0.10851492951 ║   -  │ 1.34e-05 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              93                                                                                      ║
Function evaluations:    136                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 160.9356348514557s

Train Accuracy

[7]:
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
Final acc = 99.00%

PyGRANSO Restarting

(Optional) The following example shows how to use partial auto-differentiation feature (with user provided gradients) to accelerate pygranso

[8]:
# partial AD
def user_fn(model,inputs,labels):
    # objective function
    logits = model(inputs)
    criterion = nn.CrossEntropyLoss()
    f = criterion(logits, labels)

    A = list(model.parameters())[1]

    # get f_grad by AD
    n = getNvarTorch(model.parameters())
    f_grad = getObjGradDL(nvar=n,model=model,f=f, torch_device=device, double_precision=True)
    f = f.detach().item()

    # inequality constraint
    ci = None
    ci_grad = None

    # equality constraint
    # special orthogonal group

    ce = pygransoStruct()

    ce = A.T @ A - torch.eye(hidden_size).to(device=device, dtype=double_precision)
    ce = ce.detach()
    nconstr = hidden_size*hidden_size
    ce = torch.reshape(ce,(nconstr,1))
    ce_grad = torch.zeros((n,nconstr)).to(device=device, dtype=double_precision)
    M = torch.zeros((nconstr,nconstr)).to(device=device, dtype=double_precision)

    for i in range(hidden_size):
        for j in range(hidden_size):
            J_ij = torch.zeros((hidden_size,hidden_size)).to(device=device, dtype=double_precision)
            J_ij[i,j] = 1
            tmp = A.T@J_ij + J_ij.T@A
            M[hidden_size*i+j,:] = tmp.reshape((1,hidden_size*hidden_size))

    ce_grad[input_size*hidden_size:input_size*hidden_size+ hidden_size*(hidden_size),:] = M
    ce_grad = ce_grad.detach()

    return [f,f_grad,ci,ci_grad,ce,ce_grad]

Important: set opts.globalAD = False to disable global auto-differentiation

[9]:
opts.globalAD = False # disable global auto-differentiation

# reinitialize model
torch.manual_seed(0)

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device=device, dtype=double_precision)
model.train()

# initial acc
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Initial acc = {:.2f}%".format((100 * correct/len(inputs))))

# Main Algorithm
start = time.time()
soln = pygranso(var_spec= model, combined_fn = comb_fn, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))

# Train acc
torch.nn.utils.vector_to_parameters(soln.final.x, model.parameters())
logits = model(inputs)
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
print("Final acc = {:.2f}%".format((100 * correct/len(inputs))))
Initial acc = 10.00%


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   2110                                                                     ║
 # of inequality constraints        :      0                                                                     ║
 # of equality constraints          :    900                                                                     ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 200.0000 │  521.858725680 ║  2.28955713337 ║   -  │ 0.746091 ║ -  │     1 │ 0.000000 ║     1 │ 22.68972   ║
  10 ║ 62.76212 │  128.098797617 ║  1.73642166416 ║   -  │ 0.957309 ║ S  │     1 │ 1.000000 ║     1 │ 1.225705   ║
  20 ║ 41.17823 │  50.8284833052 ║  1.06189375932 ║   -  │ 1.568086 ║ S  │     1 │ 1.000000 ║     1 │ 0.864339   ║
  30 ║ 27.01703 │  20.8374358771 ║  0.65467158735 ║   -  │ 1.286537 ║ S  │     1 │ 1.000000 ║     1 │ 0.340082   ║
  40 ║ 21.88380 │  10.3384384935 ║  0.38950779214 ║   -  │ 0.627510 ║ S  │     2 │ 0.500000 ║     1 │ 1.333055   ║
  50 ║ 15.95329 │  4.70122280552 ║  0.25233410359 ║   -  │ 0.241137 ║ S  │     1 │ 1.000000 ║     1 │ 0.715793   ║
  60 ║ 3.649601 │  0.73125619485 ║  0.16914007692 ║   -  │ 0.001355 ║ S  │     1 │ 1.000000 ║     1 │ 0.310019   ║
  70 ║ 2.394503 │  0.32977304232 ║  0.12641808772 ║   -  │ 1.86e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.267671   ║
  80 ║ 0.751420 │  0.08588708751 ║  0.10989096115 ║   -  │ 5.43e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.026373   ║
  90 ║ 0.028668 │  0.00476881762 ║  0.10632401897 ║   -  │ 5.62e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.001171   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Optimization results:                                                                                            ║
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║  0.10599965581 ║   -  │ 4.89e-05 ║    │       │          ║       │            ║
   B ║          │                ║  0.10528524892 ║   -  │ 1.41e-04 ║    │       │          ║       │            ║
  MF ║          │                ║  0.10851492951 ║   -  │ 1.34e-05 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              93                                                                                      ║
Function evaluations:    136                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 94.66318130493164s
Final acc = 99.00%