numerical-collection-cpp 0.10.0
A collection of algorithms in numerical analysis implemented in C++
Loading...
Searching...
No Matches
general_spline_equation_solver.h
Go to the documentation of this file.
1/*
2 * Copyright 2024 MusicScience37 (Kenta Kabashima)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20#pragma once
21
22#include <cmath>
23#include <limits>
24
25#include <Eigen/Core>
26#include <Eigen/Eigenvalues>
27#include <Eigen/QR>
28#include <fmt/core.h> // IWYU pragma: keep
29
36
37namespace num_collect::rbf::impl {
38
49template <typename KernelValue, typename FunctionValue,
50 kernel_matrix_type KernelMatrixType, bool UsesGlobalLengthParameter>
52
60template <base::concepts::real_scalar KernelValue, typename FunctionValue>
61class general_spline_equation_solver<KernelValue, FunctionValue,
63public:
65 using kernel_matrix_type = Eigen::MatrixX<KernelValue>;
66
68 using additional_matrix_type = Eigen::MatrixX<KernelValue>;
69
71 using vector_type = Eigen::VectorX<FunctionValue>;
72
74 using scalar_type = KernelValue;
75
78
86 void compute(const kernel_matrix_type& kernel_matrix,
87 const additional_matrix_type& additional_matrix,
88 const vector_type& data) {
89 num_variables_ = kernel_matrix.rows();
90 NUM_COLLECT_PRECONDITION(kernel_matrix.cols() == num_variables_,
91 "Kernel matrix must be a square matrix.");
92 NUM_COLLECT_PRECONDITION(additional_matrix.rows() == num_variables_,
93 "Matrix of additional terms must have the same number of rows as "
94 "the kernel matrix.");
95 num_additional_terms_ = additional_matrix.cols();
96 NUM_COLLECT_PRECONDITION(num_variables_ > num_additional_terms_,
97 "The number of variables must be larger than the number of "
98 "additional terms.");
99 kernel_subspace_dimensions_ = num_variables_ - num_additional_terms_;
100
101 qr_decomposition_.compute(additional_matrix);
102 if (qr_decomposition_.rank() != additional_matrix.cols()) {
104 "The matrix of additional terms must have full "
105 "column rank. (columns={}, rand={})",
106 additional_matrix.cols(), qr_decomposition_.rank());
107 }
108 q_matrix_ = qr_decomposition_.householderQ();
109
110 transformed_kernel_matrix_ =
111 q_matrix_.rightCols(kernel_subspace_dimensions_).transpose() *
112 kernel_matrix * q_matrix_.rightCols(kernel_subspace_dimensions_);
113
114 eigen_value_decomposition_.compute(
115 transformed_kernel_matrix_, Eigen::ComputeEigenvectors);
116 data_transformation_matrix_ =
117 eigen_value_decomposition_.eigenvectors().transpose() *
118 q_matrix_.rightCols(kernel_subspace_dimensions_).transpose();
119 spectre_ = data_transformation_matrix_ * data;
120
121 kernel_matrix_ = &kernel_matrix;
122 data_ = &data;
123 }
124
134 void solve(vector_type& kernel_coefficients,
135 vector_type& additional_coefficients, scalar_type reg_param) const {
136 NUM_COLLECT_PRECONDITION(kernel_matrix_ != nullptr && data_ != nullptr,
137 "compute() must be called before solve().");
138
139 reg_param = correct_reg_param_if_needed(reg_param);
140
141 kernel_coefficients = data_transformation_matrix_.transpose() *
142 (eigen_value_decomposition_.eigenvalues().array() + reg_param)
143 .inverse()
144 .matrix()
145 .asDiagonal() *
146 spectre_;
147
148 additional_coefficients = qr_decomposition_.solve(
149 (*data_) - (*kernel_matrix_) * kernel_coefficients);
150 }
151
161 [[nodiscard]] auto calc_mle_objective(scalar_type reg_param) const
162 -> scalar_type {
163 reg_param = correct_reg_param_if_needed(reg_param);
164
165 constexpr scalar_type limit = std::numeric_limits<scalar_type>::max() *
166 static_cast<scalar_type>(1e-20);
167 if (eigen_value_decomposition_.eigenvalues()(0) + reg_param <=
168 static_cast<scalar_type>(0)) {
169 return limit;
170 }
171
172 using std::log;
173 const scalar_type value =
174 static_cast<scalar_type>(kernel_subspace_dimensions_) *
175 log(calc_reg_term(reg_param)) +
176 calc_log_determinant(reg_param);
177 if (value < limit) {
178 return value;
179 }
180 return limit;
181 }
182
183private:
190 [[nodiscard]] auto calc_reg_term(const scalar_type& reg_param) const
191 -> scalar_type {
192 return (spectre_.array().abs2().rowwise().sum() /
193 (eigen_value_decomposition_.eigenvalues().array() + reg_param))
194 .sum();
195 }
196
204 [[nodiscard]] auto calc_log_determinant(const scalar_type& reg_param) const
205 -> scalar_type {
206 return (eigen_value_decomposition_.eigenvalues().array() + reg_param)
207 .log()
208 .sum();
209 }
210
217 [[nodiscard]] auto correct_reg_param_if_needed(
218 const scalar_type& reg_param) const noexcept -> scalar_type {
219 const scalar_type smallest_eigenvalue =
220 eigen_value_decomposition_.eigenvalues()(0);
221 const scalar_type largest_eigenvalue =
222 eigen_value_decomposition_.eigenvalues()(
223 eigen_value_decomposition_.eigenvalues().size() - 1);
224 const scalar_type eigenvalue_safe_limit =
225 largest_eigenvalue * std::numeric_limits<scalar_type>::epsilon();
226 const scalar_type reg_param_safe_limit =
227 eigenvalue_safe_limit - smallest_eigenvalue;
228
229 if (reg_param < reg_param_safe_limit) {
230 return reg_param_safe_limit;
231 }
232 return reg_param;
233 }
234
236 Eigen::SelfAdjointEigenSolver<kernel_matrix_type>
237 eigen_value_decomposition_{};
238
240 Eigen::ColPivHouseholderQR<additional_matrix_type> qr_decomposition_{};
241
244
246 kernel_matrix_type transformed_kernel_matrix_{};
247
249 kernel_matrix_type data_transformation_matrix_{};
250
252 vector_type spectre_{};
253
255 const kernel_matrix_type* kernel_matrix_{nullptr};
256
258 const vector_type* data_{nullptr};
259
261 index_type num_variables_{};
262
264 index_type num_additional_terms_{};
265
267 index_type kernel_subspace_dimensions_{};
268};
269
270} // namespace num_collect::rbf::impl
Class of exception on failure in algorithm.
Definition exception.h:93
void compute(const kernel_matrix_type &kernel_matrix, const additional_matrix_type &additional_matrix, const vector_type &data)
Compute internal matrices.
auto calc_mle_objective(scalar_type reg_param) const -> scalar_type
Calculate maximum likelihood estimation (MLE) objective function scheuerer2011.
void solve(vector_type &kernel_coefficients, vector_type &additional_coefficients, scalar_type reg_param) const
Solve the linear equation with a regularization parameter.
auto calc_log_determinant(const scalar_type &reg_param) const -> scalar_type
Calculate the logarithm of the determinant of kernel matrix plus regularization parameter.
auto correct_reg_param_if_needed(const scalar_type &reg_param) const noexcept -> scalar_type
Correct regularization parameter if needed.
auto calc_reg_term(const scalar_type &reg_param) const -> scalar_type
Calculate the regularization term.
Class to solve linear equations of kernel matrices and matrices of additional terms in RBF interpolat...
Definition of exceptions.
Definition of index_type type.
Definition of kernel_matrix_type enumeration.
Definition of macros for logging.
#define NUM_COLLECT_LOG_AND_THROW(EXCEPTION_TYPE,...)
Write an error log and throw an exception for an error.
std::ptrdiff_t index_type
Type of indices in this library.
Definition index_type.h:33
Namespace of internal implementations.
kernel_matrix_type
Enumeration of types of kernel matrices.
Definition of NUM_COLLECT_PRECONDITION macro.
#define NUM_COLLECT_PRECONDITION(CONDITION,...)
Check whether a precondition is satisfied and throw an exception if not.
Definition of real_scalar concept.