LBFGS++
Loading...
Searching...
No Matches
BFGSMat.h
1// Copyright (C) 2020-2023 Yixuan Qiu <yixuan.qiu@cos.name>
2// Under MIT license
3
4#ifndef LBFGSPP_BFGS_MAT_H
5#define LBFGSPP_BFGS_MAT_H
6
7#include <vector>
8#include <Eigen/Core>
9#include "BKLDLT.h"
10
12
13namespace LBFGSpp {
14
15//
16// An *implicit* representation of the BFGS approximation to the Hessian matrix B
17//
18// B = theta * I - W * M * W'
19// H = inv(B)
20//
21// Reference:
22// [1] D. C. Liu and J. Nocedal (1989). On the limited memory BFGS method for large scale optimization.
23// [2] R. H. Byrd, P. Lu, and J. Nocedal (1995). A limited memory algorithm for bound constrained optimization.
24//
25template <typename Scalar, bool LBFGSB = false>
26class BFGSMat
27{
28private:
29 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
30 using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
31 using RefConstVec = Eigen::Ref<const Vector>;
32 using IndexSet = std::vector<int>;
33
34 int m_m; // Maximum number of correction vectors
35 Scalar m_theta; // theta * I is the initial approximation to the Hessian matrix
36 Matrix m_s; // History of the s vectors
37 Matrix m_y; // History of the y vectors
38 Vector m_ys; // History of the s'y values
39 Vector m_alpha; // Temporary values used in computing H * v
40 int m_ncorr; // Number of correction vectors in the history, m_ncorr <= m
41 int m_ptr; // A Pointer to locate the most recent history, 1 <= m_ptr <= m
42 // Details: s and y vectors are stored in cyclic order.
43 // For example, if the current s-vector is stored in m_s[, m-1],
44 // then in the next iteration m_s[, 0] will be overwritten.
45 // m_s[, m_ptr-1] points to the most recent history,
46 // and m_s[, m_ptr % m] points to the most distant one.
47
48 //========== The following members are only used in L-BFGS-B algorithm ==========//
49 Matrix m_permMinv; // Permutated M inverse
50 BKLDLT<Scalar> m_permMsolver; // Represents the permutated M matrix
51
52public:
53 // Constructor
54 BFGSMat() {}
55
56 // Reset internal variables
57 // n: dimension of the vector to be optimized
58 // m: maximum number of corrections to approximate the Hessian matrix
59 inline void reset(int n, int m)
60 {
61 m_m = m;
62 m_theta = Scalar(1);
63 m_s.resize(n, m);
64 m_y.resize(n, m);
65 m_ys.resize(m);
66 m_alpha.resize(m);
67 m_ncorr = 0;
68 m_ptr = m; // This makes sure that m_ptr % m == 0 in the first step
69
70 if (LBFGSB)
71 {
72 m_permMinv.resize(2 * m, 2 * m);
73 m_permMinv.setZero();
74 m_permMinv.diagonal().setOnes();
75 }
76 }
77
78 // Add correction vectors to the BFGS matrix
79 inline void add_correction(const RefConstVec& s, const RefConstVec& y)
80 {
81 const int loc = m_ptr % m_m;
82
83 m_s.col(loc).noalias() = s;
84 m_y.col(loc).noalias() = y;
85
86 // ys = y's = 1/rho
87 const Scalar ys = m_s.col(loc).dot(m_y.col(loc));
88 m_ys[loc] = ys;
89
90 m_theta = m_y.col(loc).squaredNorm() / ys;
91
92 if (m_ncorr < m_m)
93 m_ncorr++;
94
95 m_ptr = loc + 1;
96
97 if (LBFGSB)
98 {
99 // Minv = [-D L']
100 // [ L theta*S'S]
101
102 // Copy -D
103 // Let S=[s[0], ..., s[m-1]], Y=[y[0], ..., y[m-1]]
104 // D = [s[0]'y[0], ..., s[m-1]'y[m-1]]
105 m_permMinv(loc, loc) = -ys;
106
107 // Update S'S
108 // We only store S'S in Minv, and multiply theta when LU decomposition is performed
109 Vector Ss = m_s.leftCols(m_ncorr).transpose() * m_s.col(loc);
110 m_permMinv.block(m_m + loc, m_m, 1, m_ncorr).noalias() = Ss.transpose();
111 m_permMinv.block(m_m, m_m + loc, m_ncorr, 1).noalias() = Ss;
112
113 // Compute L
114 // L = [ 0 ]
115 // [ s[1]'y[0] 0 ]
116 // [ s[2]'y[0] s[2]'y[1] ]
117 // ...
118 // [s[m-1]'y[0] ... ... ... ... ... s[m-1]'y[m-2] 0]
119 //
120 // L_next = [ 0 ]
121 // [s[2]'y[1] 0 ]
122 // [s[3]'y[1] s[3]'y[2] ]
123 // ...
124 // [s[m]'y[1] ... ... ... ... ... s[m]'y[m-1] 0]
125 const int len = m_ncorr - 1;
126 // First zero out the column of oldest y
127 if (m_ncorr >= m_m)
128 m_permMinv.block(m_m, loc, m_m, 1).setZero();
129 // Compute the row associated with new s
130 // The current row is loc
131 // End with column (loc + m - 1) % m
132 // Length is len
133 int yloc = (loc + m_m - 1) % m_m;
134 for (int i = 0; i < len; i++)
135 {
136 m_permMinv(m_m + loc, yloc) = m_s.col(loc).dot(m_y.col(yloc));
137 yloc = (yloc + m_m - 1) % m_m;
138 }
139
140 // Matrix LDLT factorization
141 m_permMinv.block(m_m, m_m, m_m, m_m) *= m_theta;
142 m_permMsolver.compute(m_permMinv);
143 m_permMinv.block(m_m, m_m, m_m, m_m) /= m_theta;
144 }
145 }
146
147 // Recursive formula to compute a * H * v, where a is a scalar, and v is [n x 1]
148 // H0 = (1/theta) * I is the initial approximation to H
149 // Algorithm 7.4 of Nocedal, J., & Wright, S. (2006). Numerical optimization.
150 inline void apply_Hv(const Vector& v, const Scalar& a, Vector& res)
151 {
152 res.resize(v.size());
153
154 // L-BFGS two-loop recursion
155
156 // Loop 1
157 res.noalias() = a * v;
158 int j = m_ptr % m_m;
159 for (int i = 0; i < m_ncorr; i++)
160 {
161 j = (j + m_m - 1) % m_m;
162 m_alpha[j] = m_s.col(j).dot(res) / m_ys[j];
163 res.noalias() -= m_alpha[j] * m_y.col(j);
164 }
165
166 // Apply initial H0
167 res /= m_theta;
168
169 // Loop 2
170 for (int i = 0; i < m_ncorr; i++)
171 {
172 const Scalar beta = m_y.col(j).dot(res) / m_ys[j];
173 res.noalias() += (m_alpha[j] - beta) * m_s.col(j);
174 j = (j + 1) % m_m;
175 }
176 }
177
178 //========== The following functions are only used in L-BFGS-B algorithm ==========//
179
180 // Return the value of theta
181 inline Scalar theta() const { return m_theta; }
182
183 // Return current number of correction vectors
184 inline int num_corrections() const { return m_ncorr; }
185
186 // W = [Y, theta * S]
187 // W [n x (2*ncorr)], v [n x 1], res [(2*ncorr) x 1]
188 // res preserves the ordering of Y and S columns
189 inline void apply_Wtv(const Vector& v, Vector& res) const
190 {
191 res.resize(2 * m_ncorr);
192 res.head(m_ncorr).noalias() = m_y.leftCols(m_ncorr).transpose() * v;
193 res.tail(m_ncorr).noalias() = m_theta * m_s.leftCols(m_ncorr).transpose() * v;
194 }
195
196 // The b-th row of the W matrix
197 // Preserves the ordering of Y and S columns
198 // Return as a column vector
199 inline Vector Wb(int b) const
200 {
201 Vector res(2 * m_ncorr);
202 for (int j = 0; j < m_ncorr; j++)
203 {
204 res[j] = m_y(b, j);
205 res[m_ncorr + j] = m_s(b, j);
206 }
207 res.tail(m_ncorr) *= m_theta;
208 return res;
209 }
210
211 // Extract rows of W
212 inline Matrix Wb(const IndexSet& b) const
213 {
214 const int nb = b.size();
215 const int* bptr = b.data();
216 Matrix res(nb, 2 * m_ncorr);
217
218 for (int j = 0; j < m_ncorr; j++)
219 {
220 const Scalar* Yptr = &m_y(0, j);
221 const Scalar* Sptr = &m_s(0, j);
222 Scalar* resYptr = res.data() + j * nb;
223 Scalar* resSptr = resYptr + m_ncorr * nb;
224 for (int i = 0; i < nb; i++)
225 {
226 const int row = bptr[i];
227 resYptr[i] = Yptr[row];
228 resSptr[i] = Sptr[row];
229 }
230 }
231 return res;
232 }
233
234 // M is [(2*ncorr) x (2*ncorr)], v is [(2*ncorr) x 1]
235 inline void apply_Mv(const Vector& v, Vector& res) const
236 {
237 res.resize(2 * m_ncorr);
238 if (m_ncorr < 1)
239 return;
240
241 Vector vpadding = Vector::Zero(2 * m_m);
242 vpadding.head(m_ncorr).noalias() = v.head(m_ncorr);
243 vpadding.segment(m_m, m_ncorr).noalias() = v.tail(m_ncorr);
244
245 // Solve linear equation
246 m_permMsolver.solve_inplace(vpadding);
247
248 res.head(m_ncorr).noalias() = vpadding.head(m_ncorr);
249 res.tail(m_ncorr).noalias() = vpadding.segment(m_m, m_ncorr);
250 }
251
252 // Compute W'Pv
253 // W [n x (2*ncorr)], v [nP x 1], res [(2*ncorr) x 1]
254 // res preserves the ordering of Y and S columns
255 // Returns false if the result is known to be zero
256 inline bool apply_WtPv(const IndexSet& P_set, const Vector& v, Vector& res, bool test_zero = false) const
257 {
258 const int* Pptr = P_set.data();
259 const Scalar* vptr = v.data();
260 int nP = P_set.size();
261
262 // Remove zeros in v to save computation
263 IndexSet P_reduced;
264 std::vector<Scalar> v_reduced;
265 if (test_zero)
266 {
267 P_reduced.reserve(nP);
268 for (int i = 0; i < nP; i++)
269 {
270 if (vptr[i] != Scalar(0))
271 {
272 P_reduced.push_back(Pptr[i]);
273 v_reduced.push_back(vptr[i]);
274 }
275 }
276 Pptr = P_reduced.data();
277 vptr = v_reduced.data();
278 nP = P_reduced.size();
279 }
280
281 res.resize(2 * m_ncorr);
282 if (m_ncorr < 1 || nP < 1)
283 {
284 res.setZero();
285 return false;
286 }
287
288 for (int j = 0; j < m_ncorr; j++)
289 {
290 Scalar resy = Scalar(0), ress = Scalar(0);
291 const Scalar* yptr = &m_y(0, j);
292 const Scalar* sptr = &m_s(0, j);
293 for (int i = 0; i < nP; i++)
294 {
295 const int row = Pptr[i];
296 resy += yptr[row] * vptr[i];
297 ress += sptr[row] * vptr[i];
298 }
299 res[j] = resy;
300 res[m_ncorr + j] = ress;
301 }
302 res.tail(m_ncorr) *= m_theta;
303 return true;
304 }
305
306 // Compute s * P'WMv
307 // Assume that v[2*ncorr x 1] has the same ordering (permutation) as W and M
308 // Returns false if the result is known to be zero
309 inline bool apply_PtWMv(const IndexSet& P_set, const Vector& v, Vector& res, const Scalar& scale) const
310 {
311 const int nP = P_set.size();
312 res.resize(nP);
313 res.setZero();
314 if (m_ncorr < 1 || nP < 1)
315 return false;
316
317 Vector Mv;
318 apply_Mv(v, Mv);
319 // WP * Mv
320 Mv.tail(m_ncorr) *= m_theta;
321 for (int j = 0; j < m_ncorr; j++)
322 {
323 const Scalar* yptr = &m_y(0, j);
324 const Scalar* sptr = &m_s(0, j);
325 const Scalar Mvy = Mv[j], Mvs = Mv[m_ncorr + j];
326 for (int i = 0; i < nP; i++)
327 {
328 const int row = P_set[i];
329 res[i] += Mvy * yptr[row] + Mvs * sptr[row];
330 }
331 }
332 res *= scale;
333 return true;
334 }
335 // If the P'W matrix has been explicitly formed, do a direct matrix multiplication
336 inline bool apply_PtWMv(const Matrix& WP, const Vector& v, Vector& res, const Scalar& scale) const
337 {
338 const int nP = WP.rows();
339 res.resize(nP);
340 if (m_ncorr < 1 || nP < 1)
341 {
342 res.setZero();
343 return false;
344 }
345
346 Vector Mv;
347 apply_Mv(v, Mv);
348 // WP * Mv
349 Mv.tail(m_ncorr) *= m_theta;
350 res.noalias() = scale * (WP * Mv);
351 return true;
352 }
353
354 // Compute F'BAb = -(F'W)M(W'AA'd)
355 // W'd is known, and AA'+FF'=I, so W'AA'd = W'd - W'FF'd
356 // Usually d contains many zeros, so we fist compute number of nonzero elements in A set and F set,
357 // denoted as nnz_act and nnz_fv, respectively
358 // If nnz_act is smaller, compute W'AA'd = WA' (A'd) directly
359 // If nnz_fv is smaller, compute W'AA'd = W'd - WF' * (F'd)
360 inline void compute_FtBAb(
361 const Matrix& WF, const IndexSet& fv_set, const IndexSet& newact_set, const Vector& Wd, const Vector& drt,
362 Vector& res) const
363 {
364 const int nact = newact_set.size();
365 const int nfree = WF.rows();
366 res.resize(nfree);
367 if (m_ncorr < 1 || nact < 1 || nfree < 1)
368 {
369 res.setZero();
370 return;
371 }
372
373 // W'AA'd
374 Vector rhs(2 * m_ncorr);
375 if (nact <= nfree)
376 {
377 // Construct A'd
378 Vector Ad(nfree);
379 for (int i = 0; i < nact; i++)
380 Ad[i] = drt[newact_set[i]];
381 apply_WtPv(newact_set, Ad, rhs);
382 }
383 else
384 {
385 // Construct F'd
386 Vector Fd(nfree);
387 for (int i = 0; i < nfree; i++)
388 Fd[i] = drt[fv_set[i]];
389 // Compute W'AA'd = W'd - WF' * (F'd)
390 rhs.noalias() = WF.transpose() * Fd;
391 rhs.tail(m_ncorr) *= m_theta;
392 rhs.noalias() = Wd - rhs;
393 }
394
395 apply_PtWMv(WF, rhs, res, Scalar(-1));
396 }
397
398 // Compute inv(P'BP) * v
399 // P represents an index set
400 // inv(P'BP) * v = v / theta + WP * inv(inv(M) - WP' * WP / theta) * WP' * v / theta^2
401 //
402 // v is [nP x 1]
403 inline void solve_PtBP(const Matrix& WP, const Vector& v, Vector& res) const
404 {
405 const int nP = WP.rows();
406 res.resize(nP);
407 if (m_ncorr < 1 || nP < 1)
408 {
409 res.noalias() = v / m_theta;
410 return;
411 }
412
413 // Compute the matrix in the middle (only the lower triangular part is needed)
414 // Remember that W = [Y, theta * S], but we do not store theta in WP
415 Matrix mid(2 * m_ncorr, 2 * m_ncorr);
416 // [0:(ncorr - 1), 0:(ncorr - 1)]
417 for (int j = 0; j < m_ncorr; j++)
418 {
419 mid.col(j).segment(j, m_ncorr - j).noalias() = m_permMinv.col(j).segment(j, m_ncorr - j) -
420 WP.block(0, j, nP, m_ncorr - j).transpose() * WP.col(j) / m_theta;
421 }
422 // [ncorr:(2 * ncorr - 1), 0:(ncorr - 1)]
423 mid.block(m_ncorr, 0, m_ncorr, m_ncorr).noalias() = m_permMinv.block(m_m, 0, m_ncorr, m_ncorr) -
424 WP.rightCols(m_ncorr).transpose() * WP.leftCols(m_ncorr);
425 // [ncorr:(2 * ncorr - 1), ncorr:(2 * ncorr - 1)]
426 for (int j = 0; j < m_ncorr; j++)
427 {
428 mid.col(m_ncorr + j).segment(m_ncorr + j, m_ncorr - j).noalias() = m_theta *
429 (m_permMinv.col(m_m + j).segment(m_m + j, m_ncorr - j) - WP.rightCols(m_ncorr - j).transpose() * WP.col(m_ncorr + j));
430 }
431 // Factorization
432 BKLDLT<Scalar> midsolver(mid);
433 // Compute the final result
434 Vector WPv = WP.transpose() * v;
435 WPv.tail(m_ncorr) *= m_theta;
436 midsolver.solve_inplace(WPv);
437 WPv.tail(m_ncorr) *= m_theta;
438 res.noalias() = v / m_theta + (WP * WPv) / (m_theta * m_theta);
439 }
440
441 // Compute P'BQv, where P and Q are two mutually exclusive index selection operators
442 // P'BQv = -WP * M * WQ' * v
443 // Returns false if the result is known to be zero
444 inline bool apply_PtBQv(const Matrix& WP, const IndexSet& Q_set, const Vector& v, Vector& res, bool test_zero = false) const
445 {
446 const int nP = WP.rows();
447 const int nQ = Q_set.size();
448 res.resize(nP);
449 if (m_ncorr < 1 || nP < 1 || nQ < 1)
450 {
451 res.setZero();
452 return false;
453 }
454
455 Vector WQtv;
456 bool nonzero = apply_WtPv(Q_set, v, WQtv, test_zero);
457 if (!nonzero)
458 {
459 res.setZero();
460 return false;
461 }
462
463 Vector MWQtv;
464 apply_Mv(WQtv, MWQtv);
465 MWQtv.tail(m_ncorr) *= m_theta;
466 res.noalias() = -WP * MWQtv;
467 return true;
468 }
469 // If the Q'W matrix has been explicitly formed, do a direct matrix multiplication
470 inline bool apply_PtBQv(const Matrix& WP, const Matrix& WQ, const Vector& v, Vector& res) const
471 {
472 const int nP = WP.rows();
473 const int nQ = WQ.rows();
474 res.resize(nP);
475 if (m_ncorr < 1 || nP < 1 || nQ < 1)
476 {
477 res.setZero();
478 return false;
479 }
480
481 // Remember that W = [Y, theta * S], so we need to multiply theta to the second half
482 Vector WQtv = WQ.transpose() * v;
483 WQtv.tail(m_ncorr) *= m_theta;
484 Vector MWQtv;
485 apply_Mv(WQtv, MWQtv);
486 MWQtv.tail(m_ncorr) *= m_theta;
487 res.noalias() = -WP * MWQtv;
488 return true;
489 }
490};
491
492} // namespace LBFGSpp
493
495
496#endif // LBFGSPP_BFGS_MAT_H