src/Index/Index.cpp

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 /***************************************************************************
00003  *
00004  * The IPPL Framework
00005  * 
00006  * This program was prepared by PSI. 
00007  * All rights in the program are reserved by PSI.
00008  * Neither PSI nor the author(s)
00009  * makes any warranty, express or implied, or assumes any liability or
00010  * responsibility for the use of this software
00011  *
00012  * Visit http://www.acl.lanl.gov/POOMS for more details
00013  *
00014  ***************************************************************************/
00015 
00016 // -*- C++ -*-
00017 /***************************************************************************
00018  *
00019  * The IPPL Framework
00020  * 
00021  *
00022  * Visit http://people.web.psi.ch/adelmann/ for more details
00023  *
00024  ***************************************************************************/
00025 
00027 // Major functions and test code for Index.
00028 // See main below for examples of use.
00030 
00031 // include files
00032 #include "Index/Index.h"
00033 #include "Utility/PAssert.h"
00034 #include "Profile/Profiler.h"
00035 
00037 
00038 ostream& operator<<(ostream& out, const Index& I) {
00039   TAU_PROFILE("operator<<()", "ostream (ostream, Index)", TAU_SPARSE | TAU_IO);
00040   out << '[' << I.first() << ':' << I.last() << ':' << I.stride() << ']';
00041   return out;
00042 }
00043 
00044 
00046 // Calculate the least common multipple of s1 and s2.
00047 // put the result in s.
00048 // also calculate m1 = s/s1 and m2 = s/s2.
00049 // This version is optimized for small s1 and s2 and 
00050 // just uses an exhaustive search.
00052 void lcm(int s1, int s2, int &s, int &m1, int &m2)
00053 {
00054   TAU_PROFILE("lcm()", "void (int, int, int, int, int)", TAU_SPARSE);
00055   PAssert(s1>0);   // For simplicity, make some assumptions.
00056   PAssert(s2>0);
00057   int i1=s1;
00058   int i2=s2;
00059   int _m1 = 1;
00060   int _m2 = 1;
00061   if (i2<i1)
00062     while(true)
00063       {
00064         while (i2<i1)
00065           {
00066             i2 += s2;
00067             ++_m2;
00068           }
00069         if (i1==i2)
00070           {
00071             m1 = _m1;
00072             m2 = _m2;
00073             s  = i1;
00074             return;
00075           }
00076         i1 += s1;
00077         ++_m1;
00078       }
00079   else
00080     while(true)
00081       {
00082         while (i1<i2)
00083           {
00084             i1 += s1;
00085             ++_m1;
00086           }
00087         if (i1==i2)
00088           {
00089             m1 = _m1;
00090             m2 = _m2;
00091             s  = i1;
00092             return;
00093           }
00094         i2 += s2;
00095         ++_m2;
00096       }
00097 }
00098 
00100 
00101 //
00102 // Intersect, with the code for the common case of
00103 // both strides equal to one.
00104 //
00105 
00106 Index
00107 Index::intersect(const Index& rhs) const
00108 {
00109   Index ret = DontInitialize() ;
00110   if ( (stride()==1) && (rhs.stride()==1) ) {
00111     int lf = first();
00112     int rf = rhs.first();
00113     int ll = last();
00114     int rl = rhs.last();
00115     int f = lf > rf ? lf : rf;
00116     int l = ll < rl ? ll : rl;
00117     ret.First = f;
00118     ret.Length = ( (l>=f) ? l-f+1 : 0 );
00119     ret.Stride = 1;
00120     ret.BaseFirst = BaseFirst + f - lf;
00121     ret.Base = Base;
00122 
00123 #ifdef UNDEFINED
00124     Index test = general_intersect(rhs);
00125     cout << "inter:  First  Length  Stride  BaseFirst  Base " << endl;
00126     cout << "*this= " << First << "," << Length << "," << Stride << "," << BaseFirst << "," << Base << endl;
00127     cout << "rhs  = " << rhs.First << "," << rhs.Length << "," << rhs.Stride << "," << rhs.BaseFirst << "," << rhs.Base << endl;
00128     cout << "ret  = " << ret.First << "," << ret.Length << "," << ret.Stride << "," << ret.BaseFirst << "," << ret.Base << endl;
00129     cout << "test = " << test.First << "," << test.Length << "," << test.Stride << "," << test.BaseFirst << "," << test.Base << endl;
00130     PAssert( ret.Length    == test.Length );
00131     if ( ret.Length > 0 ) {
00132       PAssert( ret.First     == test.First );
00133       PAssert( ret.Stride    == test.Stride );
00134       PAssert( ret.BaseFirst == test.BaseFirst );
00135       PAssert( ret.Base      == test.Base );
00136     }
00137 #endif // UNDEFINED
00138 
00139   }
00140   else
00141     ret = general_intersect(rhs);
00142   return ret;
00143 }
00144 
00146 
00147 static Index do_intersect(const Index &a, const Index &b)
00148 {
00149   TAU_PROFILE("do_intersect()", "Index (Index, Index)", TAU_SPARSE);
00150   PAssert(a.stride()>0);                // This should be assured by the
00151   PAssert(b.stride()>0);                // caller of this function.
00152 
00153   int newStride;                // The stride for the new index is
00154   int a_mul,b_mul;              // a_mul=newStride/a.stride() ...
00155   lcm(a.stride(),b.stride(),    // The input strides...
00156       newStride,a_mul,b_mul);   // the lcm of the strides of a and b.
00157   
00158   // Find the offset from a.first() in units of newStride
00159   // that puts the ranges close together.
00160   int a_i = (b.first()-a.first())/a.stride();
00161   int a_off = a.first() + a_i*a.stride();
00162   if (a_off < b.first())
00163     {
00164       a_i++;
00165       a_off += a.stride();
00166     }
00167   PAssert(a_off >= b.first());  // make sure I'm understanding this right...
00168 
00169   // Now do an exhaustive search for the first point in common.
00170   // Count over all possible offsets for a.
00171   for (int a_m=0;(a_m<a_mul)&&(a_i<a.length());a_m++,a_i++,a_off+=a.stride())
00172     {
00173       int b_off = b.first();
00174       // Count over all possible offsets for b.
00175       for (int b_m=0; (b_m<b_mul)&&(b_m<b.length()); b_m++,b_off+=b.stride())
00176         if ( a_off == b_off )
00177           {     // If the offsets are the same, we found it!
00178             int am = a.max();   // Find the minimum maximum of a and b...
00179             int bm = b.max();
00180             int m = am < bm ? am : bm;
00181             return Index(a_off, m, newStride);
00182           }
00183     }
00184   return Index(0);              // If we get to here there is no intersection.
00185 }
00186 
00188 
00189 Index Index::general_intersect(const Index& that) const
00190 {
00191   TAU_PROFILE("Index::general_intersect()", "Index (Index)", TAU_SPARSE);
00192   // If they just don't overlap, return null indexes.
00193   if ( (min() > that.max()) || (that.min() > max()) )
00194     return Index(0);
00195   if ( (Stride==0) || (that.Stride==0) )
00196     return Index(0);
00197 
00198   // If one or the other counts -ve, reverse it and intersect result.
00199   if ( that.Stride < 0 )
00200     return intersect(that.reverse());
00201   if ( Stride < 0 )
00202     {
00203       Index r;
00204       r = reverse().intersect(that).reverse();
00205       int diff = (r.First-First)/Stride;
00206       PAssert(diff>=0);
00207       r.BaseFirst = BaseFirst + diff;
00208       return r;
00209     }
00210 
00211   // Getting closer to the real thing: intersect them.
00212   // Pass the one that starts lower as the first argument.
00213   Index r;
00214   if ( First < that.First )
00215     r = do_intersect(*this,that);
00216   else
00217     r = do_intersect(that,*this);
00218 
00219   // Set the base so you can find what parts correspond
00220   // to the original interval.
00221   r.Base = Base;
00222   int diff = (r.First - First)/Stride;
00223   PAssert(diff>=0);
00224   r.BaseFirst = BaseFirst + diff;
00225   return r;
00226 }
00227 
00229 
00230 #ifdef DEBUG_INDEX
00231 int main()
00232 {
00233   TAU_PROFILE("main()", "int ()", TAU_DEFAULT);
00234   const int N  = 16;            // Number of grid points.
00235   const int NP = 4;             // Number of processors.
00236   const int NL = N/NP;          // Grid points per processor.
00237   int p;                        // processor counter.
00238 
00239   Index Ranges[NP];             // an index for each processor.
00240   for (p=0;p<NP;p++)            // On each processor
00241     Ranges[p] = Index(p*NL,(p+1)*NL-1); // Set the local range
00242 
00243   for (p=0;p<NP;p++)            // On each processor
00244     cout << Ranges[p] << endl;
00245 
00246   // work out A[Dest] = B[2*Dest];
00247   // Dest = [0...N/2-1]
00248   // Index Dest(N/2);
00249   // Index Src = 2*Dest;
00250 
00251   // Also try this:
00252   // Index Dest(N);
00253   // Index Src = N-1-Dest;
00254 
00255   // and this
00256   Index Dest(N);
00257   Index Src = Dest - 1;
00258 
00259   // another
00260   // Index Dest(0,N/2,2);
00261   // Index Src = Dest/2;
00262 
00263   // yet another
00264   // Index Dest = N-1-2*Index(N/2);
00265   // Index Src = N-1-Dest;
00266 
00267   cout << "Dest=" << Dest << endl;
00268   cout << "Src =" << Src  << endl;
00269 
00270   // Find out the gets from each processor for that operation.
00271   for (p=0; p<NP; p++)
00272     {
00273       cout << "On vp=" << p << ", range=" << Ranges[p] << endl;
00274 
00275       // Calculate what gets will be done.
00276       Index LDRange = Dest.intersect(Ranges[p]); // Local Destination Range for p
00277       Index SDRange = Src.plugBase(LDRange);     // Where that comes from.
00278       cout << "LDRange = " << LDRange << endl;
00279       cout << "SDRange = " << SDRange << endl;
00280       for (int pp=0; pp<NP; pp++)
00281         {              // Get from pp
00282           Index LSDRange = SDRange.intersect(Ranges[pp]); // what comes from pp
00283           if (!LSDRange.empty())
00284             {
00285               cout << "    from proc=" << pp << ", receive " << LSDRange << endl;
00286             }
00287         }
00288 
00289       // Calculate the puts.
00290       Index LSRange = Src.intersect(Ranges[p]);
00291       Index DSRange = Dest.plugBase(LSRange);    // The destination for that.
00292       cout << "LSRange = " << LSRange << endl;
00293       cout << "DSRange = " << DSRange << endl;
00294       for (pp=0; pp<NP; pp++)
00295         {                      // Put to pp
00296           Index LDSRange = LSRange.plugBase(DSRange.intersect(Ranges[pp]));
00297           if (!LDSRange.empty())
00298             {
00299               cout << "    send to pp=" << pp << ", the range=" << LDSRange << endl;
00300             }
00301         }
00302     }
00303 }
00304 
00305 #endif // DEBUG_INDEX
00306 
00307 /***************************************************************************
00308  * $RCSfile: Index.cpp,v $   $Author: adelmann $
00309  * $Revision: 1.1.1.1 $   $Date: 2003/01/23 07:40:27 $
00310  * IPPL_VERSION_ID: $Id: Index.cpp,v 1.1.1.1 2003/01/23 07:40:27 adelmann Exp $ 
00311  ***************************************************************************/

Generated on Mon Jan 16 13:23:48 2006 for IPPL by  doxygen 1.4.6