Trace Optimization¶
Trace optimization with orthogonal constraints taken from: Effrosini Kokiopoulou, Jie Chen, and Yousef Saad. “Trace optimization and eigenproblems in dimension reduction methods.” Numerical Linear Algebra with Applications 18.3 (2011): 565-602.
Problem Description¶
Given a symmetric matrix \(A\) of dimension \(n\times n\), and an arbitrary unitary matrix \(V\) of dimension \(n\times d\).
The trace of \(V^TAV\) is maximized when \(V\) is an orthogonal basis of the eigenspace associated with the (algebraically) largest eigenvalues.
If eigenvalues are labeled decreasingly and \(u_1,...,u_d\) are eigenvectors associated with the first \(d\) eigenvalues \(\lambda_1,...,\lambda_d\), and \(U = [u_1,...,u_d]\) with \(U^TU=I\), then,
Modules Importing¶
Import all necessary modules and add PyGRANSO src folder to system path.
[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
Data Initialization¶
Specify torch device, and generate data
Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)
[2]:
device = torch.device('cuda')
n = 5
d = 1
torch.manual_seed(0)
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
A = torch.randn(n,n).to(device=device, dtype=torch.double)
A = (A + A.T)/2
L, U = torch.linalg.eig(A)
L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)
index = torch.argsort(L,descending=True)
U = U[:,index[0:d]]
/tmp/ipykernel_894447/3523941836.py:11: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/aten/src/ATen/native/Copy.cpp:240.)
L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)
Function Set-Up¶
Encode the optimization variables, and objective and constraint functions.
Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.
[3]:
# variables and corresponding dimensions.
var_in = {"V": [n,d]}
def user_fn(X_struct,A,d):
V = X_struct.V
# objective function
f = -torch.trace(V.T@A@V)
# inequality constraint, matrix form
ci = None
# equality constraint
ce = pygransoStruct()
ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)
return [f,ci,ce]
comb_fn = lambda X_struct : user_fn(X_struct,A,d)
User Options¶
Specify user-defined options for PyGRANSO
[4]:
opts = pygransoStruct()
opts.torch_device = device
opts.print_frequency = 1
# opts.opt_tol = 1e-7
opts.maxit = 3000
# opts.mu0 = 10
# opts.steering_c_viol = 0.02
Main Algorithm¶
[5]:
start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
V = torch.reshape(soln.final.x,(n,d))
rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)
print("torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}".format(rel_dist))
print("torch.trace(V.T@A@V) = {}".format(torch.trace(V.T@A@V)))
print("torch.trace(U.T@A@U) = {}".format(torch.trace(U.T@A@U)))
print("sum of first d eigvals = {}".format(torch.sum(L[index[0:d]])))
print("sorted eigs = {}".format(L[index]))
╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║
║ the default is osqp. Users may provide their own wrapper for the QP solver. ║
║ To disable this notice, set opts.quadprog_info_msg = False ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║
Version 1.2.0 ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║
# of variables : 5 ║
# of inequality constraints : 0 ║
# of equality constraints : 1 ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
0 ║ 1.000000 │ 1.16724813601 ║ -2.14592113443 ║ - │ 3.313169 ║ - │ 1 │ 0.000000 ║ 1 │ 4.801284 ║
1 ║ 1.000000 │ -4.18220640118 ║ -27.2345373313 ║ - │ 23.05233 ║ S │ 1 │ 1.000000 ║ 1 │ 19.60748 ║
2 ║ 0.810000 │ -7.44513089341 ║ -25.7524524611 ║ - │ 13.41436 ║ S │ 1 │ 1.000000 ║ 1 │ 2.119321 ║
3 ║ 0.478297 │ -3.36403570563 ║ -29.7615131867 ║ - │ 10.87080 ║ S │ 1 │ 1.000000 ║ 1 │ 0.807332 ║
4 ║ 0.313811 │ 0.44260646755 ║ -29.8199400438 ║ - │ 9.800420 ║ S │ 1 │ 1.000000 ║ 1 │ 0.181373 ║
5 ║ 0.313811 │ -0.05159283053 ║ -25.9178254403 ║ - │ 8.081695 ║ S │ 2 │ 2.000000 ║ 1 │ 0.266675 ║
6 ║ 0.313811 │ -0.60015959521 ║ -19.8021105233 ║ - │ 5.613953 ║ S │ 2 │ 2.000000 ║ 1 │ 0.129813 ║
7 ║ 0.313811 │ -0.74007491793 ║ -15.7240418418 ║ - │ 4.194296 ║ S │ 1 │ 1.000000 ║ 1 │ 0.073206 ║
8 ║ 0.313811 │ -0.80534295483 ║ -10.5540665925 ║ - │ 2.506635 ║ S │ 3 │ 4.000000 ║ 1 │ 0.096243 ║
9 ║ 0.313811 │ -0.90562463200 ║ -3.96127832617 ║ - │ 0.337466 ║ S │ 3 │ 1.500000 ║ 1 │ 0.109112 ║
10 ║ 0.313811 │ -0.94306331782 ║ -3.09440966016 ║ - │ 0.027995 ║ S │ 1 │ 1.000000 ║ 1 │ 0.036484 ║
11 ║ 0.313811 │ -0.95168907673 ║ -3.04557834593 ║ - │ 0.004046 ║ S │ 1 │ 1.000000 ║ 1 │ 0.016092 ║
12 ║ 0.313811 │ -0.95404206443 ║ -3.04284283001 ║ - │ 8.34e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.010167 ║
13 ║ 0.313811 │ -0.95488755011 ║ -3.04425497591 ║ - │ 4.32e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.003923 ║
14 ║ 0.313811 │ -0.95497194333 ║ -3.04323398616 ║ - │ 2.71e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.001371 ║
15 ║ 0.313811 │ -0.95498874787 ║ -3.04321724294 ║ - │ 5.07e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 8.65e-04 ║
16 ║ 0.313811 │ -0.95499959494 ║ -3.04325102043 ║ - │ 4.82e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 6.98e-04 ║
17 ║ 0.313811 │ -0.95500308923 ║ -3.04325221318 ║ - │ 1.70e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 3.20e-04 ║
18 ║ 0.313811 │ -0.95500354042 ║ -3.04324890133 ║ - │ 2.11e-07 ║ S │ 1 │ 1.000000 ║ 1 │ 8.22e-05 ║
19 ║ 0.313811 │ -0.95500356428 ║ -3.04324833552 ║ - │ 9.93e-09 ║ S │ 1 │ 1.000000 ║ 2 │ 2.46e-05 ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
20 ║ 0.313811 │ -0.95500356476 ║ -3.04324830572 ║ - │ 1.06e-10 ║ S │ 7 │ 1.031250 ║ 2 │ 7.11e-06 ║
21 ║ 0.282430 │ -0.85950320836 ║ -3.04324830605 ║ - │ 1.11e-10 ║ S │ 1 │ 1.000000 ║ 3 │ 1.92e-06 ║
22 ║ 0.282430 │ -0.85950320837 ║ -3.04324830571 ║ - │ 8.61e-12 ║ S │ 5 │ 1.125000 ║ 4 │ 2.38e-07 ║
23 ║ 0.282430 │ -0.85950320837 ║ -3.04324830572 ║ - │ 8.20e-12 ║ S │ 3 │ 0.250000 ║ 5 │ 6.46e-07 ║
24 ║ 0.282430 │ -0.85950320837 ║ -3.04324830570 ║ - │ 4.19e-13 ║ S │ 6 │ 1.062500 ║ 6 │ 7.63e-05 ║
25 ║ 0.282430 │ -0.85950320837 ║ -3.04324830570 ║ - │ 6.05e-13 ║ S │ 11 │ 9.77e-04 ║ 7 │ 0.001019 ║
26 ║ 0.282430 │ -0.85950320837 ║ -3.04324830570 ║ - │ 8.78e-13 ║ S │ 2 │ 0.500000 ║ 8 │ 1.00e-06 ║
27 ║ 0.282430 │ -0.85950320845 ║ -3.04324830625 ║ - │ 7.90e-11 ║ SI │ 2 │ 0.500000 ║ 9 │ 2.12e-07 ║
28 ║ 0.282430 │ -0.85950320845 ║ -3.04324830631 ║ - │ 9.62e-11 ║ S │ 3 │ 0.250000 ║ 10 │ 9.83e-07 ║
29 ║ 0.282430 │ -0.85950320845 ║ -3.04324830634 ║ - │ 1.03e-10 ║ S │ 4 │ 0.125000 ║ 10 │ 5.25e-06 ║
30 ║ 0.282430 │ -0.85950320845 ║ -3.04324830635 ║ - │ 1.07e-10 ║ S │ 5 │ 0.062500 ║ 10 │ 5.19e-06 ║
31 ║ 0.282430 │ -0.85950320845 ║ -3.04324830635 ║ - │ 1.08e-10 ║ S │ 7 │ 0.015625 ║ 10 │ 5.37e-06 ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║
Optimization results: ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
F ║ │ ║ -3.04324830635 ║ - │ 1.08e-10 ║ │ │ ║ │ ║
B ║ │ ║ -3.04324890133 ║ - │ 2.11e-07 ║ │ │ ║ │ ║
MF ║ │ ║ -3.04324830570 ║ - │ 8.44e-14 ║ │ │ ║ │ ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 31 ║
Function evaluations: 105 ║
PyGRANSO termination code: 6 --- line search bracketed a minimizer but failed to satisfy Wolfe conditions at a ║
feasible point (to tolerances). This may be an indication that approximate stationarity has been attained. ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 1.0528507232666016s
torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 3.772472601724718e-06
torch.trace(V.T@A@V) = 3.0432483063508395
torch.trace(U.T@A@U) = 3.0432483060418907
sum of first d eigvals = 3.04324830604189
sorted eigs = tensor([ 3.0432, 0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',
dtype=torch.float64)
More Constraints¶
(Optional) Exploring the pygranso performance on different number of constraints
[6]:
device = torch.device('cuda')
n = 5
d = 2
torch.manual_seed(0)
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
A = torch.randn(n,n).to(device=device, dtype=torch.double)
A = (A + A.T)/2
L, U = torch.linalg.eig(A)
L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)
index = torch.argsort(L,descending=True)
U = U[:,index[0:d]]
# variables and corresponding dimensions.
var_in = {"V": [n,d]}
def user_fn(X_struct,A,d):
V = X_struct.V
# objective function
f = -torch.trace(V.T@A@V)
# inequality constraint, matrix form
ci = None
# equality constraint
ce = pygransoStruct()
ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)
return [f,ci,ce]
comb_fn = lambda X_struct : user_fn(X_struct,A,d)
opts = pygransoStruct()
opts.torch_device = device
opts.print_frequency = 10
opts.opt_tol = 5e-6
opts.maxit = 1000
# opts.mu0 = 10
# opts.steering_c_viol = 0.02
start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
V = torch.reshape(soln.final.x,(n,d))
rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)
print("torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}".format(rel_dist))
print("torch.trace(V.T@A@V) = {}".format(torch.trace(V.T@A@V)))
print("torch.trace(U.T@A@U) = {}".format(torch.trace(U.T@A@U)))
print("sum of first d eigvals = {}".format(torch.sum(L[index[0:d]])))
print("sorted eigs = {}".format(L[index]))
╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║
║ the default is osqp. Users may provide their own wrapper for the QP solver. ║
║ To disable this notice, set opts.quadprog_info_msg = False ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║
Version 1.2.0 ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications: ║
# of variables : 10 ║
# of inequality constraints : 0 ║
# of equality constraints : 4 ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
0 ║ 1.000000 │ 27.7034556094 ║ 0.81257641642 ║ - │ 8.581446 ║ - │ 1 │ 0.000000 ║ 1 │ 10.40129 ║
10 ║ 0.282430 │ -0.18719099678 ║ -2.46656247013 ║ - │ 0.280276 ║ S │ 2 │ 2.000000 ║ 1 │ 0.486407 ║
20 ║ 0.166772 │ -0.55397439861 ║ -3.46402291008 ║ - │ 0.021852 ║ S │ 1 │ 1.000000 ║ 1 │ 0.092432 ║
30 ║ 0.088629 │ -0.34324393783 ║ -3.87472269962 ║ - │ 1.46e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.002322 ║
40 ║ 0.025032 │ -0.09733355831 ║ -3.88881298264 ║ - │ 7.01e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 0.002899 ║
50 ║ 0.025032 │ -0.09765249793 ║ -3.90193007079 ║ - │ 1.09e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.003514 ║
60 ║ 0.025032 │ -0.09795121684 ║ -3.91354197166 ║ - │ 7.61e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 0.003132 ║
70 ║ 0.025032 │ -0.09811594986 ║ -3.91992492755 ║ - │ 3.69e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 0.002289 ║
80 ║ 0.002465 │ -0.00966775963 ║ -3.92198806015 ║ - │ 5.13e-08 ║ S │ 3 │ 1.500000 ║ 2 │ 5.79e-04 ║
90 ║ 5.64e-04 │ -0.00221178042 ║ -3.92221123547 ║ - │ 2.74e-08 ║ S │ 3 │ 1.500000 ║ 3 │ 4.65e-06 ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible ║
Optimization results: ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
F ║ │ ║ -3.92221123547 ║ - │ 2.74e-08 ║ │ │ ║ │ ║
B ║ │ ║ -3.92221510373 ║ - │ 2.72e-08 ║ │ │ ║ │ ║
MF ║ │ ║ -3.92199228401 ║ - │ 1.84e-09 ║ │ │ ║ │ ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations: 90 ║
Function evaluations: 122 ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 0.9984183311462402s
torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 0.062441956102622084
torch.trace(V.T@A@V) = 3.922211235466178
torch.trace(U.T@A@U) = 3.932280709191555
sum of first d eigvals = 3.9322807091915544
sorted eigs = tensor([ 3.0432, 0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',
dtype=torch.float64)