Trace Optimization

Trace optimization with orthogonal constraints taken from: Effrosini Kokiopoulou, Jie Chen, and Yousef Saad. “Trace optimization and eigenproblems in dimension reduction methods.” Numerical Linear Algebra with Applications 18.3 (2011): 565-602.

Problem Description

Given a symmetric matrix \(A\) of dimension \(n\times n\), and an arbitrary unitary matrix \(V\) of dimension \(n\times d\).

The trace of \(V^TAV\) is maximized when \(V\) is an orthogonal basis of the eigenspace associated with the (algebraically) largest eigenvalues.

If eigenvalues are labeled decreasingly and \(u_1,...,u_d\) are eigenvectors associated with the first \(d\) eigenvalues \(\lambda_1,...,\lambda_d\), and \(U = [u_1,...,u_d]\) with \(U^TU=I\), then,

\[\max_{V \in R^{n\times d}, V^TV=I} \text{Tr}[V^TAV]=\text{Tr}[U^TAU]=\lambda_1+...+\lambda_d\]

Modules Importing

Import all necessary modules and add PyGRANSO src folder to system path.

[1]:
import time
import torch
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct

Data Initialization

Specify torch device, and generate data

Use GPU for this problem. If no cuda device available, please set device = torch.device(‘cpu’)

[2]:
device = torch.device('cuda')
n = 5
d = 1
torch.manual_seed(0)
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
A = torch.randn(n,n).to(device=device, dtype=torch.double)
A = (A + A.T)/2
L, U = torch.linalg.eig(A)
L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)
index = torch.argsort(L,descending=True)
U = U[:,index[0:d]]
/tmp/ipykernel_221231/3523941836.py:11: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448255797/work/aten/src/ATen/native/Copy.cpp:240.)
  L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)

Function Set-Up

Encode the optimization variables, and objective and constraint functions.

Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm.

[3]:
# variables and corresponding dimensions.
var_in = {"V": [n,d]}

def user_fn(X_struct,A,d):
    V = X_struct.V

    # objective function
    f = -torch.trace(V.T@A@V)

    # inequality constraint, matrix form
    ci = None

    # equality constraint
    ce = pygransoStruct()
    ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)

    return [f,ci,ce]

comb_fn = lambda X_struct : user_fn(X_struct,A,d)

User Options

Specify user-defined options for PyGRANSO

[4]:
opts = pygransoStruct()
opts.torch_device = device
opts.print_frequency = 1
# opts.opt_tol = 1e-7
opts.maxit = 3000
# opts.mu0 = 10
# opts.steering_c_viol = 0.02

Main Algorithm

[5]:
start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))

V = torch.reshape(soln.final.x,(n,d))

rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)
print("torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}".format(rel_dist))

print("torch.trace(V.T@A@V) = {}".format(torch.trace(V.T@A@V)))
print("torch.trace(U.T@A@U) = {}".format(torch.trace(U.T@A@U)))
print("sum of first d eigvals = {}".format(torch.sum(L[index[0:d]])))
print("sorted eigs = {}".format(L[index]))


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   5                                                                        ║
 # of inequality constraints        :   0                                                                        ║
 # of equality constraints          :   1                                                                        ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  1.16724813601 ║ -2.14592113443 ║   -  │ 3.313169 ║ -  │     1 │ 0.000000 ║     1 │ 4.801284   ║
   1 ║ 1.000000 │ -4.18220640118 ║ -27.2345373313 ║   -  │ 23.05233 ║ S  │     1 │ 1.000000 ║     1 │ 19.60679   ║
   2 ║ 0.810000 │ -7.44332699137 ║ -25.7415437768 ║   -  │ 13.40732 ║ S  │     1 │ 1.000000 ║     1 │ 2.115362   ║
   3 ║ 0.478297 │ -3.35759159998 ║ -29.6991742459 ║   -  │ 10.84743 ║ S  │     1 │ 1.000000 ║     1 │ 0.801258   ║
   4 ║ 0.313811 │  0.44031618300 ║ -29.6849843151 ║   -  │ 9.755779 ║ S  │     1 │ 1.000000 ║     1 │ 0.181847   ║
   5 ║ 0.313811 │ -0.05442280769 ║ -25.6951163222 ║   -  │ 8.008977 ║ S  │     2 │ 2.000000 ║     1 │ 0.267744   ║
   6 ║ 0.313811 │ -0.60400882947 ║ -19.5195850208 ║   -  │ 5.521444 ║ S  │     2 │ 2.000000 ║     1 │ 0.129171   ║
   7 ║ 0.313811 │ -0.74392449978 ║ -15.4214205335 ║   -  │ 4.095481 ║ S  │     1 │ 1.000000 ║     1 │ 0.071143   ║
   8 ║ 0.313811 │ -0.80903198043 ║ -10.2208919613 ║   -  │ 2.398392 ║ S  │     3 │ 4.000000 ║     1 │ 0.098483   ║
   9 ║ 0.313811 │ -0.90831604867 ║ -3.84466777888 ║   -  │ 0.298181 ║ S  │     3 │ 1.500000 ║     1 │ 0.110139   ║
  10 ║ 0.313811 │ -0.94420742855 ║ -3.08652285415 ║   -  │ 0.024376 ║ S  │     1 │ 1.000000 ║     1 │ 0.034643   ║
  11 ║ 0.313811 │ -0.95190087009 ║ -3.04454808709 ║   -  │ 0.003511 ║ S  │     1 │ 1.000000 ║     1 │ 0.015493   ║
  12 ║ 0.313811 │ -0.95411578037 ║ -3.04298259421 ║   -  │ 8.04e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.009830   ║
  13 ║ 0.313811 │ -0.95489372763 ║ -3.04415367869 ║   -  │ 3.94e-04 ║ S  │     1 │ 1.000000 ║     1 │ 0.003801   ║
  14 ║ 0.313811 │ -0.95497336218 ║ -3.04323431616 ║   -  │ 2.58e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.001336   ║
  15 ║ 0.313811 │ -0.95498942531 ║ -3.04321860994 ║   -  │ 4.82e-06 ║ S  │     1 │ 1.000000 ║     1 │ 8.48e-04   ║
  16 ║ 0.313811 │ -0.95499981196 ║ -3.04325113730 ║   -  │ 4.64e-06 ║ S  │     1 │ 1.000000 ║     1 │ 6.81e-04   ║
  17 ║ 0.313811 │ -0.95500311823 ║ -3.04325200800 ║   -  │ 1.61e-06 ║ S  │     1 │ 1.000000 ║     1 │ 3.10e-04   ║
  18 ║ 0.313811 │ -0.95500354216 ║ -3.04324886242 ║   -  │ 1.97e-07 ║ S  │     1 │ 1.000000 ║     1 │ 7.90e-05   ║
  19 ║ 0.313811 │ -0.95500356434 ║ -3.04324833316 ║   -  │ 9.14e-09 ║ S  │     1 │ 1.000000 ║     2 │ 5.28e-05   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
  20 ║ 0.313811 │ -0.95500356478 ║ -3.04324830574 ║   -  │ 9.81e-11 ║ S  │     7 │ 1.031250 ║     2 │ 6.94e-06   ║
  21 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║   -  │ 3.89e-12 ║ SI │     1 │ 1.000000 ║     3 │ 28.67710   ║
  22 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║   -  │ 3.97e-12 ║ SI │     2 │ 0.500000 ║     4 │ 6.23e-06   ║
  23 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║   -  │ 2.25e-12 ║ SI │     2 │ 0.500000 ║     5 │ 3.39e-06   ║
  24 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║   -  │ 2.31e-12 ║ S  │    11 │ 9.77e-04 ║     6 │ 3.85e-04   ║
  25 ║ 0.313811 │ -0.95500356497 ║ -3.04324830604 ║   -  │ 3.60e-14 ║ S  │     8 │ 1.015625 ║     7 │ 2.23e-07   ║
  26 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║   -  │ 1.49e-14 ║ S  │    24 │ 3.58e-07 ║     8 │ 4.571292   ║
  27 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║   -  │ 5.55e-15 ║ S  │    13 │ 2.44e-04 ║     9 │ 4.76e-07   ║
  28 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║   -  │ 4.66e-15 ║ S  │    14 │ 1.22e-04 ║    10 │ 4.75e-07   ║
  29 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║   -  │ 2.53e-14 ║ SI │     1 │ 1.000000 ║    10 │ 3.95e-09   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║ -3.04324830604 ║   -  │ 2.53e-14 ║    │       │          ║       │            ║
   B ║          │                ║ -3.04324886242 ║   -  │ 1.97e-07 ║    │       │          ║       │            ║
  MF ║          │                ║ -3.04324830604 ║   -  │ 0.000000 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              29                                                                                      ║
Function evaluations:    109                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 1.0340015888214111s
torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 1.4317144979982078e-07
torch.trace(V.T@A@V) = 3.0432483060419457
torch.trace(U.T@A@U) = 3.0432483060418907
sum of first d eigvals = 3.04324830604189
sorted eigs = tensor([ 3.0432,  0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',
       dtype=torch.float64)

More Constraints

(Optional) Exploring the pygranso performance on different number of constraints

[6]:
device = torch.device('cuda')
n = 5
d = 2
torch.manual_seed(0)
# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.
# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.
# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.
A = torch.randn(n,n).to(device=device, dtype=torch.double)
A = (A + A.T)/2
L, U = torch.linalg.eig(A)
L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)
index = torch.argsort(L,descending=True)
U = U[:,index[0:d]]

# variables and corresponding dimensions.
var_in = {"V": [n,d]}

def user_fn(X_struct,A,d):
    V = X_struct.V

    # objective function
    f = -torch.trace(V.T@A@V)

    # inequality constraint, matrix form
    ci = None

    # equality constraint
    ce = pygransoStruct()
    ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)

    return [f,ci,ce]

comb_fn = lambda X_struct : user_fn(X_struct,A,d)

opts = pygransoStruct()
opts.torch_device = device
opts.print_frequency = 10
opts.opt_tol = 5e-6
opts.maxit = 1000
# opts.mu0 = 10
# opts.steering_c_viol = 0.02

start = time.time()
soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))

V = torch.reshape(soln.final.x,(n,d))

rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)
print("torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}".format(rel_dist))

print("torch.trace(V.T@A@V) = {}".format(torch.trace(V.T@A@V)))
print("torch.trace(U.T@A@U) = {}".format(torch.trace(U.T@A@U)))
print("sum of first d eigvals = {}".format(torch.sum(L[index[0:d]])))
print("sorted eigs = {}".format(L[index]))


╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗
║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface,  ║
║  the default is osqp. Users may provide their own wrapper for the QP solver.                  ║
║  To disable this notice, set opts.quadprog_info_msg = False                                   ║
╚═══════════════════════════════════════════════════════════════════════════════════════════════╝
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation                                             ║
Version 1.2.0                                                                                                    ║
Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang                                  ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Problem specifications:                                                                                          ║
 # of variables                     :   10                                                                       ║
 # of inequality constraints        :    0                                                                       ║
 # of equality constraints          :    4                                                                       ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
     ║ <--- Penalty Function --> ║                ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║
Iter ║    Mu    │      Value     ║    Objective   ║ Ineq │    Eq    ║ SD │ Evals │     t    ║ Grads │    Value   ║
═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣
   0 ║ 1.000000 │  27.7034556094 ║  0.81257641642 ║   -  │ 8.581446 ║ -  │     1 │ 0.000000 ║     1 │ 10.40129   ║
  10 ║ 0.282430 │ -0.30396065925 ║ -3.41481873164 ║   -  │ 0.524987 ║ S  │     1 │ 1.000000 ║     1 │ 0.937294   ║
  20 ║ 0.205891 │ -0.74610607144 ║ -3.69972896961 ║   -  │ 0.007700 ║ S  │     1 │ 1.000000 ║     1 │ 0.011535   ║
  30 ║ 0.058150 │ -0.22595631429 ║ -3.89008845464 ║   -  │ 1.69e-04 ║ S  │     2 │ 0.500000 ║     1 │ 0.038402   ║
  40 ║ 0.047101 │ -0.18419411173 ║ -3.91138416923 ║   -  │ 2.92e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.022714   ║
  50 ║ 0.047101 │ -0.18471088554 ║ -3.92205275507 ║   -  │ 1.73e-05 ║ S  │     1 │ 1.000000 ║     1 │ 0.009407   ║
  60 ║ 0.014781 │ -0.05801100437 ║ -3.92495071255 ║   -  │ 1.75e-06 ║ S  │     1 │ 1.000000 ║     1 │ 0.001456   ║
  70 ║ 0.014781 │ -0.05802583565 ║ -3.92589713488 ║   -  │ 1.05e-06 ║ S  │     1 │ 1.000000 ║     1 │ 8.25e-04   ║
  80 ║ 0.014781 │ -0.05803300583 ║ -3.92636319358 ║   -  │ 9.28e-07 ║ S  │     1 │ 1.000000 ║     1 │ 6.60e-04   ║
  90 ║ 0.014781 │ -0.05803840995 ║ -3.92674667061 ║   -  │ 1.13e-06 ║ S  │     1 │ 1.000000 ║     1 │ 7.99e-04   ║
 100 ║ 0.014781 │ -0.05804249204 ║ -3.92706587993 ║   -  │ 1.49e-06 ║ S  │     1 │ 1.000000 ║     1 │ 8.05e-04   ║
 110 ║ 0.010775 │ -0.04231489018 ║ -3.92730230715 ║   -  │ 1.50e-06 ║ S  │     2 │ 2.000000 ║     1 │ 5.00e-04   ║
 120 ║ 0.001061 │ -0.00416748732 ║ -3.92753388653 ║   -  │ 5.44e-08 ║ SI │     1 │ 1.000000 ║     1 │ 2.77e-05   ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
F = final iterate, B = Best (to tolerance), MF = Most Feasible                                                   ║
Optimization results:                                                                                            ║
═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣
   F ║          │                ║ -3.92761322417 ║   -  │ 3.62e-09 ║    │       │          ║       │            ║
   B ║          │                ║ -3.92772408038 ║   -  │ 5.22e-07 ║    │       │          ║       │            ║
  MF ║          │                ║ -3.92743429333 ║   -  │ 2.74e-09 ║    │       │          ║       │            ║
═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣
Iterations:              126                                                                                     ║
Function evaluations:    183                                                                                     ║
PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances.                           ║
═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Total Wall Time: 1.4121968746185303s
torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 0.04411636799335658
torch.trace(V.T@A@V) = 3.9276132241693347
torch.trace(U.T@A@U) = 3.932280709191555
sum of first d eigvals = 3.9322807091915544
sorted eigs = tensor([ 3.0432,  0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',
       dtype=torch.float64)