numerical-collection-cpp 0.10.0
A collection of algorithms in numerical analysis implemented in C++
Loading...
Searching...
No Matches
node_differentiator.h
Go to the documentation of this file.
1/*
2 * Copyright 2021 MusicScience37 (Kenta Kabashima)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20#pragma once
21
22#include <queue>
23#include <unordered_map>
24
29
31
38template <base::concepts::real_scalar Scalar>
40public:
42 using scalar_type = Scalar;
43
48
54 void compute(const node_ptr<scalar_type>& top_node) {
55 NUM_COLLECT_ASSERT(top_node);
56 list_nodes(top_node);
57 compute_coeffs(top_node);
58 }
59
66 [[nodiscard]] auto coeff(const node_ptr<scalar_type>& node) const
67 -> scalar_type {
68 if (const auto iter = info_dict_.find(node); iter != info_dict_.end()) {
69 return iter->second.diff;
70 }
71 return static_cast<scalar_type>(0);
72 }
73
74private:
76 struct node_info {
78 scalar_type diff{static_cast<scalar_type>(0)};
79
82 };
83
89 void list_nodes(const node_ptr<scalar_type>& top_node) {
90 info_dict_.clear();
91 node_queue_ = std::queue<node_ptr<scalar_type>>();
92
93 info_dict_.try_emplace(top_node);
94 node_queue_.push(top_node);
95 while (!node_queue_.empty()) {
96 const auto& node = node_queue_.front();
97 for (const auto& child_node : node->children()) {
98 const auto [iter, inserted] =
99 info_dict_.try_emplace(child_node.node());
100 if (inserted) {
102 }
103 ++iter->second.ref_count;
104 }
105 node_queue_.pop();
106 }
107 }
108
115 void compute_coeffs(const node_ptr<scalar_type>& top_node) {
116 info_dict_[top_node].diff = static_cast<scalar_type>(1);
117 node_queue_.push(top_node);
118 while (!node_queue_.empty()) {
119 const auto& node = node_queue_.front();
120 const auto& info = info_dict_[node];
121 for (const auto& child_node : node->children()) {
122 auto& child_info = info_dict_[child_node.node()];
123 child_info.diff += info.diff * child_node.sensitivity();
124 --child_info.ref_count;
125 if (child_info.ref_count == 0) {
127 }
128 }
129 node_queue_.pop();
130 }
131 }
132
134 std::unordered_map<node_ptr<scalar_type>, node_info> info_dict_{};
135
137 std::queue<node_ptr<scalar_type>> node_queue_{};
138};
139
140} // namespace num_collect::auto_diff::backward::graph
Definition of assertion macros.
#define NUM_COLLECT_ASSERT(CONDITION)
Macro to check whether a condition is satisfied.
Definition assert.h:66
class to save information of child nodes.
Definition node.h:48
auto node() const noexcept -> const node_ptr< scalar_type > &
Get the child node.
Definition node.h:70
auto sensitivity() const noexcept -> const scalar_type &
Get the partial differential coefficient of the parent node by the child node.
Definition node.h:81
Class to compute differential coefficients for nodes in backward-mode automatic differentiation kubot...
void list_nodes(const node_ptr< scalar_type > &top_node)
List nodes.
void compute(const node_ptr< scalar_type > &top_node)
Compute differential coefficients.
std::unordered_map< node_ptr< scalar_type >, node_info > info_dict_
Dictionary of information of nodes.
auto coeff(const node_ptr< scalar_type > &node) const -> scalar_type
Get the differential coefficient of a node.
std::queue< node_ptr< scalar_type > > node_queue_
Queue of remaining nodes.
void compute_coeffs(const node_ptr< scalar_type > &top_node)
Compute the differential coefficients using the information generated by list_nodes function.
Class of nodes in graphs of automatic differentiation.
Definition node.h:99
auto children() const noexcept -> const std::vector< child_node< scalar_type > > &
Get the child nodes.
Definition node.h:118
Definition of index_type type.
Namespace of graphs in backward-mode automatic differentiation.
Definition node.h:29
std::shared_ptr< const node< Scalar > > node_ptr
Type of pointers of nodes.
Definition node.h:40
std::ptrdiff_t index_type
Type of indices in this library.
Definition index_type.h:33
Definition of node class.
Definition of real_scalar concept.