OPAL (Object Oriented Parallel Accelerator Library)  2.2.0
OPAL
FFT.h
Go to the documentation of this file.
1 //
2 // IPPL FFT
3 //
4 // Copyright (c) 2008-2018
5 // Paul Scherrer Institut, Villigen PSI, Switzerland
6 // All rights reserved.
7 //
8 // OPAL is licensed under GNU GPL version 3.
9 //
10 
18 #ifndef IPPL_FFT_FFT_H
19 #define IPPL_FFT_FFT_H
20 
21 #include "FFT/FFTBase.h"
22 
23 #ifdef IPPL_DKS
24 #include "DKSOPAL.h"
25 #endif
26 
27 // forward declarations
28 //template <unsigned Dim> class FieldLayout;
30 template <class T, unsigned Dim> class BareField;
31 template <class T, unsigned Dim> class LField;
32 
33 
34 
38 class CCTransform {};
42 class RCTransform {};
46 class SineTransform {};
47 
51 template <class Transform, size_t Dim, class T>
52 class FFT : public FFTBase<Dim,T> {};
53 
57 template <size_t Dim, class T>
58 class FFT<CCTransform,Dim,T> : public FFTBase<Dim,T> {
59 
60 private:
61 #ifdef IPPL_DKS
62  DKSOPAL base;
63 #endif
64 
65 public:
66 
68  typedef std::complex<T> Complex_t;
72 
78  FFT(const Domain_t& cdomain, const bool transformTheseDims[Dim],
79  const bool& compressTemps=false);
80 
87 FFT(const Domain_t& cdomain, const bool& compressTemps=false)
89 
90  // construct array of axis lengths
91  int lengths[Dim];
92  size_t d;
93  for (d=0; d<Dim; ++d)
94  lengths[d] = cdomain[d].length();
95 
96  // construct array of transform types for FFT Engine, compute normalization
97  int transformTypes[Dim];
98  T& normFact = this->getNormFact();
99  normFact = 1.0;
100  for (d=0; d<Dim; ++d) {
101  transformTypes[d] = FFTBase<Dim,T>::ccFFT; // all transforms are complex-to-complex
102  normFact /= lengths[d];
103  }
104 
105 #ifdef IPPL_DKS
106 #ifdef IPPL_DKS_OPENCL
107  INFOMSG("Init DKS base opencl" << endl);
108  base.setAPI("OpenCL", 6);
109  base.setDevice("-gpu", 4);
110  base.initDevice();
111 
112 #endif
113 
114 #ifdef IPPL_DKS_CUDA
115  INFOMSG("Init DKS base cuda" << endl);
116  base.setAPI("Cuda", 4);
117  base.setDevice("-gpu", 4);
118  base.initDevice();
119 #endif
120 
121 #ifdef IPPL_DKS_MIC
122  INFOMSG("Init DKS base MIC" << endl);
123  base.setAPI("OpenMP", 6);
124  base.setDevice("-mic", 4);
125  base.initDevice();
126 #endif
127 #endif
128 
129  // set up FFT Engine
130  this->getEngine().setup(Dim, transformTypes, lengths);
131  // set up the temporary fields
132  setup();
133  }
134 
135 
136  // Destructor
137  ~FFT(void);
138 
146  void transform(int direction, ComplexField_t& f, ComplexField_t& g,
147  const bool& constInput=false);
151  void transform(const char* directionName, ComplexField_t& f,
152  ComplexField_t& g, const bool& constInput=false);
153 
156  void transform(int direction, ComplexField_t& f);
157 
158  void transform(const char* directionName, ComplexField_t& f) {
159  // invoke in-place transform function using direction name string
160  int direction = this->getDirection(directionName);
161 
162  // Check domain of incoming Field
163  const Layout_t& in_layout = f.getLayout();
164  const Domain_t& in_dom = in_layout.getDomain();
165  PAssert_EQ(this->checkDomain(this->getDomain(),in_dom), true);
166 
167  // Common loop iterate and other vars:
168  size_t d;
169  int idim; // idim loops over the number of transform dims.
170  int begdim, enddim; // beginning and end of transform dim loop
171  size_t nTransformDims = this->numTransformDims();
172  // Field* for temp Field management:
173  ComplexField_t* temp = &f;
174  // Local work array passed to FFT:
175  Complex_t* localdata;
176 
177  // Loop over the dimensions be transformed:
178  begdim = (direction == +1) ? 0 : (nTransformDims-1);
179  enddim = (direction == +1) ? nTransformDims : -1;
180  for (idim = begdim; idim != enddim; idim += direction) {
181 
182  // Now do the serial transforms along this dimension:
183 
184  bool skipTranspose = false;
185  // if this is the first transform dimension, we might be able
186  // to skip the transpose into the first temporary Field
187  if (idim == begdim) {
188  // get domain for comparison
189  const Domain_t& first_dom = tempLayouts_m[idim]->getDomain();
190  // check that zeroth axis is the same and is serial
191  // and that there are no guard cells
192  skipTranspose = ( (in_dom[0].sameBase(first_dom[0])) &&
193  (in_dom[0].length() == first_dom[0].length()) &&
194  (in_layout.getDistribution(0) == SERIAL) &&
196  }
197 
198  // if this is the last transform dimension, we might be able
199  // to skip the last temporary and transpose right into f
200  if (idim == enddim-direction) {
201  // get domain for comparison
202  const Domain_t& last_dom = tempLayouts_m[idim]->getDomain();
203  // check that zeroth axis is the same and is serial
204  // and that there are no guard cells
205  skipTranspose = ( (in_dom[0].sameBase(last_dom[0])) &&
206  (in_dom[0].length() == last_dom[0].length()) &&
207  (in_layout.getDistribution(0) == SERIAL) &&
209  }
210 
211  if (!skipTranspose) {
212  // transpose and permute to Field with transform dim first
213  (*tempFields_m[idim])[tempLayouts_m[idim]->getDomain()] =
214  (*temp)[temp->getLayout().getDomain()];
215 
216  // Compress out previous iterate's storage:
217  if (this->compressTemps() && temp != &f) *temp = 0;
218  temp = tempFields_m[idim]; // Field* management aid
219  }
220  else if (idim == enddim-direction && temp != &f) {
221  // last transform and we can skip the last temporary field
222  // so do the transpose here using f instead
223 
224  // transpose and permute to Field with transform dim first
225  f[in_dom] = (*temp)[temp->getLayout().getDomain()];
226 
227  // Compress out previous iterate's storage:
228  if (this->compressTemps()) *temp = 0;
229  temp = &f; // Field* management aid
230  }
231 
232 
233 
234  // Loop over all the Vnodes, working on the LField in each.
235  typename ComplexField_t::const_iterator_if l_i, l_end = temp->end_if();
236  for (l_i = temp->begin_if(); l_i != l_end; ++l_i) {
237 
238  // Get the LField
239  ComplexLField_t* ldf = (*l_i).second.get();
240  // make sure we are uncompressed
241  ldf->Uncompress();
242  // get the raw data pointer
243  localdata = ldf->getP();
244 
245  // Do 1D complex-to-complex FFT's on all the strips in the LField:
246  int nstrips = 1, length = ldf->size(0);
247  for (d=1; d<Dim; ++d) nstrips *= ldf->size(d);
248  for (int istrip=0; istrip<nstrips; ++istrip) {
249  // Do the 1D FFT:
250  this->getEngine().callFFT(idim, direction, localdata);
251  // advance the data pointer
252  localdata += length;
253  } // loop over 1D strips
254  } // loop over all the LFields
255 
256 
257  } // loop over all transformed dimensions
258 
259  // skip final assignment and compress if we used f as final temporary
260  if (temp != &f) {
261 
262  // Now assign back into original Field, and compress last temp's storage:
263  f[in_dom] = (*temp)[temp->getLayout().getDomain()];
264  if (this->compressTemps()) *temp = 0;
265 
266  }
267 
268  // Normalize:
269  if (direction == +1)
270  f *= Complex_t(this->getNormFact(), 0.0);
271  return;
272  }
273 private:
274 
279  void setup(void);
280 
288 
293 
294 };
295 
296 
300 template <size_t Dim, class T>
301 inline void
303  const char* directionName,
306  const bool& constInput)
307 {
308  int dir = this->getDirection(directionName);
309  transform(dir, f, g, constInput);
310  return;
311 }
312 
313 
317 template <class T>
318 class FFT<CCTransform,1U,T> : public FFTBase<1U,T> {
319 
320 public:
321 
322  // typedefs
324  typedef std::complex<T> Complex_t;
328 
329  // Constructors:
330 
336  FFT(const Domain_t& cdomain, const bool transformTheseDims[1U],
337  const bool& compressTemps=false);
344  FFT(const Domain_t& cdomain, const bool& compressTemps=false);
345 
346  // Destructor
347  ~FFT(void);
348 
356  void transform(int direction, ComplexField_t& f, ComplexField_t& g,
357  const bool& constInput=false);
361  void transform(const char* directionName, ComplexField_t& f,
362  ComplexField_t& g, const bool& constInput=false);
363 
367  void transform(int direction, ComplexField_t& f);
368  void transform(const char* directionName, ComplexField_t& f);
369 
370 private:
371 
376  void setup(void);
377 
382 
387 
388 };
389 
390 
391 // inline function definitions
392 
396 template <class T>
397 inline void
399  const char* directionName,
402  const bool& constInput)
403 {
404  int dir = this->getDirection(directionName);
405  transform(dir, f, g, constInput);
406  return;
407 }
408 
412 template <class T>
413 inline void
415  const char* directionName,
417 {
418  int dir = this->getDirection(directionName);
419  transform(dir, f);
420  return;
421 }
422 
423 
427 template <size_t Dim, class T>
428 class FFT<RCTransform,Dim,T> : public FFTBase<Dim,T> {
429 
430 private:
431 
432 public:
433 
434  // typedefs
438  typedef std::complex<T> Complex_t;
442 
443  // Constructors:
444 
450  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
451  const bool transformTheseDims[Dim], const bool& compressTemps=false);
452 
456  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
457  const bool& compressTemps=false, int serialAxes = 1);
458 
459  // Destructor
460  ~FFT(void);
461 
469  void transform(int direction, RealField_t& f, ComplexField_t& g,
470  const bool& constInput=false);
471  void transform(const char* directionName, RealField_t& f,
472  ComplexField_t& g, const bool& constInput=false);
473 
477 #ifdef IPPL_DKS
478  void transformDKSRC(int direction, RealField_t &f, void* real_ptr, void* comp_ptr,
479  DKSOPAL &dksbase, int streamId = -1, const bool& constInput=false);
480 #endif
481 
484  void transform(int direction, ComplexField_t& f, RealField_t& g,
485  const bool& constInput=false);
486  void transform(const char* directionName, ComplexField_t& f,
487  RealField_t& g, const bool& constInput=false);
488 
492 #ifdef IPPL_DKS
493  void transformDKSCR(int direction, RealField_t& g, void* real_ptr, void* comp_ptr,
494  DKSOPAL &dksbase, int streamId = -1, const bool& constInput=false);
495 #endif
496 
497 private:
498 
503  void setup(void);
504 
511 
516 
521 
526 
532 
537 };
538 
539 // Inline function definitions
540 
544 template <size_t Dim, class T>
545 inline void
547  const char* directionName,
550  const bool& constInput)
551 {
552  int dir = this->getDirection(directionName);
553  transform(dir, f, g, constInput);
554  return;
555 }
556 
560 template <size_t Dim, class T>
561 inline void
563  const char* directionName,
566  const bool& constInput)
567 {
568  int dir = this->getDirection(directionName);
569  transform(dir, f, g, constInput);
570  return;
571 }
572 
573 
577 template <class T>
578 class FFT<RCTransform,1U,T> : public FFTBase<1U,T> {
579 
580 public:
581 
582  // typedefs
586  typedef std::complex<T> Complex_t;
590 
591  // Constructors:
592 
599  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
600  const bool transformTheseDims[1U], const bool& compressTemps=false);
604  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
605  const bool& compressTemps=false);
606 
610  ~FFT(void);
611 
620  void transform(int direction, RealField_t& f, ComplexField_t& g,
621  const bool& constInput=false);
622  void transform(const char* directionName, RealField_t& f,
623  ComplexField_t& g, const bool& constInput=false);
624 
629  void transform(int direction, ComplexField_t& f, RealField_t& g,
630  const bool& constInput=false);
631  void transform(const char* directionName, ComplexField_t& f,
632  RealField_t& g, const bool& constInput=false);
633 
634 private:
635 
640  void setup(void);
641 
646 
651 
656 
661  // const Domain_t& complexDomain_m;
663 };
664 
668 template <class T>
669 inline void
671  const char* directionName,
674  const bool& constInput)
675 {
676  int dir = this->getDirection(directionName);
677  transform(dir, f, g, constInput);
678  return;
679 }
680 
684 template <class T>
685 inline void
687  const char* directionName,
690  const bool& constInput)
691 {
692  int dir = this->getDirection(directionName);
693  transform(dir, f, g, constInput);
694  return;
695 }
696 
700 template <size_t Dim, class T>
701 class FFT<SineTransform,Dim,T> : public FFTBase<Dim,T> {
702 
703 public:
704 
705  // typedefs
709  typedef std::complex<T> Complex_t;
713 
721  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
722  const bool transformTheseDims[Dim],
723  const bool sineTransformDims[Dim], const bool& compressTemps=false);
727  FFT(const Domain_t& rdomain, const Domain_t& cdomain,
728  const bool sineTransformDims[Dim], const bool& compressTemps=false);
736  FFT(const Domain_t& rdomain, const bool sineTransformDims[Dim],
737  const bool& compressTemps=false);
741  FFT(const Domain_t& rdomain, const bool& compressTemps=false);
742 
743  ~FFT(void);
744 
755  void transform(int direction, RealField_t& f, ComplexField_t& g,
756  const bool& constInput=false);
757  void transform(const char* directionName, RealField_t& f,
758  ComplexField_t& g, const bool& constInput=false);
759 
764  void transform(int direction, ComplexField_t& f, RealField_t& g,
765  const bool& constInput=false);
766  void transform(const char* directionName, ComplexField_t& f,
767  RealField_t& g, const bool& constInput=false);
768 
778  void transform(int direction, RealField_t& f, RealField_t& g,
779  const bool& constInput=false);
780  void transform(const char* directionName, RealField_t& f,
781  RealField_t& g, const bool& constInput=false);
782 
786  void transform(int direction, RealField_t& f);
787  void transform(const char* directionName, RealField_t& f);
788 
789 private:
790 
795  void setup(void);
796 
797 
798 
802  bool sineTransformDims_m[Dim];
803 
808 
817 
822 
827 
832 
837 };
838 
842 template <size_t Dim, class T>
843 inline void
845  const char* directionName,
848  const bool& constInput)
849 {
850  int dir = this->getDirection(directionName);
851  transform(dir, f, g, constInput);
852  return;
853 }
854 
858 template <size_t Dim, class T>
859 inline void
861  const char* directionName,
864  const bool& constInput)
865 {
866  int dir = this->getDirection(directionName);
867  transform(dir, f, g, constInput);
868  return;
869 }
870 
874 template <size_t Dim, class T>
875 inline void
877  const char* directionName,
880  const bool& constInput)
881 {
882  int dir = this->getDirection(directionName);
883  transform(dir, f, g, constInput);
884  return;
885 }
886 
890 template <size_t Dim, class T>
891 inline void
893  const char* directionName,
895 {
896  int dir = this->getDirection(directionName);
897  transform(dir, f);
898  return;
899 }
900 
901 
905 template <class T>
906 class FFT<SineTransform,1U,T> : public FFTBase<1U,T> {
907 
908 public:
909 
914 
922  FFT(const Domain_t& rdomain, const bool sineTransformDims[1U],
923  const bool& compressTemps=false);
927  FFT(const Domain_t& rdomain, const bool& compressTemps=false);
928 
929 
930  ~FFT(void);
931 
940  void transform(int direction, RealField_t& f, RealField_t& g,
941  const bool& constInput=false);
942  void transform(const char* directionName, RealField_t& f,
943  RealField_t& g, const bool& constInput=false);
944 
948  void transform(int direction, RealField_t& f);
949  void transform(const char* directionName, RealField_t& f);
950 
951 private:
952 
957  void setup(void);
958 
963 
968 
969 };
970 
974 template <class T>
975 inline void
977  const char* directionName,
980  const bool& constInput)
981 {
982  int dir = this->getDirection(directionName);
983  transform(dir, f, g, constInput);
984  return;
985 }
986 
990 template <class T>
991 inline void
993  const char* directionName,
995 {
996  int dir = this->getDirection(directionName);
997  transform(dir, f);
998  return;
999 }
1000 #include "FFT/FFT.hpp"
1001 #endif // IPPL_FFT_FFT_H
1002 
1003 // vi: set et ts=4 sw=4 sts=4:
1004 // Local Variables:
1005 // mode:c
1006 // c-basic-offset: 4
1007 // indent-tabs-mode:nil
1008 // End:
std::complex< T > Complex_t
Definition: FFT.h:438
Domain_t complexDomain_m
Definition: FFT.h:531
Layout_t & getLayout() const
Definition: BareField.h:130
void setup(unsigned numTransformDims, const int *transformTypes, const int *axisLengths)
Definition: fftpack_FFT.h:153
Definition: rbendmap.h:8
std::complex< T > Complex_t
Definition: FFT.h:709
const Domain_t * complexDomain_m
Definition: FFT.h:836
std::complex< T > Complex_t
Definition: FFT.h:68
const Domain_t & getDomain(void) const
get our domain
Definition: FFTBase.h:160
std::complex< T > Complex_t
Definition: FFT.h:586
RealField_t * tempRField_m
Definition: FFT.h:525
FFT(const Domain_t &cdomain, const bool &compressTemps=false)
Definition: FFT.h:87
BareField< Complex_t, Dim > ComplexField_t
Definition: FFT.h:69
FFTBase< Dim, T >::Domain_t Domain_t
Definition: FFT.h:71
LField< Complex_t, Dim > ComplexLField_t
Definition: FFT.h:70
FieldLayout< 1U > Layout_t
Definition: FFT.h:910
Layout_t * tempRLayouts_m
Definition: FFT.h:962
Layout_t * tempLayouts_m
Definition: FFT.h:645
LField< T, Dim > RealLField_t
Definition: FFT.h:437
Layout_t ** tempLayouts_m
Definition: FFT.h:816
Definition: FFT.h:31
void Uncompress(bool fill_domain=true)
Definition: LField.h:166
FFTBase< 1U, T >::Domain_t Domain_t
Definition: FFT.h:913
BareField< Complex_t, 1U > ComplexField_t
Definition: FFT.h:325
BareField< Complex_t, 1U > ComplexField_t
Definition: FFT.h:587
void transform(const char *directionName, ComplexField_t &f)
Definition: FFT.h:158
BareField< T, Dim > RealField_t
Definition: FFT.h:436
FFTBase< 1U, T >::Domain_t Domain_t
Definition: FFT.h:327
iterator_if end_if()
Definition: BareField.h:100
unsigned numTransformDims(void) const
query number of transform dimensions
Definition: FFTBase.h:148
#define PAssert_EQ(a, b)
Definition: PAssert.h:119
ComplexField_t ** tempFields_m
Definition: FFT.h:520
FFTBase< 1U, T >::Domain_t Domain_t
Definition: FFT.h:589
ComplexField_t * tempFields_m
Definition: FFT.h:386
Layout_t ** tempLayouts_m
Definition: FFT.h:510
InternalFFT_t & getEngine(void)
access the internal FFT Engine
Definition: FFTBase.h:154
FieldLayout< 1U > Layout_t
Definition: FFT.h:323
#define INFOMSG(msg)
Definition: IpplInfo.h:397
LField< T, 1U > RealLField_t
Definition: FFT.h:585
BareField< T, 1U > RealField_t
Definition: FFT.h:911
LField< T, 1U > RealLField_t
Definition: FFT.h:912
BareField< T, Dim > RealField_t
Definition: FFT.h:707
ComplexField_t * tempFields_m
Definition: FFT.h:650
Definition: FFT.h:52
BareField< T, 1U > RealField_t
Definition: FFT.h:584
Precision_t & getNormFact(void)
get the FFT normalization factor
Definition: FFTBase.h:157
BareField< Complex_t, Dim > ComplexField_t
Definition: FFT.h:439
Definition: FFT.h:30
LField< Complex_t, Dim > ComplexLField_t
Definition: FFT.h:440
std::complex< T > Complex_t
Definition: FFT.h:324
ac_id_larray::const_iterator const_iterator_if
Definition: BareField.h:92
FFTBase< Dim, T >::Domain_t Domain_t
Definition: FFT.h:712
LField< Complex_t, Dim > ComplexLField_t
Definition: FFT.h:711
Layout_t * tempLayouts_m
Definition: FFT.h:381
FieldLayout< Dim > Layout_t
Definition: FFT.h:706
void callFFT(unsigned transformDim, int direction, Complex_t *data)
FieldLayout< 1U > Layout_t
Definition: FFT.h:583
RealField_t * tempRFields_m
Definition: FFT.h:967
const GuardCellSizes< Dim > & getGC() const
Definition: BareField.h:145
RealField_t ** tempRFields_m
Definition: FFT.h:831
FieldLayout< Dim > Layout_t
Definition: FFT.h:67
e_dim_tag getDistribution(unsigned int d) const
Definition: FieldLayout.h:396
BareField< Complex_t, Dim > ComplexField_t
Definition: FFT.h:710
ComplexField_t ** tempFields_m
Definition: FFT.h:292
LField< T, Dim > RealLField_t
Definition: FFT.h:708
Layout_t ** tempLayouts_m
Definition: FFT.h:287
bool compressTemps(void) const
do we compress temps?
Definition: FFTBase.h:166
FFTBase< Dim, T >::Domain_t Domain_t
Definition: FFT.h:441
Layout_t * tempRLayout_m
Definition: FFT.h:655
Layout_t ** tempRLayouts_m
Definition: FFT.h:821
const unsigned Dim
ComplexField_t ** tempFields_m
Definition: FFT.h:826
bool checkDomain(const Domain_t &dom1, const Domain_t &dom2) const
compare indexes of two domains
Definition: FFTBase.h:269
iterator_if begin_if()
Definition: BareField.h:99
int getDirection(const char *directionName) const
translate direction name string into dimension number
Definition: FFTBase.h:225
LField< Complex_t, 1U > ComplexLField_t
Definition: FFT.h:326
Layout_t * tempRLayout_m
Definition: FFT.h:515
T * getP()
Definition: LField.h:94
int size(unsigned d) const
Definition: LField.h:91
const NDIndex< Dim > & getDomain() const
Definition: FieldLayout.h:325
LField< Complex_t, 1U > ComplexLField_t
Definition: FFT.h:588
FieldLayout< Dim > Layout_t
Definition: FFT.h:435
Domain_t complexDomain_m
Definition: FFT.h:662
Inform & endl(Inform &inf)
Definition: Inform.cpp:42