numerical-collection-cpp 0.10.0
A collection of algorithms in numerical analysis implemented in C++
Loading...
Searching...
No Matches
gaussian_process_interpolator.h
Go to the documentation of this file.
1/*
2 * Copyright 2024 MusicScience37 (Kenta Kabashima)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20#pragma once
21
22#include <algorithm>
23#include <cstddef>
24#include <utility>
25#include <vector>
26
27#include <Eigen/Core>
28
38
39namespace num_collect::rbf {
40
53template <typename FunctionSignature,
54 concepts::rbf RBF =
55 rbfs::gaussian_rbf<impl::get_default_scalar_type<FunctionSignature>>,
57 concepts::distance_function DistanceFunction =
58 distance_functions::euclidean_distance_function<
61 : public global_rbf_interpolator<FunctionSignature, RBF, KernelMatrixType,
62 DistanceFunction> {
63public:
64 static_assert(KernelMatrixType == kernel_matrix_type::dense,
65 "Current implementation does not support sparse kernel matrices.");
66
68 using base_type = global_rbf_interpolator<FunctionSignature, RBF,
69 KernelMatrixType, DistanceFunction>;
70
71 using base_type::coeffs;
72 using base_type::distance_function;
73 using base_type::length_parameter_calculator;
74 using base_type::rbf;
75 using typename base_type::function_value_type;
76 using typename base_type::function_value_vector_type;
77 using typename base_type::kernel_value_type;
78 using typename base_type::variable_type;
79
89 void compute(const std::vector<variable_type>& variables,
90 const function_value_vector_type& function_values) {
91 base_type::compute(variables, function_values);
92 common_coeff_ = function_values.dot(coeffs()) /
93 static_cast<function_value_type>(function_values.size());
94 }
95
103 const variable_type& variable) const
104 -> std::pair<function_value_type, function_value_type> {
105 Eigen::VectorXd kernel_vec;
106 kernel_vec.resize(static_cast<index_type>(variables().size()));
107 for (std::size_t i = 0; i < variables().size(); ++i) {
108 kernel_vec(static_cast<index_type>(i)) =
109 rbf()(distance_function()(variable, variables()[i]) /
110 length_parameter_calculator().length_parameter_at(
111 static_cast<index_type>(i)));
112 }
113
114 const function_value_type mean = kernel_vec.dot(coeffs());
115 const function_value_type center_rbf_value =
116 rbf()(static_cast<kernel_value_type>(0));
117 const function_value_type variance = common_coeff_ *
118 std::max<function_value_type>(center_rbf_value -
119 kernel_matrix_solver().calc_reg_term(kernel_vec, reg_param),
120 static_cast<function_value_type>(0));
121 return {mean, variance};
122 }
123
124private:
125 using base_type::kernel_matrix_solver;
126 using base_type::variables;
127
129 static constexpr auto reg_param = static_cast<kernel_value_type>(0);
130
131 // TODO Implementation of regularization.
132
134 function_value_type common_coeff_{};
135};
136
137} // namespace num_collect::rbf
Class to interpolate using Gaussian process.
static constexpr auto reg_param
Regularization parameter.
function_value_type common_coeff_
Common coefficients for RBF.
auto evaluate_mean_and_variance_on(const variable_type &variable) const -> std::pair< function_value_type, function_value_type >
Evaluate mean and variance in the gaussian process for a variable.
void compute(const std::vector< variable_type > &variables, const function_value_vector_type &function_values)
Compute parameters for interpolation.
Class to interpolate using RBF.
Definition of distance_function concept.
Definition of euclidean_distance_function class.
Definition of gaussian_rbf class.
Definition of get_default_scalar_type type.
Definition of get_variable_type class.
Definition of index_type type.
Definition of kernel_matrix_type enumeration.
std::ptrdiff_t index_type
Type of indices in this library.
Definition index_type.h:33
typename get_variable_type< FunctionSignature >::type get_variable_type_t
Get the type of variables from function signature.
Namespace of RBF interpolation.
kernel_matrix_type
Enumeration of types of kernel matrices.
Definition of rbf concept.
Definition of rbf_interpolator class.