/builds/MusicScience37Projects/numerical-analysis/numerical-collection-cpp/include/num_collect/opt/newton_optimizer.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright 2021 MusicScience37 (Kenta Kabashima) |
3 | | * |
4 | | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | * you may not use this file except in compliance with the License. |
6 | | * You may obtain a copy of the License at |
7 | | * |
8 | | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | | * |
10 | | * Unless required by applicable law or agreed to in writing, software |
11 | | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | * See the License for the specific language governing permissions and |
14 | | * limitations under the License. |
15 | | */ |
16 | | /*! |
17 | | * \file |
18 | | * \brief Definition of newton_optimizer class. |
19 | | */ |
20 | | #pragma once |
21 | | |
22 | | #include <Eigen/Cholesky> |
23 | | |
24 | | #include "num_collect/base/index_type.h" |
25 | | #include "num_collect/logging/iterations/iteration_logger.h" |
26 | | #include "num_collect/logging/log_tag_view.h" |
27 | | #include "num_collect/opt/backtracking_line_searcher.h" |
28 | | #include "num_collect/opt/concepts/line_searcher.h" |
29 | | #include "num_collect/opt/concepts/multi_variate_twice_differentiable_objective_function.h" |
30 | | #include "num_collect/opt/descent_method_base.h" |
31 | | |
32 | | namespace num_collect::opt { |
33 | | |
34 | | //! Tag of newton_optimizer. |
35 | | constexpr auto newton_optimizer_tag = |
36 | | logging::log_tag_view("num_collect::opt::newton_optimizer"); |
37 | | |
38 | | /*! |
39 | | * \brief Class of newton method for optimization. |
40 | | * |
41 | | * \tparam ObjectiveFunction Type of the objective function. |
42 | | * \tparam LineSearcher Type of class to perform line search. |
43 | | * \tparam HessianSolver Type of solvers of linear equation of Hessian. |
44 | | */ |
45 | | template <concepts::multi_variate_twice_differentiable_objective_function |
46 | | ObjectiveFunction, |
47 | | concepts::line_searcher LineSearcher = |
48 | | backtracking_line_searcher<ObjectiveFunction>, |
49 | | typename HessianSolver = |
50 | | Eigen::LLT<typename ObjectiveFunction::hessian_type>> |
51 | | class newton_optimizer |
52 | | : public descent_method_base< |
53 | | newton_optimizer<ObjectiveFunction, LineSearcher, HessianSolver>, |
54 | | LineSearcher> { |
55 | | public: |
56 | | //! This class. |
57 | | using this_type = |
58 | | newton_optimizer<ObjectiveFunction, LineSearcher, HessianSolver>; |
59 | | |
60 | | //! Type of base class. |
61 | | using base_type = descent_method_base<this_type, LineSearcher>; |
62 | | |
63 | | using typename base_type::objective_function_type; |
64 | | using typename base_type::variable_type; |
65 | | |
66 | | //! Type of Hessian. |
67 | | using hessian_type = typename objective_function_type::hessian_type; |
68 | | |
69 | | //! Type of solvers of linear equation of Hessian. |
70 | | using hessian_solver_type = HessianSolver; |
71 | | |
72 | | /*! |
73 | | * \brief Constructor. |
74 | | * |
75 | | * \param[in] obj_fun Objective function. |
76 | | */ |
77 | | explicit newton_optimizer( |
78 | | const objective_function_type& obj_fun = objective_function_type()) |
79 | 3 | : base_type(newton_optimizer_tag, obj_fun) {} |
80 | | |
81 | | using base_type::evaluations; |
82 | | using base_type::gradient; |
83 | | using base_type::gradient_norm; |
84 | | using base_type::iterations; |
85 | | using base_type::line_searcher; |
86 | | using base_type::opt_value; |
87 | | using typename base_type::value_type; |
88 | | |
89 | | /*! |
90 | | * \brief Get Hessian for current optimal variable. |
91 | | * |
92 | | * \return Hessian for current optimal variable. |
93 | | */ |
94 | 2 | [[nodiscard]] auto hessian() const -> const hessian_type& { |
95 | 2 | return line_searcher().obj_fun().hessian(); |
96 | 2 | } |
97 | | |
98 | | /*! |
99 | | * \copydoc num_collect::opt::descent_method_base::calc_direction |
100 | | */ |
101 | 2 | [[nodiscard]] auto calc_direction() -> variable_type { |
102 | 2 | return calc_direction_impl(hessian()); |
103 | 2 | } |
104 | | |
105 | | /*! |
106 | | * \copydoc num_collect::base::iterative_solver_base::configure_iteration_logger |
107 | | */ |
108 | | void configure_iteration_logger( |
109 | | logging::iterations::iteration_logger<this_type>& iteration_logger) |
110 | 1 | const { |
111 | 1 | iteration_logger.template append<index_type>( |
112 | 1 | "Iter.", &base_type::iterations); |
113 | 1 | iteration_logger.template append<index_type>( |
114 | 1 | "Eval.", &base_type::evaluations); |
115 | 1 | iteration_logger.template append<value_type>( |
116 | 1 | "Value", &base_type::opt_value); |
117 | 1 | iteration_logger.template append<value_type>( |
118 | 1 | "Grad.", &base_type::gradient_norm); |
119 | 1 | } |
120 | | |
121 | | private: |
122 | | /*! |
123 | | * \brief Calculate search direction. |
124 | | * |
125 | | * \param[in] hessian Hessian. |
126 | | * \return Search direction. |
127 | | */ |
128 | | [[nodiscard]] auto calc_direction_impl(const hessian_type& hessian) |
129 | 2 | -> variable_type { |
130 | 2 | solver_.compute(hessian); |
131 | 2 | return -solver_.solve(gradient()); |
132 | 2 | } |
133 | | |
134 | | //! Solver of linear equation of Hessian. |
135 | | hessian_solver_type solver_; |
136 | | }; |
137 | | |
138 | | } // namespace num_collect::opt |