{ "cells": [ { "cell_type": "markdown", "id": "b910d945", "metadata": {}, "source": [ "# Trace Optimization\n", "\n", "Trace optimization with orthogonal constraints taken from: Effrosini Kokiopoulou, Jie Chen, and Yousef Saad. \"Trace optimization and eigenproblems in dimension reduction methods.\" Numerical Linear Algebra with Applications 18.3 (2011): 565-602." ] }, { "cell_type": "markdown", "id": "13b5ad66", "metadata": {}, "source": [ "## Problem Description\n", "Given a symmetric matrix $A$ of dimension $n\\times n$, and an arbitrary unitary matrix $V$ of dimension $n\\times d$. \n", "\n", "The trace of $V^TAV$ is maximized when $V$ is an orthogonal basis of the eigenspace associated with the (algebraically) largest eigenvalues.\n", "\n", "If eigenvalues are labeled decreasingly and $u_1,...,u_d$ are eigenvectors associated with the first $d$ eigenvalues $\\lambda_1,...,\\lambda_d$, and $U = [u_1,...,u_d]$ with $U^TU=I$, then,\n", "\n", "$$\\max_{V \\in R^{n\\times d}, V^TV=I} \\text{Tr}[V^TAV]=\\text{Tr}[U^TAU]=\\lambda_1+...+\\lambda_d$$\n" ] }, { "cell_type": "markdown", "id": "73483897", "metadata": {}, "source": [ "## Modules Importing\n", "Import all necessary modules and add PyGRANSO src folder to system path." ] }, { "cell_type": "code", "execution_count": 1, "id": "ae68ad56", "metadata": {}, "outputs": [], "source": [ "import time\n", "import torch\n", "import sys\n", "## Adding PyGRANSO directories. Should be modified by user\n", "sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')\n", "from pygranso.pygranso import pygranso\n", "from pygranso.pygransoStruct import pygransoStruct" ] }, { "cell_type": "markdown", "id": "d3713c13", "metadata": {}, "source": [ "## Data Initialization \n", "Specify torch device, and generate data\n", "\n", "Use GPU for this problem. If no cuda device available, please set *device = torch.device('cpu')*" ] }, { "cell_type": "code", "execution_count": 2, "id": "f80d015b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_221231/3523941836.py:11: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /opt/conda/conda-bld/pytorch_1623448255797/work/aten/src/ATen/native/Copy.cpp:240.)\n", " L, U = L.to(dtype=torch.double), U.to(dtype=torch.double)\n" ] } ], "source": [ "device = torch.device('cuda')\n", "n = 5\n", "d = 1\n", "torch.manual_seed(0)\n", "# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.\n", "# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.\n", "# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.\n", "A = torch.randn(n,n).to(device=device, dtype=torch.double)\n", "A = (A + A.T)/2\n", "L, U = torch.linalg.eig(A)\n", "L, U = L.to(dtype=torch.double), U.to(dtype=torch.double) \n", "index = torch.argsort(L,descending=True)\n", "U = U[:,index[0:d]]" ] }, { "cell_type": "markdown", "id": "174aa2e7", "metadata": {}, "source": [ "## Function Set-Up\n", "\n", "Encode the optimization variables, and objective and constraint functions.\n", "\n", "Note: please strictly follow the format of comb_fn, which will be used in the PyGRANSO main algortihm." ] }, { "cell_type": "code", "execution_count": 3, "id": "76877185", "metadata": {}, "outputs": [], "source": [ "# variables and corresponding dimensions.\n", "var_in = {\"V\": [n,d]}\n", "\n", "def user_fn(X_struct,A,d):\n", " V = X_struct.V\n", "\n", " # objective function\n", " f = -torch.trace(V.T@A@V)\n", "\n", " # inequality constraint, matrix form\n", " ci = None\n", "\n", " # equality constraint\n", " ce = pygransoStruct()\n", " ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)\n", "\n", " return [f,ci,ce]\n", "\n", "comb_fn = lambda X_struct : user_fn(X_struct,A,d)" ] }, { "cell_type": "markdown", "id": "2b21c2ec", "metadata": {}, "source": [ "## User Options\n", "Specify user-defined options for PyGRANSO" ] }, { "cell_type": "code", "execution_count": 4, "id": "54137e9f", "metadata": {}, "outputs": [], "source": [ "opts = pygransoStruct()\n", "opts.torch_device = device\n", "opts.print_frequency = 1\n", "# opts.opt_tol = 1e-7\n", "opts.maxit = 3000\n", "# opts.mu0 = 10\n", "# opts.steering_c_viol = 0.02" ] }, { "cell_type": "markdown", "id": "be9ba1d7", "metadata": {}, "source": [ "## Main Algorithm" ] }, { "cell_type": "code", "execution_count": 5, "id": "8ce3b204", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗\n", "\u001b[0m\u001b[33m║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║\n", "\u001b[0m\u001b[33m║ the default is osqp. Users may provide their own wrapper for the QP solver. ║\n", "\u001b[0m\u001b[33m║ To disable this notice, set opts.quadprog_info_msg = False ║\n", "\u001b[0m\u001b[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝\n", "\u001b[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗\n", "PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║ \n", "Version 1.2.0 ║ \n", "Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣\n", "Problem specifications: ║ \n", " # of variables : 5 ║ \n", " # of inequality constraints : 0 ║ \n", " # of equality constraints : 1 ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " 0 ║ 1.000000 │ 1.16724813601 ║ -2.14592113443 ║ - │ 3.313169 ║ - │ 1 │ 0.000000 ║ 1 │ 4.801284 ║ \n", " 1 ║ 1.000000 │ -4.18220640118 ║ -27.2345373313 ║ - │ 23.05233 ║ S │ 1 │ 1.000000 ║ 1 │ 19.60679 ║ \n", " 2 ║ 0.810000 │ -7.44332699137 ║ -25.7415437768 ║ - │ 13.40732 ║ S │ 1 │ 1.000000 ║ 1 │ 2.115362 ║ \n", " 3 ║ 0.478297 │ -3.35759159998 ║ -29.6991742459 ║ - │ 10.84743 ║ S │ 1 │ 1.000000 ║ 1 │ 0.801258 ║ \n", " 4 ║ 0.313811 │ 0.44031618300 ║ -29.6849843151 ║ - │ 9.755779 ║ S │ 1 │ 1.000000 ║ 1 │ 0.181847 ║ \n", " 5 ║ 0.313811 │ -0.05442280769 ║ -25.6951163222 ║ - │ 8.008977 ║ S │ 2 │ 2.000000 ║ 1 │ 0.267744 ║ \n", " 6 ║ 0.313811 │ -0.60400882947 ║ -19.5195850208 ║ - │ 5.521444 ║ S │ 2 │ 2.000000 ║ 1 │ 0.129171 ║ \n", " 7 ║ 0.313811 │ -0.74392449978 ║ -15.4214205335 ║ - │ 4.095481 ║ S │ 1 │ 1.000000 ║ 1 │ 0.071143 ║ \n", " 8 ║ 0.313811 │ -0.80903198043 ║ -10.2208919613 ║ - │ 2.398392 ║ S │ 3 │ 4.000000 ║ 1 │ 0.098483 ║ \n", " 9 ║ 0.313811 │ -0.90831604867 ║ -3.84466777888 ║ - │ 0.298181 ║ S │ 3 │ 1.500000 ║ 1 │ 0.110139 ║ \n", " 10 ║ 0.313811 │ -0.94420742855 ║ -3.08652285415 ║ - │ 0.024376 ║ S │ 1 │ 1.000000 ║ 1 │ 0.034643 ║ \n", " 11 ║ 0.313811 │ -0.95190087009 ║ -3.04454808709 ║ - │ 0.003511 ║ S │ 1 │ 1.000000 ║ 1 │ 0.015493 ║ \n", " 12 ║ 0.313811 │ -0.95411578037 ║ -3.04298259421 ║ - │ 8.04e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.009830 ║ \n", " 13 ║ 0.313811 │ -0.95489372763 ║ -3.04415367869 ║ - │ 3.94e-04 ║ S │ 1 │ 1.000000 ║ 1 │ 0.003801 ║ \n", " 14 ║ 0.313811 │ -0.95497336218 ║ -3.04323431616 ║ - │ 2.58e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.001336 ║ \n", " 15 ║ 0.313811 │ -0.95498942531 ║ -3.04321860994 ║ - │ 4.82e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 8.48e-04 ║ \n", " 16 ║ 0.313811 │ -0.95499981196 ║ -3.04325113730 ║ - │ 4.64e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 6.81e-04 ║ \n", " 17 ║ 0.313811 │ -0.95500311823 ║ -3.04325200800 ║ - │ 1.61e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 3.10e-04 ║ \n", " 18 ║ 0.313811 │ -0.95500354216 ║ -3.04324886242 ║ - │ 1.97e-07 ║ S │ 1 │ 1.000000 ║ 1 │ 7.90e-05 ║ \n", " 19 ║ 0.313811 │ -0.95500356434 ║ -3.04324833316 ║ - │ 9.14e-09 ║ S │ 1 │ 1.000000 ║ 2 │ 5.28e-05 ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " 20 ║ 0.313811 │ -0.95500356478 ║ -3.04324830574 ║ - │ 9.81e-11 ║ S │ 7 │ 1.031250 ║ 2 │ 6.94e-06 ║ \n", " 21 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║ - │ 3.89e-12 ║ \u001b[33mSI\u001b[0m │ 1 │ 1.000000 ║ 3 │ 28.67710 ║ \n", " 22 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║ - │ 3.97e-12 ║ \u001b[33mSI\u001b[0m │ 2 │ 0.500000 ║ 4 │ 6.23e-06 ║ \n", " 23 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║ - │ 2.25e-12 ║ \u001b[33mSI\u001b[0m │ 2 │ 0.500000 ║ 5 │ 3.39e-06 ║ \n", " 24 ║ 0.313811 │ -0.95500356497 ║ -3.04324830605 ║ - │ 2.31e-12 ║ S │ 11 │ 9.77e-04 ║ 6 │ 3.85e-04 ║ \n", " 25 ║ 0.313811 │ -0.95500356497 ║ -3.04324830604 ║ - │ 3.60e-14 ║ S │ 8 │ 1.015625 ║ 7 │ 2.23e-07 ║ \n", " 26 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║ - │ 1.49e-14 ║ S │ 24 │ 3.58e-07 ║ 8 │ 4.571292 ║ \n", " 27 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║ - │ 5.55e-15 ║ S │ 13 │ 2.44e-04 ║ 9 │ 4.76e-07 ║ \n", " 28 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║ - │ 4.66e-15 ║ S │ 14 │ 1.22e-04 ║ 10 │ 4.75e-07 ║ \n", " 29 ║ 0.109419 │ -0.33298915332 ║ -3.04324830604 ║ - │ 2.53e-14 ║ \u001b[33mSI\u001b[0m │ 1 │ 1.000000 ║ 10 │ 3.95e-09 ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "F = final iterate, B = Best (to tolerance), MF = Most Feasible ║ \n", "Optimization results: ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " F ║ │ ║ -3.04324830604 ║ - │ 2.53e-14 ║ │ │ ║ │ ║ \n", " B ║ │ ║ -3.04324886242 ║ - │ 1.97e-07 ║ │ │ ║ │ ║ \n", " MF ║ │ ║ -3.04324830604 ║ - │ 0.000000 ║ │ │ ║ │ ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Iterations: 29 ║ \n", "Function evaluations: 109 ║ \n", "PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝\n", "Total Wall Time: 1.0340015888214111s\n", "torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 1.4317144979982078e-07\n", "torch.trace(V.T@A@V) = 3.0432483060419457\n", "torch.trace(U.T@A@U) = 3.0432483060418907\n", "sum of first d eigvals = 3.04324830604189\n", "sorted eigs = tensor([ 3.0432, 0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',\n", " dtype=torch.float64)\n" ] } ], "source": [ "start = time.time()\n", "soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)\n", "end = time.time()\n", "print(\"Total Wall Time: {}s\".format(end - start))\n", "\n", "V = torch.reshape(soln.final.x,(n,d))\n", "\n", "rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)\n", "print(\"torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}\".format(rel_dist))\n", "\n", "print(\"torch.trace(V.T@A@V) = {}\".format(torch.trace(V.T@A@V)))\n", "print(\"torch.trace(U.T@A@U) = {}\".format(torch.trace(U.T@A@U)))\n", "print(\"sum of first d eigvals = {}\".format(torch.sum(L[index[0:d]])))\n", "print(\"sorted eigs = {}\".format(L[index]))" ] }, { "cell_type": "markdown", "id": "f1c12544", "metadata": {}, "source": [ "## More Constraints\n", "**(Optional)** Exploring the pygranso performance on different number of constraints" ] }, { "cell_type": "code", "execution_count": 6, "id": "2945a9bb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[33m╔═════ QP SOLVER NOTICE ════════════════════════════════════════════════════════════════════════╗\n", "\u001b[0m\u001b[33m║ PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible interface, ║\n", "\u001b[0m\u001b[33m║ the default is osqp. Users may provide their own wrapper for the QP solver. ║\n", "\u001b[0m\u001b[33m║ To disable this notice, set opts.quadprog_info_msg = False ║\n", "\u001b[0m\u001b[33m╚═══════════════════════════════════════════════════════════════════════════════════════════════╝\n", "\u001b[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗\n", "PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation ║ \n", "Version 1.2.0 ║ \n", "Licensed under the AGPLv3, Copyright (C) 2021-2022 Tim Mitchell and Buyun Liang ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╣\n", "Problem specifications: ║ \n", " # of variables : 10 ║ \n", " # of inequality constraints : 0 ║ \n", " # of equality constraints : 4 ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " ║ <--- Penalty Function --> ║ ║ Total Violation ║ <--- Line Search ---> ║ <- Stationarity -> ║ \n", "Iter ║ Mu │ Value ║ Objective ║ Ineq │ Eq ║ SD │ Evals │ t ║ Grads │ Value ║ \n", "═════╬═══════════════════════════╬════════════════╬═════════════════╬═══════════════════════╬════════════════════╣\n", " 0 ║ 1.000000 │ 27.7034556094 ║ 0.81257641642 ║ - │ 8.581446 ║ - │ 1 │ 0.000000 ║ 1 │ 10.40129 ║ \n", " 10 ║ 0.282430 │ -0.30396065925 ║ -3.41481873164 ║ - │ 0.524987 ║ S │ 1 │ 1.000000 ║ 1 │ 0.937294 ║ \n", " 20 ║ 0.205891 │ -0.74610607144 ║ -3.69972896961 ║ - │ 0.007700 ║ S │ 1 │ 1.000000 ║ 1 │ 0.011535 ║ \n", " 30 ║ 0.058150 │ -0.22595631429 ║ -3.89008845464 ║ - │ 1.69e-04 ║ S │ 2 │ 0.500000 ║ 1 │ 0.038402 ║ \n", " 40 ║ 0.047101 │ -0.18419411173 ║ -3.91138416923 ║ - │ 2.92e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.022714 ║ \n", " 50 ║ 0.047101 │ -0.18471088554 ║ -3.92205275507 ║ - │ 1.73e-05 ║ S │ 1 │ 1.000000 ║ 1 │ 0.009407 ║ \n", " 60 ║ 0.014781 │ -0.05801100437 ║ -3.92495071255 ║ - │ 1.75e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 0.001456 ║ \n", " 70 ║ 0.014781 │ -0.05802583565 ║ -3.92589713488 ║ - │ 1.05e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 8.25e-04 ║ \n", " 80 ║ 0.014781 │ -0.05803300583 ║ -3.92636319358 ║ - │ 9.28e-07 ║ S │ 1 │ 1.000000 ║ 1 │ 6.60e-04 ║ \n", " 90 ║ 0.014781 │ -0.05803840995 ║ -3.92674667061 ║ - │ 1.13e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 7.99e-04 ║ \n", " 100 ║ 0.014781 │ -0.05804249204 ║ -3.92706587993 ║ - │ 1.49e-06 ║ S │ 1 │ 1.000000 ║ 1 │ 8.05e-04 ║ \n", " 110 ║ 0.010775 │ -0.04231489018 ║ -3.92730230715 ║ - │ 1.50e-06 ║ S │ 2 │ 2.000000 ║ 1 │ 5.00e-04 ║ \n", " 120 ║ 0.001061 │ -0.00416748732 ║ -3.92753388653 ║ - │ 5.44e-08 ║ \u001b[33mSI\u001b[0m │ 1 │ 1.000000 ║ 1 │ 2.77e-05 ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "F = final iterate, B = Best (to tolerance), MF = Most Feasible ║ \n", "Optimization results: ║ \n", "═════╦═══════════════════════════╦════════════════╦═════════════════╦═══════════════════════╦════════════════════╣\n", " F ║ │ ║ -3.92761322417 ║ - │ 3.62e-09 ║ │ │ ║ │ ║ \n", " B ║ │ ║ -3.92772408038 ║ - │ 5.22e-07 ║ │ │ ║ │ ║ \n", " MF ║ │ ║ -3.92743429333 ║ - │ 2.74e-09 ║ │ │ ║ │ ║ \n", "═════╩═══════════════════════════╩════════════════╩═════════════════╩═══════════════════════╩════════════════════╣\n", "Iterations: 126 ║ \n", "Function evaluations: 183 ║ \n", "PyGRANSO termination code: 0 --- converged to stationarity and feasibility tolerances. ║ \n", "═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝\n", "Total Wall Time: 1.4121968746185303s\n", "torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = 0.04411636799335658\n", "torch.trace(V.T@A@V) = 3.9276132241693347\n", "torch.trace(U.T@A@U) = 3.932280709191555\n", "sum of first d eigvals = 3.9322807091915544\n", "sorted eigs = tensor([ 3.0432, 0.8890, -0.4730, -0.9598, -1.8722], device='cuda:0',\n", " dtype=torch.float64)\n" ] } ], "source": [ "device = torch.device('cuda')\n", "n = 5\n", "d = 2\n", "torch.manual_seed(0)\n", "# All the user-provided data (vector/matrix/tensor) must be in torch tensor format.\n", "# As PyTorch tensor is single precision by default, one must explicitly set `dtype=torch.double`.\n", "# Also, please make sure the device of provided torch tensor is the same as opts.torch_device.\n", "A = torch.randn(n,n).to(device=device, dtype=torch.double)\n", "A = (A + A.T)/2\n", "L, U = torch.linalg.eig(A)\n", "L, U = L.to(dtype=torch.double), U.to(dtype=torch.double) \n", "index = torch.argsort(L,descending=True)\n", "U = U[:,index[0:d]]\n", "\n", "# variables and corresponding dimensions.\n", "var_in = {\"V\": [n,d]}\n", "\n", "def user_fn(X_struct,A,d):\n", " V = X_struct.V\n", "\n", " # objective function\n", " f = -torch.trace(V.T@A@V)\n", "\n", " # inequality constraint, matrix form\n", " ci = None\n", "\n", " # equality constraint\n", " ce = pygransoStruct()\n", " ce.c1 = V.T@V - torch.eye(d).to(device=device, dtype=torch.double)\n", "\n", " return [f,ci,ce]\n", "\n", "comb_fn = lambda X_struct : user_fn(X_struct,A,d)\n", "\n", "opts = pygransoStruct()\n", "opts.torch_device = device\n", "opts.print_frequency = 10\n", "opts.opt_tol = 5e-6\n", "opts.maxit = 1000\n", "# opts.mu0 = 10\n", "# opts.steering_c_viol = 0.02\n", "\n", "start = time.time()\n", "soln = pygranso(var_spec = var_in,combined_fn = comb_fn,user_opts = opts)\n", "end = time.time()\n", "print(\"Total Wall Time: {}s\".format(end - start))\n", "\n", "V = torch.reshape(soln.final.x,(n,d))\n", "\n", "rel_dist = torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T)\n", "print(\"torch.norm(V@V.T - U@U.T)/torch.norm(U@U.T) = {}\".format(rel_dist))\n", "\n", "print(\"torch.trace(V.T@A@V) = {}\".format(torch.trace(V.T@A@V)))\n", "print(\"torch.trace(U.T@A@U) = {}\".format(torch.trace(U.T@A@U)))\n", "print(\"sum of first d eigvals = {}\".format(torch.sum(L[index[0:d]])))\n", "print(\"sorted eigs = {}\".format(L[index]))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }