00001 #ifndef VIENNAMATH_RUNTIME_BINARY_EXPRESSION_HPP
00002 #define VIENNAMATH_RUNTIME_BINARY_EXPRESSION_HPP
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <ostream>
00021 #include <sstream>
00022 #include <memory>
00023 #include "viennamath/forwards.h"
00024 #include "viennamath/runtime/constant.hpp"
00025 #include "viennamath/runtime/binary_operators.hpp"
00026 #include "viennamath/compiletime/ct_binary_expr.hpp"
00027 #include "viennamath/runtime/unary_expr.hpp"
00028 #include "viennamath/runtime/op_interface.hpp"
00029 #include "viennamath/runtime/expression_interface.hpp"
00030 #include "viennamath/runtime/unary_operators.hpp"
00031
00036 namespace viennamath
00037 {
00038
00043 template <typename InterfaceType >
00044 class rt_binary_expr : public InterfaceType
00045 {
00046 typedef op_interface<InterfaceType> op_interface_type;
00047 typedef op_unary<op_id<typename InterfaceType::numeric_type>, InterfaceType> op_unary_id_type;
00048
00049 typedef rt_binary_expr<InterfaceType> self_type;
00050
00051 public:
00052 typedef typename InterfaceType::numeric_type numeric_type;
00053
00054
00055
00056 rt_binary_expr() {}
00057
00058 explicit rt_binary_expr(InterfaceType * lhs,
00059 op_interface_type * op,
00060 InterfaceType * rhs) : lhs_(lhs),
00061 op_(op),
00062 rhs_(rhs) {}
00063
00064 template <typename LHS, typename OP, typename RHS>
00065 explicit rt_binary_expr(ct_binary_expr<LHS, OP, RHS> const & other) : op_(new op_binary<OP, InterfaceType>())
00066 {
00067
00068 lhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.lhs()));
00069 rhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.rhs()));
00070 }
00071
00073 template <typename LHS, typename OP, long value>
00074 explicit rt_binary_expr(ct_binary_expr<LHS, OP, ct_constant<value> > const & other) : op_(new op_binary<OP, InterfaceType>())
00075 {
00076
00077 lhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.lhs()));
00078 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00079 }
00080
00081 template <long value, typename OP, typename RHS>
00082 explicit rt_binary_expr(ct_binary_expr<ct_constant<value>, OP, RHS > const & other) : op_(new op_binary<OP, InterfaceType>())
00083 {
00084
00085 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00086 rhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.rhs()));
00087 }
00088
00089 template <long value1, typename OP, long value2>
00090 explicit rt_binary_expr(ct_binary_expr<ct_constant<value1>, OP, ct_constant<value2> > const & other) : op_(new op_binary<OP, InterfaceType>())
00091 {
00092
00093 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value1));
00094 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value2));
00095 }
00096
00098 template <typename LHS, typename OP, id_type id>
00099 explicit rt_binary_expr(ct_binary_expr<LHS, OP, ct_variable<id> > const & other) : op_(new op_binary<OP, InterfaceType>())
00100 {
00101
00102 lhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.lhs()));
00103 rhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id));
00104 }
00105
00106 template <id_type id, typename OP, typename RHS>
00107 explicit rt_binary_expr(ct_binary_expr<ct_variable<id>, OP, RHS > const & other) : op_(new op_binary<OP, InterfaceType>())
00108 {
00109
00110 lhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id));
00111 rhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.rhs()));
00112 }
00113
00114 template <id_type id1, typename OP, id_type id2>
00115 explicit rt_binary_expr(ct_binary_expr<ct_variable<id1>, OP, ct_variable<id2> > const & other) : op_(new op_binary<OP, InterfaceType>())
00116 {
00117
00118 lhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id1));
00119 rhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id2));
00120 }
00121
00122
00123 template <id_type id, typename OP, long value>
00124 explicit rt_binary_expr(ct_binary_expr<ct_variable<id>, OP, ct_constant<value> > const & other) : op_(new op_binary<OP, InterfaceType>())
00125 {
00126
00127 lhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id));
00128 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00129 }
00130
00131 template <id_type id, typename OP, long value>
00132 explicit rt_binary_expr(ct_binary_expr<ct_constant<value>, OP, ct_variable<id> > const & other) : op_(new op_binary<OP, InterfaceType>())
00133 {
00134
00135 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00136 rhs_ = std::auto_ptr<InterfaceType>(new rt_variable<InterfaceType>(id));
00137 }
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155 rt_binary_expr(binary_expr const & other) : lhs_(other.lhs_->clone()),
00156 op_(other.op_->clone()),
00157 rhs_(other.rhs_->clone()) {}
00158
00159
00160 template <typename LHS, typename OP, typename RHS>
00161 rt_binary_expr & operator=(ct_binary_expr<LHS, OP, RHS> const & other)
00162 {
00163 lhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.lhs()));
00164 op_ = std::auto_ptr<op_interface_type>(new op_binary<OP, InterfaceType>());
00165 rhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.rhs()));
00166 return *this;
00167 }
00168
00169 template <typename LHS, typename OP, long value>
00170 rt_binary_expr & operator=(ct_binary_expr<LHS, OP, ct_constant<value> > const & other)
00171 {
00172
00173 lhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.lhs()));
00174 op_ = std::auto_ptr<op_interface_type>(new op_binary<OP, InterfaceType>());
00175 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00176 return *this;
00177 }
00178
00179 template <long value, typename OP, typename RHS>
00180 rt_binary_expr & operator=(ct_binary_expr<ct_constant<value>, OP, RHS > const & other)
00181 {
00182
00183 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00184 op_ = std::auto_ptr<op_interface_type>(new op_binary<OP, InterfaceType>());
00185 rhs_ = std::auto_ptr<InterfaceType>(new rt_binary_expr<InterfaceType>(other.rhs()));
00186 return *this;
00187 }
00188
00189 template <long value1, typename OP, long value2>
00190 rt_binary_expr & operator=(ct_binary_expr<ct_constant<value1>, OP, ct_constant<value2> > const & other)
00191 {
00192
00193 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(OP().apply(value1, value2)));
00194 op_ = std::auto_ptr<op_interface_type>(new op_unary_id_type());
00195 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(OP().apply(value1, value2)));
00196 return *this;
00197 }
00198
00199
00200 rt_binary_expr & operator=(rt_binary_expr const & other)
00201 {
00202 lhs_ = std::auto_ptr<InterfaceType>(other.lhs()->clone());
00203 op_ = std::auto_ptr<op_interface_type>(other.op()->clone());
00204 rhs_ = std::auto_ptr<InterfaceType>(other.rhs()->clone());
00205 return *this;
00206 }
00207
00208 template <typename ScalarType>
00209 rt_binary_expr & operator=(rt_constant<ScalarType> const & other)
00210 {
00211 lhs_ = std::auto_ptr<InterfaceType>(other.clone());
00212 op_ = std::auto_ptr<op_interface_type>(new op_unary_id_type());
00213 rhs_ = std::auto_ptr<InterfaceType>(other.clone());
00214 return *this;
00215 }
00216
00217 template <long value>
00218 rt_binary_expr & operator=(ct_constant<value> const & other)
00219 {
00220 return *this = value;
00221 }
00222
00223 rt_binary_expr & operator=(numeric_type value)
00224 {
00225 lhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00226 op_ = std::auto_ptr<op_interface_type>(new op_unary_id_type());
00227 rhs_ = std::auto_ptr<InterfaceType>(new rt_constant<numeric_type, InterfaceType>(value));
00228 return *this;
00229 }
00230
00231 const InterfaceType * lhs() const { return lhs_.get(); }
00232 const op_interface_type * op() const { return op_.get(); }
00233 const InterfaceType * rhs() const { return rhs_.get(); }
00234
00236
00237
00238 numeric_type operator()(numeric_type val) const
00239 {
00240 return this->eval(val);
00241 }
00242
00243 template <typename ScalarType>
00244 numeric_type operator()(rt_constant<ScalarType> val) const
00245 {
00246 return this->eval(static_cast<numeric_type>(val));
00247 }
00248
00249 template <long value>
00250 numeric_type operator()(ct_constant<value> val) const
00251 {
00252 return this->eval(value);
00253 }
00254
00255 template <typename VectorType>
00256 numeric_type operator()(VectorType const & v) const
00257 {
00258 std::vector<double> stl_v(v.size());
00259 for (std::size_t i=0; i<v.size(); ++i)
00260 stl_v[i] = v[i];
00261
00262 return this->eval(stl_v);
00263 }
00264
00265 numeric_type operator()(std::vector<numeric_type> const & stl_v) const
00266 {
00267 return this->eval(stl_v);
00268 }
00269
00270 template <typename T0>
00271 numeric_type operator()(viennamath::ct_vector_1<T0> const & v) const
00272 {
00273 std::vector<double> stl_v(1);
00274 stl_v[0] = v[ct_index<0>()];
00275 return this->eval(stl_v);
00276 }
00277
00278 template <typename T0, typename T1>
00279 numeric_type operator()(viennamath::ct_vector_2<T0, T1> const & v) const
00280 {
00281 std::vector<double> stl_v(2);
00282 stl_v[0] = v[ct_index<0>()];
00283 stl_v[1] = v[ct_index<1>()];
00284 return this->eval(stl_v);
00285 }
00286
00287 template <typename T0, typename T1, typename T2>
00288 numeric_type operator()(viennamath::ct_vector_3<T0, T1, T2> const & v) const
00289 {
00290 std::vector<double> stl_v(3);
00291 stl_v[0] = v[ct_index<0>()];
00292 stl_v[1] = v[ct_index<1>()];
00293 stl_v[2] = v[ct_index<2>()];
00294 return this->eval(stl_v);
00295 }
00296
00297
00298 numeric_type eval(std::vector<double> const & v) const
00299 {
00300 return op_->apply(lhs_.get()->eval(v), rhs_.get()->eval(v));
00301 }
00302
00303 numeric_type eval(numeric_type val) const
00304 {
00305 return op_->apply(lhs_.get()->eval(val), rhs_.get()->eval(val));
00306 }
00307
00309
00310
00312 InterfaceType * simplify() const
00313 {
00314 if (lhs_->is_constant() && rhs_->is_constant())
00315 return new rt_constant<numeric_type, InterfaceType>( unwrap() );
00316
00317
00318 return op_->simplify(lhs_.get(), rhs_.get());
00319 }
00320
00322 bool can_simplify() const
00323 {
00324 if (lhs_->is_constant() && rhs_->is_constant())
00325 {
00326
00327 return true;
00328 }
00329 return op_->can_simplify(lhs_.get(), rhs_.get());
00330 }
00331
00333
00334 InterfaceType * clone() const { return new rt_binary_expr(lhs_->clone(), op_->clone(), rhs_->clone()); }
00335
00337 std::string deep_str() const
00338 {
00339 std::stringstream ss;
00340 ss << "(";
00341 ss << lhs_->deep_str();
00342 ss << op_->str();
00343 ss << rhs_->deep_str();
00344 ss << ")";
00345 return ss.str();
00346 }
00347
00349 std::string shallow_str() const
00350 {
00351 return std::string("binary_expr");
00352 }
00353
00355 bool is_unary() const { return false; }
00356
00358 numeric_type unwrap() const
00359 {
00360
00361
00362 return op_->apply(lhs_->unwrap(), rhs_->unwrap());
00363 }
00364
00366 bool is_constant() const { return lhs_->is_constant() && rhs_->is_constant(); };
00367
00369 InterfaceType * substitute(const InterfaceType * e,
00370 const InterfaceType * repl) const
00371 {
00372 if (deep_equal(e))
00373 return repl->clone();
00374
00375 return new rt_binary_expr(lhs_->substitute(e, repl),
00376 op_->clone(),
00377 rhs_->substitute(e, repl) );
00378 };
00379
00381 InterfaceType * substitute(std::vector<const InterfaceType *> const & e,
00382 std::vector<const InterfaceType *> const & repl) const
00383 {
00384 for (size_t i=0; i<e.size(); ++i)
00385 if (deep_equal(e[i]))
00386 return repl[i]->clone();
00387
00388 return new rt_binary_expr(lhs_->substitute(e, repl),
00389 op_->clone(),
00390 rhs_->substitute(e, repl) );
00391 };
00392
00393
00395 bool deep_equal(const InterfaceType * other) const
00396 {
00397 if (dynamic_cast< const rt_binary_expr * >(other) != NULL)
00398 {
00399 const rt_binary_expr * temp = dynamic_cast< const rt_binary_expr * >(other);
00400 return lhs_->deep_equal(temp->lhs())
00401 && op_->equal(temp->op())
00402 && rhs_->deep_equal(temp->rhs());
00403 }
00404 return false;
00405
00406 }
00407
00409 bool shallow_equal(const InterfaceType * other) const
00410 {
00411 return dynamic_cast< const self_type * >(other) != NULL;
00412 }
00413
00415 InterfaceType * diff(const InterfaceType * diff_var) const
00416 {
00417 return op_->diff(lhs_.get(), rhs_.get(), diff_var);
00418 }
00419
00421 InterfaceType * recursive_manipulation(rt_manipulation_wrapper<InterfaceType> const & fw) const
00422 {
00423 if (fw.modifies(this))
00424 return fw(this);
00425
00426 return new rt_binary_expr(lhs_->recursive_manipulation(fw),
00427 op_->clone(),
00428 rhs_->recursive_manipulation(fw) );
00429 }
00430
00432 void recursive_traversal(rt_traversal_wrapper<InterfaceType> const & fw) const
00433 {
00434 if (fw.step_into(this))
00435 {
00436 lhs_->recursive_traversal(fw);
00437 fw(this);
00438 rhs_->recursive_traversal(fw);
00439 }
00440 else
00441 fw(this);
00442 }
00443
00444 private:
00446 std::auto_ptr<InterfaceType> lhs_;
00448 std::auto_ptr<op_interface_type> op_;
00450 std::auto_ptr<InterfaceType> rhs_;
00451 };
00452
00453
00455 template <typename InterfaceType>
00456 std::ostream& operator<<(std::ostream & stream, rt_binary_expr<InterfaceType> const & e)
00457 {
00458 stream << "expr"
00459 << e.deep_str()
00460 << "";
00461 return stream;
00462 }
00463
00464 template <typename T, typename InterfaceType>
00465 InterfaceType * op_unary<T, InterfaceType>::simplify(const InterfaceType * lhs,
00466 const InterfaceType * rhs) const
00467 {
00468 return new rt_binary_expr<InterfaceType>(lhs->clone(),
00469 new op_unary<T, InterfaceType>(),
00470 rhs->clone());
00471 }
00472
00473 }
00474
00475 #endif