• Main Page
  • Namespaces
  • Data Structures
  • Files
  • File List

/export/development/ViennaMath/viennamath/runtime/variable.hpp

Go to the documentation of this file.
00001 #ifndef VIENNAMATH_RUNTIME_VARIABLE_HPP
00002 #define VIENNAMATH_RUNTIME_VARIABLE_HPP
00003 
00004 /* =======================================================================
00005    Copyright (c) 2012, Institute for Microelectronics,
00006                        Institute for Analysis and Scientific Computing,
00007                        TU Wien.
00008                              -----------------
00009                ViennaMath - Symbolic and Numerical Math in C++
00010                              -----------------
00011 
00012    Author:     Karl Rupp                          rupp@iue.tuwien.ac.at
00013 
00014    License:    MIT (X11), see file LICENSE in the ViennaMath base directory
00015 ======================================================================= */
00016 
00017 
00018 
00019 
00020 #include "viennamath/compiletime/ct_vector.hpp"
00021 #include "viennamath/exception.hpp"
00022 
00023 #include "viennamath/runtime/constant.hpp"
00024 #include "viennamath/runtime/expression_interface.hpp"
00025 
00026 #include <assert.h>
00027 
00033 namespace viennamath
00034 {
00035 
00036   
00041   template <typename VectorType>
00042   typename VectorType::value_type get_from_vector(VectorType const & vec, id_type i)
00043   {
00044     return vec[i];
00045   }
00046   
00051   template <typename T1>
00052   default_numeric_type get_from_vector(ct_vector_1<T1> const & vec, id_type i)
00053   {
00054     return vec[ct_index<0>()]; 
00055   }
00056   
00061   template <typename T1, typename T2>
00062   default_numeric_type get_from_vector(ct_vector_2<T1, T2> const & vec, id_type i)
00063   {
00064     if (i == 0)
00065       return vec[ct_index<0>()];
00066     else if (i == 1)
00067       return vec[ct_index<1>()];
00068     return 0;
00069   }
00070 
00075   template <typename T1, typename T2, typename T3>
00076   default_numeric_type get_from_vector(ct_vector_3<T1, T2, T3> const & vec, id_type i)
00077   {
00078     if (i == 0)
00079       return vec[ct_index<0>()];
00080     else if (i == 1)
00081       return vec[ct_index<1>()];
00082     else if (i == 2)
00083       return vec[ct_index<2>()];
00084     return 0;
00085   }
00086 
00087   
00093   template <typename InterfaceType /* see forwards.h for default argument */>
00094   class rt_variable : public InterfaceType
00095   {
00096       typedef rt_variable<InterfaceType>                 self_type;
00097     
00098     public:
00099       typedef typename InterfaceType::numeric_type    numeric_type;
00100       
00101       explicit rt_variable(id_type my_id) : id_(my_id) {};
00102 
00103       id_type id() const { return id_; }
00104       
00106       
00108       numeric_type operator()(numeric_type value) const
00109       {
00110         assert(id_ == 0 && "Evaluation of variable with nonzero index by a scalar attempted!");
00111         
00112         return value;
00113       }
00114 
00116       template <typename ScalarType>
00117       rt_constant<ScalarType> operator()(rt_constant<ScalarType> const & other) const
00118       {
00119         if (id_ > 0)
00120           throw variable_index_out_of_bounds_exception(id_, 0);
00121         return rt_constant<ScalarType>(static_cast<ScalarType>(other));
00122       }
00123 
00125       template <long value>
00126       long operator()(ct_constant<value> const & other) const
00127       {
00128         if (id_ > 0)
00129           throw variable_index_out_of_bounds_exception(id_, 0);
00130         return value;
00131       }
00132 
00133       //Vector argument (can be of type std::vector)
00134       
00136       template <typename VectorType>
00137       numeric_type 
00138       operator()(VectorType const & v) const
00139       {
00140         if(id_ >= v.size())
00141           throw variable_index_out_of_bounds_exception(id_, v.size());
00142         return get_from_vector(v, id_);
00143       }
00144       
00145       //
00146       // interface requirements:
00147       //
00148       
00151       InterfaceType * clone() const { return new variable(id_); }
00152       
00154       numeric_type eval(std::vector<double> const & v) const
00155       {
00156         if (id_ >= v.size())
00157           throw variable_index_out_of_bounds_exception(id_, v.size());
00158         
00159         return (*this)(v);
00160       }
00161       
00163       numeric_type eval(numeric_type val) const
00164       {
00165         if (id_ > 0)
00166           throw variable_index_out_of_bounds_exception(id_, 1);
00167         
00168         return val;
00169       }
00170       
00172       std::string deep_str() const
00173       {
00174         std::stringstream ss;
00175         ss << "variable(" << id_ << ")";
00176         return ss.str();      
00177       }
00178       
00180       numeric_type unwrap() const
00181       {
00182         throw expression_not_unwrappable_exception();
00183         return 0;
00184       }
00185         
00186       //protected:
00188       InterfaceType * substitute(const InterfaceType * e,
00189                                  const InterfaceType * repl) const
00190       {
00191         //std::cout << "Comparing variable<" << id << "> with " << e->str() << ", result: ";
00192         if (deep_equal(e))
00193           return repl->clone();
00194         
00195         //std::cout << "FALSE" << std::endl;
00196         return clone();
00197       };    
00198 
00200       InterfaceType * substitute(std::vector<const InterfaceType *> const &  e,
00201                                  std::vector<const InterfaceType *> const &  repl) const
00202       {
00203         //std::cout << "Comparing variable<" << id << "> with " << e->str() << ", result: ";
00204         for (std::size_t i=0; i<e.size(); ++i)
00205           if (deep_equal(e[i]))
00206             return repl[i]->clone();
00207         
00208         //std::cout << "FALSE" << std::endl;
00209         return clone();
00210       };    
00211       
00213       bool deep_equal(const InterfaceType * other) const
00214       {
00215         const self_type * ptr = dynamic_cast< const self_type *>(other);
00216         if (ptr != NULL)
00217           return ptr->id() == id_;
00218 
00219         return false;
00220       }
00221 
00223       bool shallow_equal(const InterfaceType * other) const
00224       {
00225         return dynamic_cast< const self_type * >(other) != NULL;
00226       }
00227 
00229       InterfaceType * diff(const InterfaceType * diff_var) const
00230       {
00231         const rt_variable<InterfaceType> * ptr = dynamic_cast< const rt_variable<InterfaceType> *>(diff_var);
00232         if (ptr != NULL)
00233         {
00234           //std::cout << "diff variable<" << id << ">: TRUE" << std::endl;
00235           if (ptr->id() == id_)
00236             return new rt_constant<numeric_type, InterfaceType>(1);
00237         }
00238         //std::cout << "diff variable<" << id << ">: FALSE, is: " << diff_var.get()->str() << std::endl;
00239         return new rt_constant<numeric_type, InterfaceType>(0);
00240       }
00241     
00242     private: 
00243       id_type id_;
00244   }; //variable
00245 
00246 
00248   template <typename InterfaceType>
00249   std::ostream& operator<<(std::ostream & stream, rt_variable<InterfaceType> const & u)
00250   {
00251     stream << "variable(" << u.id() << ")";
00252     return stream;
00253   }
00254 
00255 
00256 }
00257 
00258 #endif

Generated on Wed Feb 29 2012 21:50:43 for ViennaMath by  doxygen 1.7.1