/builds/MusicScience37Projects/numerical-analysis/numerical-collection-cpp/include/num_collect/regularization/explicit_l_curve.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright 2021 MusicScience37 (Kenta Kabashima) |
3 | | * |
4 | | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | * you may not use this file except in compliance with the License. |
6 | | * You may obtain a copy of the License at |
7 | | * |
8 | | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | | * |
10 | | * Unless required by applicable law or agreed to in writing, software |
11 | | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | * See the License for the specific language governing permissions and |
14 | | * limitations under the License. |
15 | | */ |
16 | | /*! |
17 | | * \file |
18 | | * \brief Definition of explicit_l_curve class. |
19 | | */ |
20 | | #pragma once |
21 | | |
22 | | #include <cmath> |
23 | | |
24 | | #include "num_collect/opt/function_object_wrapper.h" |
25 | | #include "num_collect/opt/heuristic_global_optimizer.h" |
26 | | #include "num_collect/regularization/concepts/explicit_regularized_solver.h" |
27 | | #include "num_collect/regularization/explicit_param_searcher_base.h" |
28 | | |
29 | | namespace num_collect::regularization { |
30 | | |
31 | | /*! |
32 | | * \brief Class of objective function in L-curve. |
33 | | * |
34 | | * \warning This class is meant for use in optimizers. |
35 | | * |
36 | | * \tparam Solver Type of solvers. |
37 | | */ |
38 | | template <concepts::explicit_regularized_solver Solver> |
39 | | class explicit_l_curve_objective_function { |
40 | | public: |
41 | | //! Type of solvers. |
42 | | using solver_type = Solver; |
43 | | |
44 | | //! Type of scalars. |
45 | | using scalar_type = typename solver_type::scalar_type; |
46 | | |
47 | | /*! |
48 | | * \brief Constructor. |
49 | | * |
50 | | * \param[in] solver Solver. |
51 | | */ |
52 | | explicit explicit_l_curve_objective_function(const solver_type& solver) |
53 | 1 | : solver_(&solver) {} |
54 | | |
55 | | /*! |
56 | | * \brief Calculate the curvature of L-curve. |
57 | | * |
58 | | * \param[in] log_param Logarithm of a regularization parameter. |
59 | | * \return Negated value of the curvature. |
60 | | */ |
61 | | [[nodiscard]] auto operator()(const scalar_type& log_param) const |
62 | 36 | -> scalar_type { |
63 | 36 | using std::pow; |
64 | 36 | const scalar_type param = pow(static_cast<scalar_type>(10), // NOLINT |
65 | 36 | log_param); |
66 | 36 | return -solver_->l_curve_curvature(param); |
67 | 36 | } |
68 | | |
69 | | private: |
70 | | //! Solver. |
71 | | const solver_type* solver_; |
72 | | }; |
73 | | |
74 | | /*! |
75 | | * \brief Class to search optimal regularization parameter using l-curve. |
76 | | * |
77 | | * \tparam Solver Type of solvers. |
78 | | * \tparam Optimizer Type of optimizers. |
79 | | */ |
80 | | template <concepts::explicit_regularized_solver Solver, |
81 | | template <typename> typename Optimizer = opt::heuristic_global_optimizer> |
82 | | class explicit_l_curve |
83 | | : public explicit_param_searcher_base<explicit_l_curve<Solver, Optimizer>, |
84 | | Solver> { |
85 | | public: |
86 | | //! Type of base class. |
87 | | using base_type = |
88 | | explicit_param_searcher_base<explicit_l_curve<Solver, Optimizer>, |
89 | | Solver>; |
90 | | |
91 | | using typename base_type::data_type; |
92 | | using typename base_type::scalar_type; |
93 | | using typename base_type::solver_type; |
94 | | |
95 | | //! Type of optimizers. |
96 | | using optimizer_type = |
97 | | Optimizer<opt::function_object_wrapper<scalar_type(scalar_type), |
98 | | explicit_l_curve_objective_function<solver_type>>>; |
99 | | |
100 | | /*! |
101 | | * \brief Constructor. |
102 | | * |
103 | | * \param[in] solver Solver. |
104 | | */ |
105 | | explicit explicit_l_curve(const solver_type& solver) |
106 | 1 | : solver_(&solver), |
107 | 1 | optimizer_( |
108 | 1 | opt::make_function_object_wrapper<scalar_type(scalar_type)>( |
109 | 1 | explicit_l_curve_objective_function<solver_type>(solver))) {} |
110 | | |
111 | | //! \copydoc explicit_param_searcher_base::search |
112 | 1 | void search() { |
113 | 1 | using std::log10; |
114 | 1 | using std::pow; |
115 | 1 | const auto [min_param, max_param] = solver_->param_search_region(); |
116 | 1 | const scalar_type log_min_param = log10(min_param); |
117 | 1 | const scalar_type log_max_param = log10(max_param); |
118 | 1 | optimizer_.init(log_min_param, log_max_param); |
119 | 1 | optimizer_.solve(); |
120 | 1 | opt_param_ = pow(static_cast<scalar_type>(10), // NOLINT |
121 | 1 | optimizer_.opt_variable()); |
122 | 1 | } |
123 | | |
124 | | //! \copydoc explicit_param_searcher_base::opt_param |
125 | 1 | [[nodiscard]] auto opt_param() const -> scalar_type { return opt_param_; } |
126 | | |
127 | | //! \copydoc explicit_param_searcher_base::solve |
128 | 1 | void solve(data_type& solution) const { |
129 | 1 | solver_->solve(opt_param_, solution); |
130 | 1 | } |
131 | | |
132 | | private: |
133 | | //! Solver. |
134 | | const solver_type* solver_; |
135 | | |
136 | | //! Optimizer. |
137 | | optimizer_type optimizer_; |
138 | | |
139 | | //! Optimal regularization parameter. |
140 | | scalar_type opt_param_{}; |
141 | | }; |
142 | | |
143 | | } // namespace num_collect::regularization |