numerical-collection-cpp 0.10.0
A collection of algorithms in numerical analysis implemented in C++
Loading...
Searching...
No Matches
fista.h
Go to the documentation of this file.
1/*
2 * Copyright 2021 MusicScience37 (Kenta Kabashima)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20#pragma once
21
22#include <algorithm>
23#include <cmath>
24#include <limits>
25#include <utility>
26
34#include "num_collect/regularization/impl/weak_coeff_param.h" // IWYU pragma: keep
36
38
40constexpr auto fista_tag =
41 logging::log_tag_view("num_collect::regularization::fista");
42
54template <typename Coeff, typename Data>
55class fista
56 : public iterative_regularized_solver_base<fista<Coeff, Data>, Data> {
57public:
60
63
64 using typename base_type::data_type;
65 using typename base_type::scalar_type;
66
68 using coeff_type = Coeff;
69
74
85 void compute(const Coeff& coeff, const Data& data) {
86 coeff_ = &coeff;
87 data_ = &data;
88 inv_max_eigen_ = static_cast<scalar_type>(1) /
91 this->logger(), "inv_max_eigen={}", inv_max_eigen_);
92 }
93
95 void init(const scalar_type& param, data_type& solution) {
96 (void)param;
97
98 NUM_COLLECT_PRECONDITION(coeff_->rows() == data_->rows(),
99 this->logger(),
100 "Coefficient matrix and data vector must have the same number of "
101 "rows.");
102 NUM_COLLECT_PRECONDITION(coeff_->cols() == solution.rows(),
103 this->logger(),
104 "The number of columns in the coefficient matrix must match the "
105 "number of rows in solution vector.");
106 NUM_COLLECT_PRECONDITION(data_->cols() == solution.cols(),
107 this->logger(),
108 "Data and solution must have the same number of columns.");
109
110 iterations_ = 0;
111 t_ = static_cast<scalar_type>(1);
112 y_ = solution;
113 update_ = std::numeric_limits<scalar_type>::infinity();
114 }
115
117 void iterate(const scalar_type& param, data_type& solution) {
118 const scalar_type t_before = t_;
119 using std::sqrt;
120 t_ = static_cast<scalar_type>(0.5) * // NOLINT
121 (static_cast<scalar_type>(1) + // NOLINT
122 sqrt(static_cast<scalar_type>(1) + // NOLINT
123 static_cast<scalar_type>(4) // NOLINT
124 * t_before * t_before));
125 const scalar_type coeff_update =
126 (t_before - static_cast<scalar_type>(1)) / t_;
127
128 const scalar_type twice_step = inv_max_eigen_;
129 const scalar_type step = static_cast<scalar_type>(0.5) * twice_step;
130 const scalar_type trunc_thresh = param * step;
131
132 residual_ = -(*data_);
133 auto update_sum2 = static_cast<scalar_type>(0);
134#pragma omp parallel default(shared)
135 {
136 const index_type size = solution.size();
137 data_type temp_res = Data::Zero(residual_.size());
138#pragma omp for nowait
139 for (index_type i = 0; i < size; ++i) {
140 using std::abs;
141 if (abs(y_(i)) > static_cast<scalar_type>(0)) {
142 temp_res += y_(i) * coeff_->col(i);
143 }
144 }
145#pragma omp critical
146 residual_ += temp_res;
147#pragma omp barrier
148
149#pragma omp for reduction(+ : update_sum2)
150 for (index_type i = 0; i < size; ++i) {
151 scalar_type cur_next_sol =
152 y_(i) - twice_step * coeff_->col(i).dot(residual_);
153
154 if (cur_next_sol > trunc_thresh) {
155 cur_next_sol = cur_next_sol - trunc_thresh;
156 } else if (cur_next_sol < -trunc_thresh) {
157 cur_next_sol = cur_next_sol + trunc_thresh;
158 } else {
159 cur_next_sol = static_cast<scalar_type>(0);
160 }
161
162 const scalar_type current_update = cur_next_sol - solution(i);
163 update_sum2 += current_update * current_update;
164
165 y_(i) = cur_next_sol + coeff_update * current_update;
166 solution(i) = cur_next_sol;
167 }
168 }
169
170 update_ = sqrt(update_sum2);
171 ++iterations_;
172 }
173
175 [[nodiscard]] auto is_stop_criteria_satisfied(
176 const data_type& solution) const -> bool {
177 return (iterations() > max_iterations()) ||
178 (update() < tol_update_rate() * solution.norm());
179 }
180
184 const {
185 iteration_logger.template append<index_type>(
186 "Iter.", &this_type::iterations);
187 iteration_logger.template append<scalar_type>(
188 "Update", &this_type::update);
189 iteration_logger.template append<scalar_type>(
190 "Res.Rate", &this_type::residual_norm_rate);
191 }
192
194 [[nodiscard]] auto residual_norm(const data_type& solution) const
195 -> scalar_type {
196 return ((*coeff_) * solution - (*data_)).squaredNorm();
197 }
198
200 [[nodiscard]] auto regularization_term(const data_type& solution) const
201 -> scalar_type {
202 return solution.template lpNorm<1>();
203 }
204
206 void change_data(const data_type& data) { data_ = &data; }
207
209 void calculate_data_for(const data_type& solution, data_type& data) const {
210 data = (*coeff_) * solution;
211 }
212
214 [[nodiscard]] auto data_size() const -> index_type { return data_->size(); }
215
217 [[nodiscard]] auto param_search_region() const
218 -> std::pair<scalar_type, scalar_type> {
219 const scalar_type max_sol_est =
220 (coeff_->transpose() * (*data_)).cwiseAbs().maxCoeff();
221 NUM_COLLECT_LOG_TRACE(this->logger(), "max_sol_est={}", max_sol_est);
222 constexpr auto tol_update_coeff_multiplier =
223 static_cast<scalar_type>(10);
224 return {max_sol_est *
226 tol_update_coeff_multiplier * tol_update_rate_),
228 }
229
235 [[nodiscard]] auto iterations() const noexcept -> index_type {
236 return iterations_;
237 }
238
244 [[nodiscard]] auto update() const noexcept -> scalar_type {
245 return update_;
246 }
247
253 [[nodiscard]] auto residual_norm_rate() const -> scalar_type {
254 return residual_.squaredNorm() / data_->squaredNorm();
255 }
256
262 [[nodiscard]] auto max_iterations() const -> index_type {
263 return max_iterations_;
264 }
265
273 NUM_COLLECT_PRECONDITION(value > 0, this->logger(),
274 "Maximum number of iterations must be a positive integer.");
275 max_iterations_ = value;
276 return *this;
277 }
278
284 [[nodiscard]] auto tol_update_rate() const -> scalar_type {
285 return tol_update_rate_;
286 }
287
295 NUM_COLLECT_PRECONDITION(value > static_cast<scalar_type>(0),
296 this->logger(),
297 "Tolerance of update rate of the solution must be a positive "
298 "value.");
299 tol_update_rate_ = value;
300 return *this;
301 }
302
303private:
305 const coeff_type* coeff_{nullptr};
306
308 const data_type* data_{nullptr};
309
315
318
321
324
327
330
332 static constexpr index_type default_max_iterations = 1000;
333
336
338 static constexpr auto default_tol_update_rate =
339 static_cast<scalar_type>(1e-4);
340
343};
344
345} // namespace num_collect::regularization
Definition of max_eigen_aat function.
Class of tags of logs without memory management.
auto logger() const noexcept -> const num_collect::logging::logger &
Access to the logger.
Class for fast iterative shrinkage-thresholding algorithm (FISTA) beck2009 for L1-regularization of l...
Definition fista.h:56
auto tol_update_rate(scalar_type value) -> fista &
Set the tolerance of update rate of the solution.
Definition fista.h:294
auto residual_norm_rate() const -> scalar_type
Get the rate of the last residual norm.
Definition fista.h:253
auto tol_update_rate() const -> scalar_type
Get the tolerance of update rate of the solution.
Definition fista.h:284
void init(const scalar_type &param, data_type &solution)
Initialize.
Definition fista.h:95
auto regularization_term(const data_type &solution) const -> scalar_type
Calculate the regularization term.
Definition fista.h:200
void configure_iteration_logger(logging::iterations::iteration_logger< this_type > &iteration_logger) const
Configure an iteration logger.
Definition fista.h:182
scalar_type t_
Parameter for step size of y_.
Definition fista.h:320
void compute(const Coeff &coeff, const Data &data)
Compute internal parameters.
Definition fista.h:85
Coeff coeff_type
Type of coefficient matrices.
Definition fista.h:68
scalar_type tol_update_rate_
Tolerance of update rate of the solution.
Definition fista.h:342
auto iterations() const noexcept -> index_type
Get the number of iterations.
Definition fista.h:235
data_type residual_
Residual vector.
Definition fista.h:326
auto data_size() const -> index_type
Get the size of data.
Definition fista.h:214
scalar_type inv_max_eigen_
Inverse of maximum eigenvalue of for coefficient matrix .
Definition fista.h:314
auto update() const noexcept -> scalar_type
Get the norm of the update of the solution in the last iteration.
Definition fista.h:244
void iterate(const scalar_type &param, data_type &solution)
Iterate the algorithm once.
Definition fista.h:117
index_type iterations_
Number of iterations.
Definition fista.h:317
const coeff_type * coeff_
Coefficient matrix.
Definition fista.h:305
auto max_iterations(index_type value) -> fista &
Set the maximum number of iterations.
Definition fista.h:272
index_type max_iterations_
Maximum number of iterations.
Definition fista.h:335
auto is_stop_criteria_satisfied(const data_type &solution) const -> bool
Determine if stopping criteria of the algorithm are satisfied.
Definition fista.h:175
data_type y_
Another vector to update in FISTA.
Definition fista.h:323
void calculate_data_for(const data_type &solution, data_type &data) const
Calculate data for a solution.
Definition fista.h:209
auto max_iterations() const -> index_type
Get the maximum number of iterations.
Definition fista.h:262
const data_type * data_
Data vector.
Definition fista.h:308
auto param_search_region() const -> std::pair< scalar_type, scalar_type >
Get the default region to search for the optimal regularization parameter.
Definition fista.h:217
static constexpr index_type default_max_iterations
Default maximum number of iterations.
Definition fista.h:332
scalar_type update_
Norm of the update of the solution in the last iteration.
Definition fista.h:329
auto residual_norm(const data_type &solution) const -> scalar_type
Calculate the squared norm of the residual.
Definition fista.h:194
void change_data(const data_type &data)
Change data.
Definition fista.h:206
static constexpr auto default_tol_update_rate
Default tolerance of update rate of the solution.
Definition fista.h:338
Base class of solvers using iterative formulas for regularization.
typename Eigen::NumTraits< typename data_type::Scalar >::Real scalar_type
Type of scalars.
Definition of exceptions.
Definition of index_type type.
Definition of iteration_logger class.
Definition of iterative_regularized_solver_base class.
Definition of log_tag_view class.
Definition of macros for logging.
#define NUM_COLLECT_LOG_TRACE(LOGGER,...)
Write a trace log.
std::ptrdiff_t index_type
Type of indices in this library.
Definition index_type.h:33
constexpr auto weak_coeff_min_param
Coefficient (minimum parameter to be searched) / (maximum singular value or eigen value).
auto approximate_max_eigen_aat(const Matrix &matrix) -> typename Matrix::Scalar
Approximate the maximum eigenvalue of for a matrix .
constexpr auto weak_coeff_max_param
Coefficient (maximum parameter to be searched) / (maximum singular value or eigen value).
Namespace of regularization algorithms.
constexpr auto fista_tag
Tag of fista.
Definition fista.h:40
STL namespace.
Definition of NUM_COLLECT_PRECONDITION macro.
#define NUM_COLLECT_PRECONDITION(CONDITION,...)
Check whether a precondition is satisfied and throw an exception if not.
Definition of coefficients for regularization parameters.