Coverage Report

Created: 2024-12-20 06:23

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/builds/MusicScience37Projects/numerical-analysis/numerical-collection-cpp/include/num_collect/opt/newton_optimizer.h
Line
Count
Source
1
/*
2
 * Copyright 2021 MusicScience37 (Kenta Kabashima)
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
/*!
17
 * \file
18
 * \brief Definition of newton_optimizer class.
19
 */
20
#pragma once
21
22
#include <Eigen/Cholesky>
23
24
#include "num_collect/base/index_type.h"
25
#include "num_collect/logging/iterations/iteration_logger.h"
26
#include "num_collect/logging/log_tag_view.h"
27
#include "num_collect/opt/backtracking_line_searcher.h"
28
#include "num_collect/opt/concepts/line_searcher.h"
29
#include "num_collect/opt/concepts/multi_variate_twice_differentiable_objective_function.h"
30
#include "num_collect/opt/descent_method_base.h"
31
32
namespace num_collect::opt {
33
34
//! Tag of newton_optimizer.
35
constexpr auto newton_optimizer_tag =
36
    logging::log_tag_view("num_collect::opt::newton_optimizer");
37
38
/*!
39
 * \brief Class of newton method for optimization.
40
 *
41
 * \tparam ObjectiveFunction Type of the objective function.
42
 * \tparam LineSearcher Type of class to perform line search.
43
 * \tparam HessianSolver Type of solvers of linear equation of Hessian.
44
 */
45
template <concepts::multi_variate_twice_differentiable_objective_function
46
              ObjectiveFunction,
47
    concepts::line_searcher LineSearcher =
48
        backtracking_line_searcher<ObjectiveFunction>,
49
    typename HessianSolver =
50
        Eigen::LLT<typename ObjectiveFunction::hessian_type>>
51
class newton_optimizer
52
    : public descent_method_base<
53
          newton_optimizer<ObjectiveFunction, LineSearcher, HessianSolver>,
54
          LineSearcher> {
55
public:
56
    //! This class.
57
    using this_type =
58
        newton_optimizer<ObjectiveFunction, LineSearcher, HessianSolver>;
59
60
    //! Type of base class.
61
    using base_type = descent_method_base<this_type, LineSearcher>;
62
63
    using typename base_type::objective_function_type;
64
    using typename base_type::variable_type;
65
66
    //! Type of Hessian.
67
    using hessian_type = typename objective_function_type::hessian_type;
68
69
    //! Type of solvers of linear equation of Hessian.
70
    using hessian_solver_type = HessianSolver;
71
72
    /*!
73
     * \brief Constructor.
74
     *
75
     * \param[in] obj_fun Objective function.
76
     */
77
    explicit newton_optimizer(
78
        const objective_function_type& obj_fun = objective_function_type())
79
3
        : base_type(newton_optimizer_tag, obj_fun) {}
80
81
    using base_type::evaluations;
82
    using base_type::gradient;
83
    using base_type::gradient_norm;
84
    using base_type::iterations;
85
    using base_type::line_searcher;
86
    using base_type::opt_value;
87
    using typename base_type::value_type;
88
89
    /*!
90
     * \brief Get Hessian for current optimal variable.
91
     *
92
     * \return Hessian for current optimal variable.
93
     */
94
2
    [[nodiscard]] auto hessian() const -> const hessian_type& {
95
2
        return line_searcher().obj_fun().hessian();
96
2
    }
97
98
    /*!
99
     * \copydoc num_collect::opt::descent_method_base::calc_direction
100
     */
101
2
    [[nodiscard]] auto calc_direction() -> variable_type {
102
2
        return calc_direction_impl(hessian());
103
2
    }
104
105
    /*!
106
     * \copydoc num_collect::base::iterative_solver_base::configure_iteration_logger
107
     */
108
    void configure_iteration_logger(
109
        logging::iterations::iteration_logger<this_type>& iteration_logger)
110
1
        const {
111
1
        iteration_logger.template append<index_type>(
112
1
            "Iter.", &base_type::iterations);
113
1
        iteration_logger.template append<index_type>(
114
1
            "Eval.", &base_type::evaluations);
115
1
        iteration_logger.template append<value_type>(
116
1
            "Value", &base_type::opt_value);
117
1
        iteration_logger.template append<value_type>(
118
1
            "Grad.", &base_type::gradient_norm);
119
1
    }
120
121
private:
122
    /*!
123
     * \brief Calculate search direction.
124
     *
125
     * \param[in] hessian Hessian.
126
     * \return Search direction.
127
     */
128
    [[nodiscard]] auto calc_direction_impl(const hessian_type& hessian)
129
2
        -> variable_type {
130
2
        solver_.compute(hessian);
131
2
        return -solver_.solve(gradient());
132
2
    }
133
134
    //! Solver of linear equation of Hessian.
135
    hessian_solver_type solver_;
136
};
137
138
}  // namespace num_collect::opt