00001 
00002 
00003 #include <cassert>
00004 #include "blas.h"
00005 
00006 extern "C" {
00007 typedef struct { float  re, im; } complex8_t;
00008 typedef struct { double re, im; } complex16_t;
00009 }
00010 
00011 namespace blas {
00012 
00013 using namespace colarray;
00014 
00015 
00016 extern "C" double sdot_(const int_t*, const float*, const int_t*, const float*, const int_t*);
00017 
00018 float dot(const Vector< float >& x, const Vector< float >& y) {
00019   assert(x._n == y._n);
00020   return sdot_(&x._n, x._v, &x._inc, y._v, &y._inc);
00021 }
00022 
00023 extern "C" double ddot_(const int_t*, const double*, const int_t*, const double*, const int_t*);
00024 
00025 double dot(const Vector< double >& x, const Vector< double >& y) {
00026   assert(x._n == y._n);
00027   return ddot_(&x._n, x._v, &x._inc, y._v, &y._inc);
00028 }
00029 
00030 extern "C" double snrm2_(const int_t*, const float*, const int_t*);
00031 
00032 double nrm2(const Vector< float >& x) {
00033   return snrm2_(&x._n, x._v, &x._inc);
00034 }
00035 
00036 extern "C" double dnrm2_(const int_t*, const double*, const int_t*);
00037 
00038 double nrm2(const Vector< double >& x) {
00039   return dnrm2_(&x._n, x._v, &x._inc);
00040 }
00041 
00042 extern "C" double scnrm2_(const int_t*, const std::complex<float>*, const int_t*);
00043 
00044 double nrm2(const Vector< std::complex<float> >& x) {
00045   return scnrm2_(&x._n, x._v, &x._inc);
00046 }
00047 
00048 extern "C" double dznrm2_(const int_t*, const std::complex<double>*, const int_t*);
00049 
00050 double nrm2(const Vector< std::complex<double> >& x) {
00051   return dznrm2_(&x._n, x._v, &x._inc);
00052 }
00053 
00054 extern "C"
00055 void saxpy_(const int_t*, const float*,
00056         const float*, const int_t*,
00057         float*, const int_t*);
00058 
00059 void axpy(const float alpha, const Vector< float >& x, Vector< float >& y)
00060 {
00061   assert(x._n == y._n);
00062   saxpy_(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00063 }
00064 
00065 extern "C"
00066 void daxpy_(const int_t*, const double*,
00067         const double*, const int_t*,
00068         double*, const int_t*);
00069 
00070 void axpy(const double alpha, const Vector< double >& x, Vector< double >& y)
00071 {
00072   assert(x._n == y._n);
00073   daxpy_(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00074 }
00075 
00076 extern "C"
00077 void caxpy_(const int_t*, const std::complex<float>*,
00078         const std::complex<float>*, const int_t*,
00079         std::complex<float>*, const int_t*);
00080 
00081 void axpy(const std::complex<float> alpha, const Vector< std::complex<float> >& x, Vector< std::complex<float> >& y)
00082 {
00083   assert(x._n == y._n);
00084   caxpy_(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00085 }
00086 
00087 extern "C"
00088 void zaxpy_(const int_t*, const std::complex<double>*,
00089         const std::complex<double>*, const int_t*,
00090         std::complex<double>*, const int_t*);
00091 
00092 void axpy(const std::complex<double> alpha, const Vector< std::complex<double> >& x, Vector< std::complex<double> >& y)
00093 {
00094   assert(x._n == y._n);
00095   zaxpy_(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00096 }
00097 
00098 extern "C"
00099 void sgemv_(const char*,
00100         const int_t*, const int_t*,
00101         const float*, const float*, const int_t*,
00102         const float*, const int_t*,
00103         const float*, float*, const int_t*,
00104         const int_t);
00105 
00106 
00107 void gemv(char trans,
00108           const Matrix< float >& A, 
00109           const Vector< float >& x, 
00110           Vector< float >& y, 
00111           float alpha, float beta)
00112     
00113 {
00114   if (tolower(trans) == 'n') {
00115     assert(A._m == x._n); 
00116     assert(A._n == y._n);
00117   } else {
00118     assert(A._n == x._n); 
00119     assert(A._m == y._n);
00120   }
00121 
00122   sgemv_(&trans, 
00123      &A._m, &A._n, 
00124      &alpha, A._v, &A._ld, 
00125      x._v, &x._inc, 
00126      &beta, y._v, &y._inc, 1); 
00127 }
00128 
00129 extern "C"
00130 void dgemv_(const char*,
00131         const int_t*, const int_t*,
00132         const double*, const double*, const int_t*,
00133         const double*, const int_t*,
00134         const double*, double*, const int_t*,
00135         const int_t);
00136 
00137 
00138 void gemv(char trans,
00139           const Matrix< double >& A, 
00140           const Vector< double >& x, 
00141           Vector< double >& y, 
00142           double alpha, double beta)
00143     
00144 {
00145   if (tolower(trans) == 'n') {
00146     assert(A._m == x._n); 
00147     assert(A._n == y._n);
00148   } else {
00149     assert(A._n == x._n); 
00150     assert(A._m == y._n);
00151   }
00152 
00153   dgemv_(&trans, 
00154      &A._m, &A._n, 
00155      &alpha, A._v, &A._ld, 
00156      x._v, &x._inc, 
00157      &beta, y._v, &y._inc, 1); 
00158 }
00159 
00160 extern "C"
00161 void cgemv_(const char*,
00162         const int_t*, const int_t*,
00163         const std::complex<float>*, const std::complex<float>*, const int_t*,
00164         const std::complex<float>*, const int_t*,
00165         const std::complex<float>*, std::complex<float>*, const int_t*,
00166         const int_t);
00167 
00168 
00169 void gemv(char trans,
00170           const Matrix< std::complex<float> >& A, 
00171           const Vector< std::complex<float> >& x, 
00172           Vector< std::complex<float> >& y, 
00173           std::complex<float> alpha, std::complex<float> beta)
00174     
00175 {
00176   if (tolower(trans) == 'n') {
00177     assert(A._m == x._n); 
00178     assert(A._n == y._n);
00179   } else {
00180     assert(A._n == x._n); 
00181     assert(A._m == y._n);
00182   }
00183 
00184   cgemv_(&trans, 
00185      &A._m, &A._n, 
00186      &alpha, A._v, &A._ld, 
00187      x._v, &x._inc, 
00188      &beta, y._v, &y._inc, 1); 
00189 }
00190 
00191 extern "C"
00192 void zgemv_(const char*,
00193         const int_t*, const int_t*,
00194         const std::complex<double>*, const std::complex<double>*, const int_t*,
00195         const std::complex<double>*, const int_t*,
00196         const std::complex<double>*, std::complex<double>*, const int_t*,
00197         const int_t);
00198 
00199 
00200 void gemv(char trans,
00201           const Matrix< std::complex<double> >& A, 
00202           const Vector< std::complex<double> >& x, 
00203           Vector< std::complex<double> >& y, 
00204           std::complex<double> alpha, std::complex<double> beta)
00205     
00206 {
00207   if (tolower(trans) == 'n') {
00208     assert(A._m == x._n); 
00209     assert(A._n == y._n);
00210   } else {
00211     assert(A._n == x._n); 
00212     assert(A._m == y._n);
00213   }
00214 
00215   zgemv_(&trans, 
00216      &A._m, &A._n, 
00217      &alpha, A._v, &A._ld, 
00218      x._v, &x._inc, 
00219      &beta, y._v, &y._inc, 1); 
00220 }
00221 
00222 extern "C" void sgemm_(
00223   const char*, const char*,
00224   const int_t*, const int_t*, const int_t*,
00225   const float*, const float*, const int_t*,
00226   const float*, const int_t*,
00227   const float*, float*, const int_t*,
00228   const int_t, const int_t);
00229     
00230 
00231 void gemm(char transA, char transB,
00232           const Matrix< float >& A,
00233           const Matrix< float >& B,
00234           Matrix< float >& C,
00235           float alpha, float beta)
00236 {
00237   int_t m, n, k1, k2;
00238   if (tolower(transA) == 'n') {
00239     m = A._m;
00240     k1 = A._n;
00241   } else {
00242     m = A._n;
00243     k1 = A._m;
00244   }
00245   if (tolower(transB) == 'n') {
00246     n = B._n;
00247     k2 = B._m;
00248   } else {
00249     n = B._m;
00250     k2 = B._n;
00251   }
00252   
00253   assert(m = C._m);
00254   assert(n = C._n);
00255   assert(k1 = k2);
00256     
00257   sgemm_(
00258       &transA, &transB,
00259       &m, &n, &k1,
00260       &alpha, A._v, &A._ld,
00261       B._v, &B._ld,
00262       &beta, C._v, &C._ld, 1, 1);
00263 }
00264 
00265 extern "C" void dgemm_(
00266   const char*, const char*,
00267   const int_t*, const int_t*, const int_t*,
00268   const double*, const double*, const int_t*,
00269   const double*, const int_t*,
00270   const double*, double*, const int_t*,
00271   const int_t, const int_t);
00272     
00273 
00274 void gemm(char transA, char transB,
00275           const Matrix< double >& A,
00276           const Matrix< double >& B,
00277           Matrix< double >& C,
00278           double alpha, double beta)
00279 {
00280   int_t m, n, k1, k2;
00281   if (tolower(transA) == 'n') {
00282     m = A._m;
00283     k1 = A._n;
00284   } else {
00285     m = A._n;
00286     k1 = A._m;
00287   }
00288   if (tolower(transB) == 'n') {
00289     n = B._n;
00290     k2 = B._m;
00291   } else {
00292     n = B._m;
00293     k2 = B._n;
00294   }
00295   
00296   assert(m = C._m);
00297   assert(n = C._n);
00298   assert(k1 = k2);
00299     
00300   dgemm_(
00301       &transA, &transB,
00302       &m, &n, &k1,
00303       &alpha, A._v, &A._ld,
00304       B._v, &B._ld,
00305       &beta, C._v, &C._ld, 1, 1);
00306 }
00307 
00308 extern "C" void cgemm_(
00309   const char*, const char*,
00310   const int_t*, const int_t*, const int_t*,
00311   const std::complex<float>*, const std::complex<float>*, const int_t*,
00312   const std::complex<float>*, const int_t*,
00313   const std::complex<float>*, std::complex<float>*, const int_t*,
00314   const int_t, const int_t);
00315     
00316 
00317 void gemm(char transA, char transB,
00318           const Matrix< std::complex<float> >& A,
00319           const Matrix< std::complex<float> >& B,
00320           Matrix< std::complex<float> >& C,
00321           std::complex<float> alpha, std::complex<float> beta)
00322 {
00323   int_t m, n, k1, k2;
00324   if (tolower(transA) == 'n') {
00325     m = A._m;
00326     k1 = A._n;
00327   } else {
00328     m = A._n;
00329     k1 = A._m;
00330   }
00331   if (tolower(transB) == 'n') {
00332     n = B._n;
00333     k2 = B._m;
00334   } else {
00335     n = B._m;
00336     k2 = B._n;
00337   }
00338   
00339   assert(m = C._m);
00340   assert(n = C._n);
00341   assert(k1 = k2);
00342     
00343   cgemm_(
00344       &transA, &transB,
00345       &m, &n, &k1,
00346       &alpha, A._v, &A._ld,
00347       B._v, &B._ld,
00348       &beta, C._v, &C._ld, 1, 1);
00349 }
00350 
00351 extern "C" void zgemm_(
00352   const char*, const char*,
00353   const int_t*, const int_t*, const int_t*,
00354   const std::complex<double>*, const std::complex<double>*, const int_t*,
00355   const std::complex<double>*, const int_t*,
00356   const std::complex<double>*, std::complex<double>*, const int_t*,
00357   const int_t, const int_t);
00358     
00359 
00360 void gemm(char transA, char transB,
00361           const Matrix< std::complex<double> >& A,
00362           const Matrix< std::complex<double> >& B,
00363           Matrix< std::complex<double> >& C,
00364           std::complex<double> alpha, std::complex<double> beta)
00365 {
00366   int_t m, n, k1, k2;
00367   if (tolower(transA) == 'n') {
00368     m = A._m;
00369     k1 = A._n;
00370   } else {
00371     m = A._n;
00372     k1 = A._m;
00373   }
00374   if (tolower(transB) == 'n') {
00375     n = B._n;
00376     k2 = B._m;
00377   } else {
00378     n = B._m;
00379     k2 = B._n;
00380   }
00381   
00382   assert(m = C._m);
00383   assert(n = C._n);
00384   assert(k1 = k2);
00385     
00386   zgemm_(
00387       &transA, &transB,
00388       &m, &n, &k1,
00389       &alpha, A._v, &A._ld,
00390       B._v, &B._ld,
00391       &beta, C._v, &C._ld, 1, 1);
00392 }
00393 
00394 extern "C" void sgesv_(
00395   const int_t* n, const int_t* nrhs, float a[], const int_t* lda,
00396   int_t ipiv[],
00397   float b[], const int_t* ldb, int_t* info);
00398 
00399 
00400 void gesv(const Matrix< float >& A, Matrix< float >& B, Vector< int_t >& ipiv, int_t& info)
00401 {
00402   assert(A._m == A._n);
00403   assert(A._m == B._m);
00404 
00405   sgesv_(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00406 }
00407 
00408 extern "C" void dgesv_(
00409   const int_t* n, const int_t* nrhs, double a[], const int_t* lda,
00410   int_t ipiv[],
00411   double b[], const int_t* ldb, int_t* info);
00412 
00413 
00414 void gesv(const Matrix< double >& A, Matrix< double >& B, Vector< int_t >& ipiv, int_t& info)
00415 {
00416   assert(A._m == A._n);
00417   assert(A._m == B._m);
00418 
00419   dgesv_(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00420 }
00421 
00422 extern "C" void cgesv_(
00423   const int_t* n, const int_t* nrhs, std::complex<float> a[], const int_t* lda,
00424   int_t ipiv[],
00425   std::complex<float> b[], const int_t* ldb, int_t* info);
00426 
00427 
00428 void gesv(const Matrix< std::complex<float> >& A, Matrix< std::complex<float> >& B, Vector< int_t >& ipiv, int_t& info)
00429 {
00430   assert(A._m == A._n);
00431   assert(A._m == B._m);
00432 
00433   cgesv_(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00434 }
00435 
00436 extern "C" void zgesv_(
00437   const int_t* n, const int_t* nrhs, std::complex<double> a[], const int_t* lda,
00438   int_t ipiv[],
00439   std::complex<double> b[], const int_t* ldb, int_t* info);
00440 
00441 
00442 void gesv(const Matrix< std::complex<double> >& A, Matrix< std::complex<double> >& B, Vector< int_t >& ipiv, int_t& info)
00443 {
00444   assert(A._m == A._n);
00445   assert(A._m == B._m);
00446 
00447   zgesv_(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00448 }
00449 
00450 extern "C" void sgetrf_(
00451   const int_t* m, const int_t* n, float a[], const int_t* lda,
00452   int_t ipiv[],
00453   int_t* info);
00454 
00455 
00456 void getrf(const Matrix< float >& A, Vector< int_t >& ipiv, int_t& info)
00457 {
00458   assert(A._m == A._n);
00459   assert(A._m == ipiv._n);
00460   
00461   sgetrf_(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00462 }
00463 
00464 extern "C" void dgetrf_(
00465   const int_t* m, const int_t* n, double a[], const int_t* lda,
00466   int_t ipiv[],
00467   int_t* info);
00468 
00469 
00470 void getrf(const Matrix< double >& A, Vector< int_t >& ipiv, int_t& info)
00471 {
00472   assert(A._m == A._n);
00473   assert(A._m == ipiv._n);
00474   
00475   dgetrf_(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00476 }
00477 
00478 extern "C" void cgetrf_(
00479   const int_t* m, const int_t* n, std::complex<float> a[], const int_t* lda,
00480   int_t ipiv[],
00481   int_t* info);
00482 
00483 
00484 void getrf(const Matrix< std::complex<float> >& A, Vector< int_t >& ipiv, int_t& info)
00485 {
00486   assert(A._m == A._n);
00487   assert(A._m == ipiv._n);
00488   
00489   cgetrf_(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00490 }
00491 
00492 extern "C" void zgetrf_(
00493   const int_t* m, const int_t* n, std::complex<double> a[], const int_t* lda,
00494   int_t ipiv[],
00495   int_t* info);
00496 
00497 
00498 void getrf(const Matrix< std::complex<double> >& A, Vector< int_t >& ipiv, int_t& info)
00499 {
00500   assert(A._m == A._n);
00501   assert(A._m == ipiv._n);
00502   
00503   zgetrf_(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00504 }
00505 
00506 extern "C" void ssyev_(
00507   char* jobz, char* uplo,
00508   int_t* n, float a[], int_t* lda,
00509   float w[],
00510   float work[], int_t* lwork, int_t* info,
00511   int_t len_jobz, int_t len_uplo);
00512 
00513 
00514 void syev(char jobz, char uplo, Matrix< float >& A, ColumnVector< float >&w, int_t& info)
00515 {
00516   assert(A._m == A._n);
00517   int_t nb = 8;
00518   int_t lwork = (nb + 2)*A._n;
00519   Vector< float > work(lwork);
00520 
00521   ssyev_(&jobz, &uplo,
00522                &A._m, A._v, &A._ld,
00523                w._v,
00524                work._v, &lwork, &info,
00525                1, 1);
00526 }
00527 
00528 extern "C" void dsyev_(
00529   char* jobz, char* uplo,
00530   int_t* n, double a[], int_t* lda,
00531   double w[],
00532   double work[], int_t* lwork, int_t* info,
00533   int_t len_jobz, int_t len_uplo);
00534 
00535 
00536 void syev(char jobz, char uplo, Matrix< double >& A, ColumnVector< double >&w, int_t& info)
00537 {
00538   assert(A._m == A._n);
00539   int_t nb = 8;
00540   int_t lwork = (nb + 2)*A._n;
00541   Vector< double > work(lwork);
00542 
00543   dsyev_(&jobz, &uplo,
00544                &A._m, A._v, &A._ld,
00545                w._v,
00546                work._v, &lwork, &info,
00547                1, 1);
00548 }
00549 
00550 }