Coverage Report

Created: 2024-10-11 06:23

/builds/MusicScience37Projects/numerical-analysis/numerical-collection-cpp/include/num_collect/regularization/explicit_l_curve.h
Line
Count
Source
1
/*
2
 * Copyright 2021 MusicScience37 (Kenta Kabashima)
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
/*!
17
 * \file
18
 * \brief Definition of explicit_l_curve class.
19
 */
20
#pragma once
21
22
#include <cmath>
23
24
#include "num_collect/opt/function_object_wrapper.h"
25
#include "num_collect/opt/heuristic_global_optimizer.h"
26
#include "num_collect/regularization/concepts/explicit_regularized_solver.h"
27
#include "num_collect/regularization/explicit_param_searcher_base.h"
28
29
namespace num_collect::regularization {
30
31
/*!
32
 * \brief Class of objective function in L-curve.
33
 *
34
 * \warning This class is meant for use in optimizers.
35
 *
36
 * \tparam Solver Type of solvers.
37
 */
38
template <concepts::explicit_regularized_solver Solver>
39
class explicit_l_curve_objective_function {
40
public:
41
    //! Type of solvers.
42
    using solver_type = Solver;
43
44
    //! Type of scalars.
45
    using scalar_type = typename solver_type::scalar_type;
46
47
    /*!
48
     * \brief Constructor.
49
     *
50
     * \param[in] solver Solver.
51
     */
52
    explicit explicit_l_curve_objective_function(const solver_type& solver)
53
1
        : solver_(&solver) {}
54
55
    /*!
56
     * \brief Calculate the curvature of L-curve.
57
     *
58
     * \param[in] log_param Logarithm of a regularization parameter.
59
     * \return Negated value of the curvature.
60
     */
61
    [[nodiscard]] auto operator()(
62
36
        const scalar_type& log_param) const -> scalar_type {
63
36
        using std::pow;
64
36
        const scalar_type param = pow(static_cast<scalar_type>(10),  // NOLINT
65
36
            log_param);
66
36
        return -solver_->l_curve_curvature(param);
67
36
    }
68
69
private:
70
    //! Solver.
71
    const solver_type* solver_;
72
};
73
74
/*!
75
 * \brief Class to search optimal regularization parameter using l-curve.
76
 *
77
 * \tparam Solver Type of solvers.
78
 * \tparam Optimizer Type of optimizers.
79
 */
80
template <concepts::explicit_regularized_solver Solver,
81
    template <typename> typename Optimizer = opt::heuristic_global_optimizer>
82
class explicit_l_curve
83
    : public explicit_param_searcher_base<explicit_l_curve<Solver, Optimizer>,
84
          Solver> {
85
public:
86
    //! Type of base class.
87
    using base_type =
88
        explicit_param_searcher_base<explicit_l_curve<Solver, Optimizer>,
89
            Solver>;
90
91
    using typename base_type::data_type;
92
    using typename base_type::scalar_type;
93
    using typename base_type::solver_type;
94
95
    //! Type of optimizers.
96
    using optimizer_type =
97
        Optimizer<opt::function_object_wrapper<scalar_type(scalar_type),
98
            explicit_l_curve_objective_function<solver_type>>>;
99
100
    /*!
101
     * \brief Constructor.
102
     *
103
     * \param[in] solver Solver.
104
     */
105
    explicit explicit_l_curve(const solver_type& solver)
106
1
        : solver_(&solver),
107
1
          optimizer_(
108
1
              opt::make_function_object_wrapper<scalar_type(scalar_type)>(
109
1
                  explicit_l_curve_objective_function<solver_type>(solver))) {}
110
111
    //! \copydoc explicit_param_searcher_base::search
112
1
    void search() {
113
1
        using std::log10;
114
1
        using std::pow;
115
1
        const auto [min_param, max_param] = solver_->param_search_region();
116
1
        const scalar_type log_min_param = log10(min_param);
117
1
        const scalar_type log_max_param = log10(max_param);
118
1
        optimizer_.init(log_min_param, log_max_param);
119
1
        optimizer_.solve();
120
1
        opt_param_ = pow(static_cast<scalar_type>(10),  // NOLINT
121
1
            optimizer_.opt_variable());
122
1
    }
123
124
    //! \copydoc explicit_param_searcher_base::opt_param
125
1
    [[nodiscard]] auto opt_param() const -> scalar_type { return opt_param_; }
126
127
    //! \copydoc explicit_param_searcher_base::solve
128
1
    void solve(data_type& solution) const {
129
1
        solver_->solve(opt_param_, solution);
130
1
    }
131
132
private:
133
    //! Solver.
134
    const solver_type* solver_;
135
136
    //! Optimizer.
137
    optimizer_type optimizer_;
138
139
    //! Optimal regularization parameter.
140
    scalar_type opt_param_{};
141
};
142
143
}  // namespace num_collect::regularization