LILAC
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
rk45_tmpl.hpp
Go to the documentation of this file.
1 /*
2 Copyright (c) 2014, Sam Schetterer, Nathan Kutz, University of Washington
3 Authors: Sam Schetterer
4 All rights reserved.
5 
6 Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 
8 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 
10 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11 
12 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13 
14 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15 
16 */
17 #ifndef RK45_TMPL
18 #define RK45_TMPL
19 #include "integrator.h"
20 #include "utils/comp_funcs.hpp"
21 #include "rk45.h"
22 #include "types/float_traits.hpp"
24 
29 template<class T>
30 class rk45_tmpl:public rk45 {
31  typedef typename float_traits<T>::type real_type;
32  protected:
33  //since these values aren't used to much in actual computation
34  //it's fine to just have them be doubles and easily retrieve them
35  double dt_init, dt_min, dt_max;
36  double relerr, abserr;
37  T* restr f0, * restr f1, * restr f2, * restr f3;
38  T* restr f4, * restr f5, * restr f6;
39  T* restr u_calc;
41  public:
42  virtual const std::type_info& vtype() const;
43  void postprocess(input& dat);
44  std::string type() const;
45  int integrate(ptr_passer u, double t0, double tf);
46  ~rk45_tmpl();
47 };
48 template<class T>
49 const std::type_info& rk45_tmpl<T>::vtype() const {
50  return typeid(T);
51 }
52 template<class T>
53 int rk45_tmpl<T>::integrate(ptr_passer _u0, double t0, double tf){
54  T* restr u0 = _u0;
55  size_t num_fail = 0;
56  real_type dt = dt_init;
57  real_type dtave=0;
58  size_t steps=0;
59 
60 
61  const real_type magic_power = 1.0/6; //found from reference file, reference p.91 Ascher and Petzold
62  const real_type magic_mult = .8;//magic multiplying safety factor for determining the next timestep;
63  // dorman prince integrator parameters
64  MAKE_ALIGNED const static real_type a[] = {0.2, 0.3, 0.8, 8.0/9, 1, 1};
65  MAKE_ALIGNED const static real_type b1 = 0.2;
66  MAKE_ALIGNED const static real_type b2[] = {3.0/40, 9.0/40};
67  MAKE_ALIGNED const static real_type b3[] = {44.0/45, -56.0/15, 32.0/9};
68  MAKE_ALIGNED const static real_type b4[] = {19372.0/6561, -25360.0/2187, 64448.0/6561, -212.0/729};
69  MAKE_ALIGNED const static real_type b5[] = {9017.0/3168, -355.0/33, 46732.0/5247, 49.0/176, -5103.0/18656};
70  MAKE_ALIGNED const static real_type b6[] = {35.0/384, 0, 500.0/1113, 125.0/192, -2187.0/6784, 11.0/84};
71  MAKE_ALIGNED const static real_type c4[] = {5179.0/57600, 0, 7571.0/16695, 393.0/640,
72  -92097.0/339200, 187.0/2100, 1.0/40};
73  MAKE_ALIGNED const static real_type c5[] = {35.0/384, 0, 500.0/1113, 125.0/192, -2187.0/6784, 11.0/84, 0};
74 
75  real_type taui;
76  taui = 0;
77  real_type tcur=t0;
78  real_type tauv = 0;
79  for(size_t i = 0; i < dimension; i++){
80  real_type taum = abs(u0[i]);
81  if(taum > tau){
82  taui = taum;
83  }
84  }
85  ALIGNED(u0);
86  ALIGNED(u_calc);
87  ALIGNED(u_calc2);
88  ALIGNED(f0);
89  ALIGNED(f1);
90  ALIGNED(f2);
91  ALIGNED(f3);
92  ALIGNED(f4);
93  ALIGNED(f5);
94  ALIGNED(f6);
95  tauv = taui;
96  tauv *= relerr;
97  T* restr tmp, * restr swp, * restr u0hld;//this is used for freeing the memory later
98  T* restr tmp6, * restr tmpc;
99  tmp6 = f6;
100  tmpc = u_calc;
101  u0hld=u0;
102  //used to avoid memory problems later on
103  //allows for easy switching of pointers to avoid memory copies
104 
105  tmp=f0;
106  //removes the need to check if the pointers should be swapped or not
107  if((tcur + dt) > tf){
108  dt = tf - tcur;
109  err("t0-tf <= dt, consider decreasing the initial timestep. Otherwise mileage may vary",
110  "rk45::integrate", "integrator/rh45.cpp", (item*)rh_val, WARNING);
111  }
112  int tries=0;
113 
114  //while this seems odd, it helps streamline the inner integrator loop
115  rh_val->dxdt(u0, f6, tcur);
116  while(tcur < tf){
117  //this function=lol at compiler loop unrolling
118  if((tcur + dt) > tf){
119  break;
120  }
121  tries++;
122  swp=f6;
123  f6=f0;
124  f0=swp;
125  real_type ts = tcur;
126 
127  for(size_t i = 0; i < dimension; i++){
128  u_calc[i] = u0[i] + b1*f0[i]*dt;
129  }
130  tcur = ts + a[0]*dt;
131  rh_val->dxdt(u_calc, f1, tcur);
132 
133  for(size_t i = 0; i < dimension; i++){
134  u_calc[i] = u0[i] + dt*(b2[0]*f0[i] + b2[1] * f1[i]);
135  }
136  tcur = ts + a[1]*dt;
137  rh_val->dxdt(u_calc, f2, tcur);
138 
139  for(size_t i = 0; i < dimension; i++){
140  u_calc[i] = u0[i] + dt*(b3[0]*f0[i] + b3[1] * f1[i] + b3[2]*f2[i]);
141  }
142  tcur = ts + a[2]*dt;
143  rh_val->dxdt(u_calc, f3, tcur);
144 
145  for(size_t i = 0; i < dimension; i++){
146  u_calc[i] = u0[i] + dt*(b4[0]*f0[i] + b4[1] * f1[i] + b4[2]*f2[i] + b4[3]*f3[i]);
147  }
148  tcur = ts + a[3]*dt;
149  rh_val->dxdt(u_calc, f4, tcur);
150 
151  for(size_t i = 0; i < dimension; i++){
152  u_calc[i] = u0[i] + dt*(b5[0]*f0[i] + b5[1] * f1[i] + b5[2]*f2[i] +
153  b5[3]*f3[i] + b5[4]*f4[i]);
154  }
155  tcur = ts + a[4]*dt;
156  rh_val->dxdt(u_calc, f5, tcur);
157 
158  for(size_t i = 0; i < dimension; i++){
159  u_calc[i] = u0[i] + dt*(b6[0]*f0[i] + b6[1] * f1[i] + b6[2]*f2[i] +
160  b6[3]*f3[i] + b6[4]*f4[i] + b6[5]*f5[i]);
161  }
162  tcur = ts + a[5]*dt;
163  rh_val->dxdt(u_calc, f6, tcur);
164  tcur = ts;
165  //calculate the magnitude of the error, and the fourth/5th order arrays
166  //The references that I found are for real values
167  //so I substitute the magnitude of complex differences for the absolute real difference
168  //not sure if this is the way to go, but sounds about right
169  //internally, I do calculations with the squares to avoid calculations of sqrt
170  //This is done with the inf norm
171  real_type delta=1E-12;
172 
173  for(size_t i = 0; i < dimension; i++){
174  u_calc[i] = u0[i] + dt*(c5[0]*f0[i] + c5[1] * f1[i] + c5[2] * f2[i] + c5[3] * f3[i] +
175  c5[4]*f4[i] + c5[5]*f5[i] + c5[6]*f6[i]);
176 
177  u_calc2[i] = abs(u0[i] + dt*(c4[0]*f0[i] + c4[1] * f1[i] + c4[2] * f2[i] + c4[3] * f3[i] +
178  c4[4]*f4[i] + c4[5]*f5[i] + c4[6]*f6[i]) - u_calc[i]);
179  }
180 
181  //seperate these two loops to allow vectorization of the first
182  for(size_t i = 0; i < dimension; i++){
183  real_type deltam = u_calc2[i];//square of absolute error
184  if(deltam>delta){
185  delta=deltam;
186  }
187  }
188  //estimated optimal dt, smallest of the two options to ensure accuracy
189  real_type dt_last = dt;
190  dt =dt*magic_mult*std::pow(tauv/delta, magic_power);
191  if(dt < dt_min){
192  num_fail++;
193  if(num_fail > 10){
194  err("Estimated timestep is smaller than minimum timestep too many times, exiting",
195  "rk45::integrate", "integrator/rk45.cpp", (item*)rh_val, WARNING);
196  f0=tmp;
197  f6 = tmp6;
198  u_calc=tmpc;
199  return -1;
200  }
201  //std::cout << "dt= " << dt << ", dt_min=" << dt_min<<std::endl;
202  dt = dt_min;
203  }
204  else{
205  num_fail = 0;
206  }
207  //since dt_max is always a double, explcit specialization is needed
208  dt = std::min<real_type>(dt, dt_max);
209  if(delta>=tauv){
210  swp=f6;
211  f6=f0;
212  f0=swp;
213  continue;
214  }
215  tcur += dt_last;
216  if((tcur + dt) > tf){
217  dt = tf-tcur;
218  }
219  for(size_t i = 0; i < dimension; i++){
220  real_type val = abs(u_calc[i]);
221  if(val > taui){
222  taui = val;
223  }
224  }
225  steps++;
226  dtave+=dt_last;
227  tauv=taui;
228  tauv *= relerr;
229  tries=0;
230  //error is acceptable, continue on
231  swp = u_calc;
232  u_calc = u0;
233  u0 = swp;
234  //instead of copying u_calc to u_0 swap pointers
235  }
236  //check if a memory copy needs to be done on u0
237  //u0 currently points to the most recent update
238  //if u0 does not point to where it originally did, re-copy the memory over
239  if(u0 != u0hld){
240  for(size_t i = 0; i < dimension; i++){
241  u0hld[i] = u0[i];
242  }
243  }
244  //u0, or at least that original address, now holds the final integration values
245  f0=tmp;
246  f6 = tmp6;
247  u_calc=tmpc;
248  return 0;
249 }
250 template<class T>
252 }
253 
254 
255 template<class T>
257  return std::string("rk45_tmpl<") + this->vname() + ">";
258 }
259 template<class T>
262  if(!rh_val->compare<T>()){
263  err("Bad rhs type passed to rk45 integrator", "rk45_tmpl::postprocess",
264  "rhs/rk45_tmpl.h", FATAL_ERROR);
265  }
266  this->add_as_parent(rh_val);
267  dat.retrieve(dt_init, "dt_init", this);
268  if(dt_init <= 0){
269  err("dt_init is invalid, must be >= 0", "rk45::postprocess",
270  "integrator/rk45.cpp", dat["dt_init"], FATAL_ERROR);
271  }
272  dat.retrieve(dt_min, "dt_min", this);
273  if(dt_min <= 0){
274  err("dt_min is invalid, must be >= 0", "rk45::postprocess",
275  "integrator/rk45.cpp", dat["dt_min"], FATAL_ERROR);
276  }
277  dat.retrieve(dt_max, "dt_max", this);
278  if(dt_max <= 0){
279  err("dt_max is invalid, must be >= 0", "rk45::postprocess",
280  "integrator/rk45.cpp", dat["dt_max"], FATAL_ERROR);
281  }
282  dat.retrieve(relerr, "relerr", this);
283  if(relerr <= 0){
284  err("relerr is invalid, must be >= 0", "rk45::postprocess",
285  "integrator/rk45.cpp", dat["relerr"], FATAL_ERROR);
286  }
287  dat.retrieve(abserr, "abserr", this);
288  if(abserr <= 0){
289  err("abserr is invalid, must be >= 0", "rk45::postprocess",
290  "integrator/rk45.cpp", dat["abserr"], FATAL_ERROR);
291  }
292  memp.create(dimension, &f0, &f1, &f2, &f3, &f4, &f5, &f6, &u_calc, &u_calc2);
293 }
294 
295 #endif