numerical-collection-cpp 0.10.0
A collection of algorithms in numerical analysis implemented in C++
Loading...
Searching...
No Matches
variable.h
Go to the documentation of this file.
1/*
2 * Copyright 2021 MusicScience37 (Kenta Kabashima)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20#pragma once
21
22#include <utility> // IWYU pragma: keep
23
24#include <Eigen/Core>
25
28
30
34struct constant_tag {};
35
39struct variable_tag {};
40
47template <base::concepts::real_scalar Scalar>
48class variable {
49public:
51 using scalar_type = Scalar;
52
61
67 variable(const scalar_type& value, const constant_tag& /*tag*/)
68 : variable(value, nullptr) {}
69
75 variable(const scalar_type& value, const variable_tag& /*tag*/)
76 : variable(value, graph::create_node<scalar_type>()) {}
77
83 variable(const scalar_type& value) // NOLINT: implicit conversion required
85
89 variable() : variable(static_cast<scalar_type>(0)) {}
90
96 [[nodiscard]] auto value() const noexcept -> const scalar_type& {
97 return value_;
98 }
99
105 [[nodiscard]] auto node() const noexcept
106 -> const graph::node_ptr<scalar_type>& {
107 return node_;
108 }
109
115 auto operator-() const -> variable {
116 if (node_) {
117 return variable(-value_,
119 node_, static_cast<scalar_type>(-1)));
120 }
121 return variable(-value_, nullptr);
122 }
123
130 auto operator+=(const variable& right) -> variable& {
131 if (this == &right) {
133 node_, static_cast<scalar_type>(2));
134 value_ += right.value_;
135 return *this;
136 }
137 if (node_) {
138 if (right.node_) {
140 static_cast<scalar_type>(1), right.node_,
141 static_cast<scalar_type>(1));
142 }
143 } else {
144 node_ = right.node_;
145 }
146 value_ += right.value_;
147 return *this;
148 }
149
156 auto operator-=(const variable& right) -> variable& {
157 if (this == &right) {
158 value_ -= right.value_;
159 node_.reset();
160 return *this;
161 }
162 if (node_) {
163 if (right.node_) {
165 static_cast<scalar_type>(1), right.node_,
166 static_cast<scalar_type>(-1));
167 }
168 } else {
169 if (right.node_) {
171 right.node_, static_cast<scalar_type>(-1));
172 }
173 }
174 value_ -= right.value_;
175 return *this;
176 }
177
184 auto operator*=(const variable& right) -> variable& {
185 if (this == &right) {
187 node_, static_cast<scalar_type>(2) * value_);
188 value_ *= right.value_;
189 return *this;
190 }
191 if (node_) {
192 if (right.node_) {
194 node_, right.value_, right.node_, value_);
195 } else {
197 }
198 } else {
199 if (right.node_) {
201 }
202 }
203 value_ *= right.value_;
204 return *this;
205 }
206
213 auto operator/=(const variable& right) -> variable& {
214 if (this == &right) {
215 value_ /= right.value_;
216 node_.reset();
217 return *this;
218 }
219 value_ /= right.value_;
220 if (node_) {
221 if (right.node_) {
223 static_cast<scalar_type>(1) / right.value_, right.node_,
224 -value_ / right.value_);
225 } else {
227 node_, static_cast<scalar_type>(1) / right.value_);
228 }
229 } else {
230 if (right.node_) {
232 right.node_, -value_ / right.value_);
233 }
234 }
235 return *this;
236 }
237
238private:
241
244};
245
254template <typename Scalar>
255[[nodiscard]] inline auto operator+(const variable<Scalar>& left,
256 const variable<Scalar>& right) -> variable<Scalar> {
257 return variable<Scalar>(left) += right;
258}
259
268template <typename Scalar>
269[[nodiscard]] inline auto operator+(
270 const Scalar& left, const variable<Scalar>& right) -> variable<Scalar> {
271 return variable<Scalar>(left) += right;
272}
273
282template <typename Scalar>
283[[nodiscard]] inline auto operator+(
284 const variable<Scalar>& left, const Scalar& right) -> variable<Scalar> {
285 return variable<Scalar>(left) += right;
286}
287
296template <typename Scalar>
297[[nodiscard]] inline auto operator-(const variable<Scalar>& left,
298 const variable<Scalar>& right) -> variable<Scalar> {
299 return variable<Scalar>(left) -= right;
300}
301
310template <typename Scalar>
311[[nodiscard]] inline auto operator-(
312 const Scalar& left, const variable<Scalar>& right) -> variable<Scalar> {
313 return variable<Scalar>(left) -= right;
314}
315
324template <typename Scalar>
325[[nodiscard]] inline auto operator-(
326 const variable<Scalar>& left, const Scalar& right) -> variable<Scalar> {
327 return variable<Scalar>(left) -= right;
328}
329
338template <typename Scalar>
339[[nodiscard]] inline auto operator*(const variable<Scalar>& left,
340 const variable<Scalar>& right) -> variable<Scalar> {
341 return variable<Scalar>(left) *= right;
342}
343
352template <typename Scalar>
353[[nodiscard]] inline auto operator*(
354 const Scalar& left, const variable<Scalar>& right) -> variable<Scalar> {
355 return variable<Scalar>(left) *= right;
356}
357
366template <typename Scalar>
367[[nodiscard]] inline auto operator*(
368 const variable<Scalar>& left, const Scalar& right) -> variable<Scalar> {
369 return variable<Scalar>(left) *= right;
370}
371
380template <typename Scalar>
381[[nodiscard]] inline auto operator/(const variable<Scalar>& left,
382 const variable<Scalar>& right) -> variable<Scalar> {
383 return variable<Scalar>(left) /= right;
384}
385
394template <typename Scalar>
395[[nodiscard]] inline auto operator/(
396 const Scalar& left, const variable<Scalar>& right) -> variable<Scalar> {
397 return variable<Scalar>(left) /= right;
398}
399
408template <typename Scalar>
409[[nodiscard]] inline auto operator/(
410 const variable<Scalar>& left, const Scalar& right) -> variable<Scalar> {
411 return variable<Scalar>(left) /= right;
412}
413
414} // namespace num_collect::auto_diff::backward
415
416namespace Eigen {
417
429template <typename Scalar>
430struct NumTraits<num_collect::auto_diff::backward::variable<Scalar>> {
433
436
438 using Literal = Real;
439
441 using Nested = Real;
442
443 enum { // NOLINT(performance-enum-size): Preserve the same implementation as Eigen library.
445 IsInteger = 0, // NOLINT
446
448 IsSigned = 1, // NOLINT
449
451 IsComplex = 0, // NOLINT
452
454 RequireInitialization = 1, // NOLINT
455
457 ReadCost = 1, // NOLINT
458
460 AddCost = 2, // NOLINT
461
463 MulCost = 4 // NOLINT
464 };
465
471 static constexpr auto epsilon() -> Real {
472 return NumTraits<Scalar>::epsilon();
473 }
474
480 static constexpr auto dummy_precision() -> Real {
481 return NumTraits<Scalar>::dummy_precision();
482 }
483
489 static constexpr auto highest() -> Real {
490 return NumTraits<Scalar>::highest();
491 }
492
498 static constexpr auto lowest() -> Real {
499 return NumTraits<Scalar>::lowest();
500 }
501
507 static constexpr auto digits10() -> int {
508 return NumTraits<Scalar>::digits10();
509 }
510
516 static constexpr auto infinity() -> Real {
517 return NumTraits<Scalar>::infinity();
518 }
519
525 static constexpr auto quiet_NaN() -> Real { // NOLINT
526 return NumTraits<Scalar>::quiet_NaN();
527 }
528};
529
530} // namespace Eigen
Class of variables in backward-mode automatic differentiation kubota1998.
Definition variable.h:48
auto operator/=(const variable &right) -> variable &
Divide this variable by a variable.
Definition variable.h:213
auto operator-() const -> variable
Negate this variable.
Definition variable.h:115
auto operator*=(const variable &right) -> variable &
Multiply this variable by a variable.
Definition variable.h:184
graph::node_ptr< scalar_type > node_
Node.
Definition variable.h:243
auto operator-=(const variable &right) -> variable &
Subtract a variable from this variable.
Definition variable.h:156
auto operator+=(const variable &right) -> variable &
Add a variable to this variable.
Definition variable.h:130
variable(const scalar_type &value, const variable_tag &)
Construct variables.
Definition variable.h:75
auto node() const noexcept -> const graph::node_ptr< scalar_type > &
Get the node.
Definition variable.h:105
variable(const scalar_type &value, const constant_tag &)
Construct constants.
Definition variable.h:67
variable(const scalar_type &value)
Construct constants.
Definition variable.h:83
auto value() const noexcept -> const scalar_type &
Get the value.
Definition variable.h:96
variable(const scalar_type &value, graph::node_ptr< scalar_type > node)
Constructor.
Definition variable.h:59
Namespace of Eigen library.
Definition variable.h:416
std::shared_ptr< const node< Scalar > > node_ptr
Type of pointers of nodes.
Definition node.h:40
auto create_node(Args &&... args) -> node_ptr< Scalar >
Create a node.
Definition node.h:174
Namespace of backward-mode automatic differentiation.
auto operator*(const variable< Scalar > &left, const variable< Scalar > &right) -> variable< Scalar >
Multiply two variables.
Definition variable.h:339
auto operator+(const variable< Scalar > &left, const variable< Scalar > &right) -> variable< Scalar >
Add two variables.
Definition variable.h:255
auto operator-(const variable< Scalar > &left, const variable< Scalar > &right) -> variable< Scalar >
Subtract a variable from another variable.
Definition variable.h:297
auto operator/(const variable< Scalar > &left, const variable< Scalar > &right) -> variable< Scalar >
Divide a variable from another variable.
Definition variable.h:381
Namespace of num_collect source codes.
STL namespace.
Definition of node class.
Definition of real_scalar concept.
static constexpr auto infinity() -> Real
Get the infinity.
Definition variable.h:516
static constexpr auto highest() -> Real
Get the highest value.
Definition variable.h:489
static constexpr auto digits10() -> int
Get the number of digits.
Definition variable.h:507
static constexpr auto epsilon() -> Real
Get machine epsilon.
Definition variable.h:471
static constexpr auto lowest() -> Real
Get the lowest value.
Definition variable.h:498
static constexpr auto dummy_precision() -> Real
Get dummy precision.
Definition variable.h:480
static constexpr auto quiet_NaN() -> Real
Get the quiet NaN value.
Definition variable.h:525
Tag class to specify constants.
Definition variable.h:34
Tag class to specify variables.
Definition variable.h:39