Recursive Least Squares for ML


2020.10.15 -- I rewrote a lot of this post. The motivation is generic and expanded, and the math has been massaged a bit. This makes \(\ell_2\) regularization now included in both sets of recursive equations, as well as giving expressions that are far less numerically unstable.

Why I'm writing this

Recently I've been working with the recursive formulation of least-squares regression. At the time of writing, there are a lot of solid resources for the signal processing view of Recursive Least Squares (like Wikipedia's RLS article), but I was unable to find a good write-up from the machine learning view (e.g. as a differentiable closed-form solver). This is my attempt to frame things in concepts and language friendlier to machine-learning-minded people like myself.

Pre-requisite knowledge

To follow this post, you probably will want to be somewhat familiar with the following topics:

Additionally, familiarity with the following will probably enhance your understanding:

The ML motivation and context for RLS

Skip this (long) section if you are more interested in the RLS algorithm than its application in ML (you can always come back!).

In the simplest settings of supervised learning, we usually think of a "machine learning problem" as finding a good way to approximate the mapping between input and output for a fixed dataset of observed input-output pairs. Most supervised learning algorithms boil down to coming up with some structure for the input-to-output mapping function with adjustable parameters (e.g. linear models, tree-based models, neural nets, etc.), and then efficiently finding a specific set of parameters which give rise to low average error between our function's output and the true outputs in the observed dataset. Somewhat confusingly, we use the word model to refer to the combination of a functional form and a specific set of parameters, not just the functional form (though in statistical learning theory, we call the functional form by itself the hypothesis class and the selected model a hypothesis, which I find much clearer). In more theoretical terms, this is the approach of directly applying the Empirical Risk Minimization (ERM) principle to a fixed dataset to find a good approximation of the input-to-output function.

In many real-world settings, however, our dataset is far from a static collection of observations. Data is usually generated by an ongoing process, and that process can change. There are a lot of terms for this change -- dataset shift, covariate shift, concept drift, and more. In a nutshell, though, we end up in a situation in which it's best to update our approximation of the input-output function as more data becomes available.

There are a number of ML subfields that provide useful formulations of problems like this. Online Learning and Reinforcement Learning are possibly the most popular and well-established of these. Online Learning re-formulates the problem at hand as finding a function which minimizes the total error when faced with a sequence of many (potentially adversarial) error responses given in response to the current model's prediction. Often the goal is to find a sequence of model functions at each time point whose errors converge to being no worse than the errors that the best possible single model would give rise to. Reinforcement Learning re-formulates the problem as back-and-forth interaction between an agent and an environment. The task of learning an input-to-output model turns into the task of learning a policy which maps state to action, while the environment maps that action to a reward and a new state. A good policy is one which achieves a high cumulative reward.

Unfortunately, these subfields' approaches are often overkill for what we see in the real world -- a slight overstepping of the traditional bounds of supervised learning due to an ever-growing dataset and a dash of dataset shift. Why are OL and RL overkill? Online Learning is concerned with a possibly adversarial response and focuses on minimizing regret compared to a single good model used for all time. This makes it an awkward match for many not-so-adversarial data-generating processes, where it can be better to worry about low absolute error than low regret. Reinforcement Learning brings to bear even more mathematical machinery to deal with the way the world reacts to the outputs of our model. When we know the world doesn't particularly care about our predictions, this extra machinery becomes cumbersome.

Faced with the reality of ML problems that slip between the cracks of sub-disciplines, practitioners often turn to some kind of batch learning approach in which they periodically gather a fresh dataset that includes recently-generated data and then learn a new model from scratch using this new dataset. In other cases, especially "big data" applications like ad-tech, practitioners will instead turn to computationally efficient streaming algorithms (which are sometimes called "online" algorithms just to confuse us!). These algorithms can make small updates to the parameters of the input-output approximation function every time a new data point is collected -- all without referring back to historical data. These streaming algorithms often guarantee that the model at any time will be similar to the model one would get by training a batch version of the same algorithm on a large dataset containing all data since the beginning of the streaming. RLS can be seen as one such streaming algorithm. It's certainly not the only one, nor the most popular one, but it does offer a number of compelling properties.

Why RLS?

RLS let's us produce the same sequence of linear models we would get if we batch-learned a new model on our entire cumulative dataset every time a new data point is added, but using about the same amount of computation as fitting just the final model on all the data. Whenever a new data point is acquired, the model for \(i+1\) points can be computed quickly and cheaply using the \(i\)th model, without looking at any data besides the new point.

Compared to online gradient descent, RLS is more computationally costly and is only defined for squared loss (so no generalized linear models like logistic regression). On the plus side, though, RLS is truly a closed-form solver. It doesn't require tuning a learning rate, and it is guaranteed to output the exact model parameters which would be given by fitting linear regression on all the data in one go.

By way of the signal-processing literature, RLS also comes with a forgetting factor, which we can use to exponentially down-weight older examples (effectively to zero, far enough back in history). This lets RLS be somewhat analogous to a batch learning approach which drops old "stale" data points.

The ML setting for RLS

The data

We consider a process that over time generates a sequence of vectors containing predictive information, as well as corresponding scalar measurements describing outcomes we'd like to predict using the information in the vectors. We'll notate these sequences as follows.

\(\pmb{x}_1, \pmb{x}_2, \pmb{x}_3, \ldots\) and \(y_1, y_2, y_3, \ldots\)

For example, perhaps \(\pmb{x}_i\) summarizes pertinent sports statistics for a team just before its \(i\)th game, while \(y_i\) is that team's score in the \(i\)th game. We'd like to predict that team's score before they play, using the statistics available before the game.

The ML objective

We seek a system that can predict outcome measurements using information vectors, with minimal error. While ultimately we care about the \(i\)th model's prediction error on the \(i\)th outcome, when it comes to training we have to settle for using relevant historical examples to quantify the error we want to minimize.

Design choices for RLS

In this setting, there are many model structures we could use, many error metrics we could use, and many ways we could decide on relevant historical examples. RLS arises as the solution in this setting when we make a certain set of "design choices."

Putting this all together, the \(i\)th model is defined by a parameter vector \(\pmb{\theta}_i\) which minimizes exponentially-weighted average error, which we'll call \(L_i(\pmb{\theta})\). To make the notation simpler, we'll use the total error rather than the average, but the difference is only a constant that doesn't affect the minimization.

\(\pmb{\theta}_i = \underset{\pmb{\theta}}{\operatorname{argmin}}~L_i(\pmb{\theta}) = \underset{\pmb{\theta}}{\operatorname{argmin}}~\sum_{t=1}^i \lambda^{i - t} (y_t - \pmb{\theta}^\intercal\pmb{x}_t)^2\)

Generalizing to \(\ell_2\) regularization

One of two generalizations I add in this derivation that I didn't find in RLS derivations on the internet is \(\ell_2\) regularization. In many cases, ML folks have found that constraining the \(\ell_2\) magnitude of a linear model's parameters can improve overfitting (and also some of the matrix inversions used to compute the parameters). Using a Lagrange multiplier \(\alpha\) to incorporate the typical \(\ell_2\) norm constraint into the loss function, we get an expanded definition of \(L_i(\pmb{\theta})\).

\(L_i(\pmb{\theta}) = \sum_{t=1}^i \lambda^{i - t} (y_t - \pmb{\theta}^\intercal\pmb{x}_t)^2 + \alpha ||\pmb{\theta}||_2^2\)

Since we use total error rather than average error in a case in which the dataset grows, we need to break up the regularization term into something added to each term in the sum. To do this, we'll use a new \(\ell_2\) penalty term \(\beta\).

\(L_i(\pmb{\theta}) = \sum_{t=1}^i \lambda^{i - t} \left( (y_t - \pmb{\theta}^\intercal\pmb{x}_t)^2 + \beta ||\pmb{\theta}||_2^2\right)\), where \(\alpha = \beta \sum_{t=1}^i \lambda^{i - t}\)

By defining \(\beta\) in this way, we see that \(\alpha\) does not stay fixed as \(i\) increases, which seems odd. However, this actually accounts for the fact that we're using the weighted sum of error rather the weighted average.

Deriving non-recursive least squares

Let's start by deriving the \(\pmb{\theta} = (X^\intercal W X)^{-1}X^\intercal W Y\) weighted linear regression closed-form solution, but with a number of non-standard choices in expressions to account for \(\ell_2\) regularization.

First, some definitions.

\(\pmb{x}_i \in \mathbb{R}^d, \pmb{\theta}_i \in \mathbb{R}^d\)

\(\pmb{w}'_i = \begin{bmatrix} \lambda^{i-1} & \lambda^{i-2} & \cdots & \lambda & 1 \end{bmatrix}^\intercal, W'_i = \operatorname{diag}(\pmb{w}'_i)\)

\(X'_i = \begin{bmatrix} \pmb{x}_1 & \pmb{x}_2 & \cdots & \pmb{x}_i \end{bmatrix}^\intercal, Y'_i = \begin{bmatrix} y_1 & y_2 & \cdots & y_i \end{bmatrix}^\intercal\)

Note the primes in the above names. We're actually going to put our \(\ell_2\) regularization right into the \(X\), \(Y\), and \(W\) matrices (see this CrossValidated post for inspiration). Not only does this new notation keep the equations short, it also makes life a lot easier when we derive the recursive formulas.

We note that \(\beta \operatorname{Tr}(W_i')\) is our \(\ell_2\) regularization Lagrange multiplier at time \(i\). We'll append \(d\) of these values on the end of our weight vector. We'll also pad \(Y'\) with \(d\) zeros, and \(X'\) with the \(d \times d\) identity matrix. When \(\beta = 0\), all of these additions are zero and have no effect on the average error or its gradient.

\(v_i = \beta \operatorname{Tr}(W'_i)\)

\(\pmb{v}_i = \begin{bmatrix} v_i & v_i & \cdots & v_i \end{bmatrix}^\intercal\), \(\pmb{v}_i \in \mathbb{R}^d\)

\(\pmb{0} = \begin{bmatrix} 0 & 0 & \cdots & 0 \end{bmatrix}^\intercal\), \(\pmb{0} \in \mathbb{R}^d\)

\(\pmb{w}_i = \begin{bmatrix} \pmb{w}'_i \\ \pmb{v}_i \end{bmatrix}, W_i = \operatorname{diag}(\pmb{w}_i)\)

\(X_i = \begin{bmatrix} X'_i \\ I \end{bmatrix}, Y_i = \begin{bmatrix} Y'_i \\ \pmb{0} \end{bmatrix}\)

Now we can define our loss in matrix notation, and see that the \(\ell_2\) regularization term has been pushed inside the new matrices.

\(L_i(\pmb{\theta}) = (Y'_i - X'_i \pmb{\theta})^\intercal W'_i (Y'_i - X'_i \pmb{\theta}) + \beta \operatorname{Tr}(W'_i) ||\pmb{\theta}||_2^2\)

\(= (Y'_i - X'_i \pmb{\theta})^\intercal W'_i (Y'_i - X'_i \pmb{\theta}) + \beta \operatorname{Tr}(W'_i) \pmb{\theta}^\intercal\pmb{\theta}\)

\(= (Y_i - X_i \pmb{\theta})^\intercal W_i (Y_i - X_i \pmb{\theta})\)

Closed-form least-squares solution

By first order optimality conditions, \(\nabla L_i(\pmb{\theta}_i) = 0\). We can use this fact to derive a closed-form solution for \(\pmb{\theta}_i\).

\(\nabla L_i(\pmb{\theta}_i) = -2 X_i ^\intercal W_i Y_i + 2 X_i ^\intercal W_i X_i \pmb{\theta}_i = 0\)

\(X_i ^\intercal W_i X_i \pmb{\theta}_i = X_i ^\intercal W_i Y_i\)

\(\pmb{\theta}_i = (X_i ^\intercal W_i X_i )^{-1} X_i ^\intercal W_i Y_i\)

We can also see what this looks like in the conventional definition of \(X, Y, W\).

\(\pmb{\theta}_i = \left( \begin{bmatrix} X'_i~^\intercal & I \end{bmatrix} \operatorname{diag}\left(\begin{bmatrix} \pmb{w}'_i \\ \pmb{v}_i \end{bmatrix}\right) \begin{bmatrix} X'_i \\ I \end{bmatrix} \right)^{-1} \begin{bmatrix} X'_i~^\intercal & I \end{bmatrix} \operatorname{diag} \left(\begin{bmatrix} \pmb{w}'_i \\ \pmb{v}_i \end{bmatrix}\right) \begin{bmatrix} Y'_i & \pmb{0} \end{bmatrix}\)

\(= (X'_i ~^\intercal W'_i X'_i + \beta \operatorname{Tr}(W'_i) I)^{-1} X'_i ~^\intercal W'_i Y'_i\)

Lastly, we'll simply the expression in terms of \(R_i = X_i ^\intercal W_i X_i\) and \(P_i = R_i^{-1}\). We shall refer to \(R_i\) as the covariance matrix, since when \(\beta =0\) and \(X_i\) has been shifted to have zero mean in all dimensions, it represents the un-normalized time-weighted sample covariance matrix. \(P_i\) is called the precision matrix, as it is the inverse of the covariance matrix.

\(\pmb{\theta}_i = P_i X_i ^\intercal W_i Y_i\)

Deriving a Recursive Solution

As \(i\) increases, so does the size of \(X_i\), \(Y_i\), as well as the cost of computing \(\pmb{\theta}_i\) from them. This is not desirable! The computational complexity is something like \(\mathcal{O}(d^2n + d^3)\) due to the ever-growing matrix multiplication to compute \(R_i\) and the matrix-inverse to invert it. We'd like to slice that down to something much lighter so we can run it frequently as new data becomes available. What we'd like is a recursive expression defining \(\pmb{\theta}_{i+1}\) in terms of \(\pmb{x}_{i+k}\), \(y_{i+k}\), and \(\pmb{\theta}_i\), so that the computation doesn't grow with \(n\). Luckily, it turns out we can get just that if we keep track of the covariance or precision matrix as well as the parameter vector.

Generalizing to batch updates

The second of two generalizations I add in this derivation that I didn't find in RLS derivations on the internet is batch updating. In many settings, it may actually be desirable update the parameter vector for more than one new point at once. For example, if database updates are done in batches to optimize performance and many new points become available simultaneously, or if we don't expect the parameter vector to change much point-to-point and we can run a batch update in a single call to an optimized linear algebra library much faster than we could run iterative updates in a loop.

The generalization also happens to not add too much complexity. We just express \(\pmb{\theta}_{i+k}\) in terms of \(\pmb{\theta}_i\) rather than \(\pmb{\theta}_{i+1}\) in terms of \(\pmb{\theta}_i\). To keep the notation clean in batch form, we'll define an implicitly-indexed notation for matrices describing \(k\) recent examples.

\(X' = \begin{bmatrix} \pmb{x}_{i + 1} & \pmb{x}_{i+2} & \cdots & \pmb{x}_{i+k} \end{bmatrix}^\intercal, Y' = \begin{bmatrix} y_{i + 1} & y_{i + 2} & \cdots & y_{i + k} \end{bmatrix}^\intercal, \pmb{w}' = \begin{bmatrix} \lambda^{k - 1} & \lambda^{k - 2} & \cdots & 1 \end{bmatrix}^\intercal, \pmb{v} = \pmb{v}_k\)

\(\pmb{w} = \begin{bmatrix} \pmb{w}' \\ \pmb{v} \end{bmatrix}, W = \operatorname{diag}(\pmb{w})\)

\(X = \begin{bmatrix} X' \\ I \end{bmatrix}, Y = \begin{bmatrix} Y' \\ \pmb{0} \end{bmatrix}\)

\(R = X ^\intercal W X\)

This lets us cleanly express \(X_{i+k}\), \(Y_{i+k}\), and \(W_{i+k}\) recursively.

\(\pmb{v}_{i+k} = \lambda^k \pmb{v}_i + \pmb{v}\)

\(\pmb{w}_{i+k} = \begin{bmatrix} \lambda^k \pmb{w}'_{i} \\ \pmb{w'} \\ \pmb{v}_{i+k} \end{bmatrix}, W_{i+k} = \operatorname{diag}(\pmb{w}_{i+k})\)

\(X_{i+k} = \begin{bmatrix} X'_i \\ X' \\ I \end{bmatrix}, Y_{i+k} = \begin{bmatrix} Y'_i \\ Y' \\ \pmb{0} \end{bmatrix}\)

The recursion relationship for \(\pmb{\theta}\)

First we show that \(R_{i+k} = \lambda^k R_i + R\).

\(R_{i + k} = X_{i+k} ^\intercal W_{i+k} X_{i+k}\)

\(= \begin{bmatrix} X'_{i}~^\intercal & X'~^\intercal & I \end{bmatrix} \operatorname{diag}\left(\begin{bmatrix} \lambda^k \pmb{w}'_{i} \\ \pmb{w'} \\ \pmb{v}_{i+k} \end{bmatrix}\right) \begin{bmatrix} X'_i \\ X' \\ I \end{bmatrix}\)

\(= X'_i~^\intercal \operatorname{diag}(\lambda^k \pmb{w}'_{i}) X'_i + X'~^\intercal \operatorname{diag}(\pmb{w'})X' + I \operatorname{diag}(\lambda^k \pmb{v}_i + \pmb{v})I\)

\(= \lambda^k(X'_i~^\intercal \operatorname{diag}(\pmb{w}'_{i}) X'_i + I \operatorname{diag}(\pmb{v}_i)I) + (X'~^\intercal \operatorname{diag}(\pmb{w}')X' + I \operatorname{diag}(\pmb{v})I)\)

\(= \lambda^k X_i~^\intercal W_i X_i + X~^\intercal W X\)

\(=\lambda^k R_i + R\)

Then we look at the equation for \(\pmb{\theta}_{i+k}\).

\(\pmb{\theta}_{i+k} = P_{i+k} X_{i+k} ^\intercal W_{i+k} Y_{i+k}\)

\(= P_{i+k} \begin{bmatrix} X'_{i}~^\intercal & X'~^\intercal & I \end{bmatrix} \operatorname{diag}\left(\begin{bmatrix} \lambda^k \pmb{w}'_{i} \\ \pmb{w'} \\ \pmb{v}_{i+k} \end{bmatrix}\right) \begin{bmatrix} Y'_i \\ Y' \\ \pmb{0} \end{bmatrix}\)

\(= P_{i+k} \begin{bmatrix} X'_{i}~^\intercal & X'~^\intercal \end{bmatrix} \operatorname{diag}\left(\begin{bmatrix} \lambda^k \pmb{w}'_{i} \\ \pmb{w'} \end{bmatrix}\right) \begin{bmatrix} Y'_i \\ Y' \end{bmatrix} + 0\)

\(= P_{i+k} \left( \lambda^k X'_i~^\intercal W'_i Y'_i + X'~^\intercal W' Y' \right)\)

\(= P_{i+k} \left(\lambda^k X'_i~^\intercal W'_i Y'_i + X~^\intercal W Y \right)\)

Since the above expression still depends on our earlier data matrices through the expression \(X'_i~^\intercal W'_i Y'_i\), we need to massage this a bit further. We note the following.

\({1 \over \lambda^k} (R_{i+k} - R) = R_i\) (since \(R_{i+k} = \lambda^k R_i + R\))

\(X'_i~^\intercal W'_i Y'_i = X_i^\intercal W_i Y_i = R_i R_i^{-1}X_i^\intercal W_i Y_i = R_i \pmb{\theta}_i = {1 \over \lambda^k} (R_{i+k} - R) \pmb{\theta}_i = {1 \over \lambda^k} (R_{i+k} - X^\intercal W X) \pmb{\theta}_i\)

By plugging this equivalent statement in, we are able to drop the dependence on all previous state except the precision matrix and parameter vector.

\(\pmb{\theta}_{i+k} = P_{i+k} \left(\lambda^k X'_i~^\intercal W'_i Y'_i + X~^\intercal W Y \right)\)

\(= P_{i+k} \left( \lambda^k \left( {1 \over \lambda^k} (R_{i+k} - X^\intercal W X) \pmb{\theta}_i \right) + X~^\intercal W Y \right)\)

\(= P_{i+k}\left((R_{i+k} - X^\intercal W X) \pmb{\theta}_i + X~^\intercal W Y\right)\)

\(= (I - P_{i+k}X^\intercal W X) \pmb{\theta}_i + P_{i+k} X~^\intercal W Y\)

\(= \pmb{\theta}_i - P_{i+k}X^\intercal W X \pmb{\theta}_i + P_{i+k} X~^\intercal W Y\)

\(= \pmb{\theta}_i + P_{i+k}X^\intercal W (Y - X \pmb{\theta}_i)\)

Faster recursive least squares

Now we can update \(\pmb{\theta}\) after \(k\) examples by recursively updating the covariance matrix, inverting it, and recursively updating our parameter vector.

\(R_{i+k} = \lambda^k R_i + R\)

\(\pmb{\theta}_{i+k} = \pmb{\theta}_i + R_{i+k}^{-1} X^\intercal W (Y - X \pmb{\theta}_i)\)

Using this approach, we incur matrix multiplication computation complexity of \(\mathcal{O}(kd^2 + k^2d)\) and matrix inverse complexity of \(\mathcal{O}(d^3)\). When \(d > k\), inverting the covariance matrix dominates. Can we improve this? Yes! We can actually speed things up with a cleverer way of updating the precision matrix. If we apply the Woodbury matrix identity to the update of the precision matrix we get the classical recursive least squares algorithm. Not only is this \(\mathcal{O}(kd^2 + k^2d + k^3)\) instead of \(\mathcal{O}(d^3 + kd^2 + k^2d)\), but it's also proven to be numerically stable when run on a computer.

The Woodbury matrix identity states that when we have already computed the inverse \(A_0^{-1}\) of matrix \(A_0\), and we wish to compute the inverse of an updated \(A = A_0 + UCV\), we can do so efficiently via the following equation.

\(A^{-1} = A_0^{-1} - A_0^{-1}U(C^{-1} + VA_0^{-1}U)^{-1}VA_0^{-1}\).

For "thin" \(U\) and "short" \(V\), this results in much smaller matrix operations than the direct inversion of \(A\).

If we squint at our equations above, we see there is a match between the covariance matrix update and the Woodberry equation.

\(A = R_{i+k}, A_0 = \lambda^k R_i, U = X^\intercal, C = W, V = X\)

\(A = R_{i+k} = A_0 + UCV = \lambda^k R_i + X^\intercal W X\)

Plugging into the Woodberry result, we get an update for \(P_{i+k}\).

\(A^{-1} = R_{i+k}^{-1} = P_{i+k} = A_0^{-1} - A_0^{-1}U(C^{-1} + VA_0^{-1}U)^{-1}VA_0^{-1}\)

\(=(\lambda^k R_i)^{-1} - (\lambda^k R_i)^{-1} X^\intercal (W^{-1} + X (\lambda^k R_i)^{-1} X^\intercal)^{-1} X (\lambda^k R_i)^{-1}\)

\(={1 \over \lambda^k} P_i - {1 \over \lambda^k} P_i X^\intercal (W^{-1} + X {1 \over \lambda^k} P_i X^\intercal)^{-1} X {1 \over \lambda^k} P_i\)

We can clean this up a bit. Let's start by defining a new \(U\) matrix so that \(W^{-1} = {1 \over \lambda^k} U\). Let's also look at how to construct \(U\).

\(\pmb{u'} = \begin{bmatrix} \lambda & \lambda^2 & \cdots & \lambda^k \end{bmatrix}^\intercal = {\lambda^k \over \pmb{w'}}\)

\(\mu = {\lambda^k \over \beta \sum_{t=1}^k\lambda^{k-t}} = {\lambda^k \over v_k}\), \(\pmb{\mu} = \begin{bmatrix} \mu & \mu & \cdots & \mu \end{bmatrix}^\intercal\), \(\pmb{\mu} \in \mathbb{R}^{d}\)

\(\pmb{u} = \begin{bmatrix} \pmb{u}' \\ \pmb{\mu} \end{bmatrix} = {\lambda^k \over \pmb{w}}\), \(U = \operatorname{diag}(\pmb{u}) = \lambda^k W^{-1}\)

Using \(U\) our update for the precision matrix simplifies nicely.

\(P_{i+k} = {1 \over \lambda^k} P_i - {1 \over \lambda^k} P_i X^\intercal (W^{-1} + X {1 \over \lambda^k} P_i X^\intercal)^{-1} X {1 \over \lambda^k} P_i\)

\(= {1 \over \lambda^k} P_i - {1 \over \lambda^k} P_i X^\intercal ({1 \over \lambda^k} U + X {1 \over \lambda^k} P_i X^\intercal)^{-1} X {1 \over \lambda^k} P_i\)

\(= {1 \over \lambda^k} P_i - {1 \over \lambda^k} P_i X^\intercal \lambda^k(U + X P_i X^\intercal)^{-1} X {1 \over \lambda^k} P_i\)

\(= {1 \over \lambda^k} \left( P_i - P_i X^\intercal (U + X P_i X^\intercal)^{-1} X P_i \right)\)

Here we see that the only matrix inverse remaining is now \(\mathcal{O}(k^3)\), with the matrix multiplications still \(\mathcal{O}(kd^2 + k^2d)\). Success!