00001
00002
00003 #include <cassert>
00004 #include "blas.h"
00005
00006 extern "C" {
00007 typedef struct { float re, im; } complex8_t;
00008 typedef struct { double re, im; } complex16_t;
00009 }
00010
00011 namespace blas {
00012
00013 using namespace colarray;
00014
00015
00016 extern "C" float sdot(const int_t*, const float*, const int_t*, const float*, const int_t*);
00017
00018 float dot(const Vector< float >& x, const Vector< float >& y) {
00019 assert(x._n == y._n);
00020 return sdot(&x._n, x._v, &x._inc, y._v, &y._inc);
00021 }
00022
00023 extern "C" double ddot(const int_t*, const double*, const int_t*, const double*, const int_t*);
00024
00025 double dot(const Vector< double >& x, const Vector< double >& y) {
00026 assert(x._n == y._n);
00027 return ddot(&x._n, x._v, &x._inc, y._v, &y._inc);
00028 }
00029
00030 extern "C" float snrm2(const int_t*, const float*, const int_t*);
00031
00032 float nrm2(const Vector< float >& x) {
00033 return snrm2(&x._n, x._v, &x._inc);
00034 }
00035
00036 extern "C" double dnrm2(const int_t*, const double*, const int_t*);
00037
00038 double nrm2(const Vector< double >& x) {
00039 return dnrm2(&x._n, x._v, &x._inc);
00040 }
00041
00042 extern "C" float scnrm2(const int_t*, const std::complex<float>*, const int_t*);
00043
00044 float nrm2(const Vector< std::complex<float> >& x) {
00045 return scnrm2(&x._n, x._v, &x._inc);
00046 }
00047
00048 extern "C" double dznrm2(const int_t*, const std::complex<double>*, const int_t*);
00049
00050 double nrm2(const Vector< std::complex<double> >& x) {
00051 return dznrm2(&x._n, x._v, &x._inc);
00052 }
00053
00054 extern "C"
00055 void saxpy(const int_t*, const float*,
00056 const float*, const int_t*,
00057 float*, const int_t*);
00058
00059 void axpy(const float alpha, const Vector< float >& x, Vector< float >& y)
00060 {
00061 assert(x._n == y._n);
00062 saxpy(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00063 }
00064
00065 extern "C"
00066 void daxpy(const int_t*, const double*,
00067 const double*, const int_t*,
00068 double*, const int_t*);
00069
00070 void axpy(const double alpha, const Vector< double >& x, Vector< double >& y)
00071 {
00072 assert(x._n == y._n);
00073 daxpy(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00074 }
00075
00076 extern "C"
00077 void caxpy(const int_t*, const std::complex<float>*,
00078 const std::complex<float>*, const int_t*,
00079 std::complex<float>*, const int_t*);
00080
00081 void axpy(const std::complex<float> alpha, const Vector< std::complex<float> >& x, Vector< std::complex<float> >& y)
00082 {
00083 assert(x._n == y._n);
00084 caxpy(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00085 }
00086
00087 extern "C"
00088 void zaxpy(const int_t*, const std::complex<double>*,
00089 const std::complex<double>*, const int_t*,
00090 std::complex<double>*, const int_t*);
00091
00092 void axpy(const std::complex<double> alpha, const Vector< std::complex<double> >& x, Vector< std::complex<double> >& y)
00093 {
00094 assert(x._n == y._n);
00095 zaxpy(&x._n, &alpha, x._v, &x._inc, y._v, &y._inc);
00096 }
00097
00098 extern "C"
00099 void sgemv(const char*,
00100 const int_t*, const int_t*,
00101 const float*, const float*, const int_t*,
00102 const float*, const int_t*,
00103 const float*, float*, const int_t*,
00104 const int_t);
00105
00106
00107 void gemv(char trans,
00108 const Matrix< float >& A,
00109 const Vector< float >& x,
00110 Vector< float >& y,
00111 float alpha, float beta)
00112
00113 {
00114 if (tolower(trans) == 'n') {
00115 assert(A._m == x._n);
00116 assert(A._n == y._n);
00117 } else {
00118 assert(A._n == x._n);
00119 assert(A._m == y._n);
00120 }
00121
00122 sgemv(&trans,
00123 &A._m, &A._n,
00124 &alpha, A._v, &A._ld,
00125 x._v, &x._inc,
00126 &beta, y._v, &y._inc, 1);
00127 }
00128
00129 extern "C"
00130 void dgemv(const char*,
00131 const int_t*, const int_t*,
00132 const double*, const double*, const int_t*,
00133 const double*, const int_t*,
00134 const double*, double*, const int_t*,
00135 const int_t);
00136
00137
00138 void gemv(char trans,
00139 const Matrix< double >& A,
00140 const Vector< double >& x,
00141 Vector< double >& y,
00142 double alpha, double beta)
00143
00144 {
00145 if (tolower(trans) == 'n') {
00146 assert(A._m == x._n);
00147 assert(A._n == y._n);
00148 } else {
00149 assert(A._n == x._n);
00150 assert(A._m == y._n);
00151 }
00152
00153 dgemv(&trans,
00154 &A._m, &A._n,
00155 &alpha, A._v, &A._ld,
00156 x._v, &x._inc,
00157 &beta, y._v, &y._inc, 1);
00158 }
00159
00160 extern "C"
00161 void cgemv(const char*,
00162 const int_t*, const int_t*,
00163 const std::complex<float>*, const std::complex<float>*, const int_t*,
00164 const std::complex<float>*, const int_t*,
00165 const std::complex<float>*, std::complex<float>*, const int_t*,
00166 const int_t);
00167
00168
00169 void gemv(char trans,
00170 const Matrix< std::complex<float> >& A,
00171 const Vector< std::complex<float> >& x,
00172 Vector< std::complex<float> >& y,
00173 std::complex<float> alpha, std::complex<float> beta)
00174
00175 {
00176 if (tolower(trans) == 'n') {
00177 assert(A._m == x._n);
00178 assert(A._n == y._n);
00179 } else {
00180 assert(A._n == x._n);
00181 assert(A._m == y._n);
00182 }
00183
00184 cgemv(&trans,
00185 &A._m, &A._n,
00186 &alpha, A._v, &A._ld,
00187 x._v, &x._inc,
00188 &beta, y._v, &y._inc, 1);
00189 }
00190
00191 extern "C"
00192 void zgemv(const char*,
00193 const int_t*, const int_t*,
00194 const std::complex<double>*, const std::complex<double>*, const int_t*,
00195 const std::complex<double>*, const int_t*,
00196 const std::complex<double>*, std::complex<double>*, const int_t*,
00197 const int_t);
00198
00199
00200 void gemv(char trans,
00201 const Matrix< std::complex<double> >& A,
00202 const Vector< std::complex<double> >& x,
00203 Vector< std::complex<double> >& y,
00204 std::complex<double> alpha, std::complex<double> beta)
00205
00206 {
00207 if (tolower(trans) == 'n') {
00208 assert(A._m == x._n);
00209 assert(A._n == y._n);
00210 } else {
00211 assert(A._n == x._n);
00212 assert(A._m == y._n);
00213 }
00214
00215 zgemv(&trans,
00216 &A._m, &A._n,
00217 &alpha, A._v, &A._ld,
00218 x._v, &x._inc,
00219 &beta, y._v, &y._inc, 1);
00220 }
00221
00222 extern "C" void sgemm(
00223 const char*, const char*,
00224 const int_t*, const int_t*, const int_t*,
00225 const float*, const float*, const int_t*,
00226 const float*, const int_t*,
00227 const float*, float*, const int_t*,
00228 const int_t, const int_t);
00229
00230
00231 void gemm(char transA, char transB,
00232 const Matrix< float >& A,
00233 const Matrix< float >& B,
00234 Matrix< float >& C,
00235 float alpha, float beta)
00236 {
00237 int_t m, n, k1, k2;
00238 if (tolower(transA) == 'n') {
00239 m = A._m;
00240 k1 = A._n;
00241 } else {
00242 m = A._n;
00243 k1 = A._m;
00244 }
00245 if (tolower(transB) == 'n') {
00246 n = B._n;
00247 k2 = B._m;
00248 } else {
00249 n = B._m;
00250 k2 = B._n;
00251 }
00252
00253 assert(m = C._m);
00254 assert(n = C._n);
00255 assert(k1 = k2);
00256
00257 sgemm(
00258 &transA, &transB,
00259 &m, &n, &k1,
00260 &alpha, A._v, &A._ld,
00261 B._v, &B._ld,
00262 &beta, C._v, &C._ld, 1, 1);
00263 }
00264
00265 extern "C" void dgemm(
00266 const char*, const char*,
00267 const int_t*, const int_t*, const int_t*,
00268 const double*, const double*, const int_t*,
00269 const double*, const int_t*,
00270 const double*, double*, const int_t*,
00271 const int_t, const int_t);
00272
00273
00274 void gemm(char transA, char transB,
00275 const Matrix< double >& A,
00276 const Matrix< double >& B,
00277 Matrix< double >& C,
00278 double alpha, double beta)
00279 {
00280 int_t m, n, k1, k2;
00281 if (tolower(transA) == 'n') {
00282 m = A._m;
00283 k1 = A._n;
00284 } else {
00285 m = A._n;
00286 k1 = A._m;
00287 }
00288 if (tolower(transB) == 'n') {
00289 n = B._n;
00290 k2 = B._m;
00291 } else {
00292 n = B._m;
00293 k2 = B._n;
00294 }
00295
00296 assert(m = C._m);
00297 assert(n = C._n);
00298 assert(k1 = k2);
00299
00300 dgemm(
00301 &transA, &transB,
00302 &m, &n, &k1,
00303 &alpha, A._v, &A._ld,
00304 B._v, &B._ld,
00305 &beta, C._v, &C._ld, 1, 1);
00306 }
00307
00308 extern "C" void cgemm(
00309 const char*, const char*,
00310 const int_t*, const int_t*, const int_t*,
00311 const std::complex<float>*, const std::complex<float>*, const int_t*,
00312 const std::complex<float>*, const int_t*,
00313 const std::complex<float>*, std::complex<float>*, const int_t*,
00314 const int_t, const int_t);
00315
00316
00317 void gemm(char transA, char transB,
00318 const Matrix< std::complex<float> >& A,
00319 const Matrix< std::complex<float> >& B,
00320 Matrix< std::complex<float> >& C,
00321 std::complex<float> alpha, std::complex<float> beta)
00322 {
00323 int_t m, n, k1, k2;
00324 if (tolower(transA) == 'n') {
00325 m = A._m;
00326 k1 = A._n;
00327 } else {
00328 m = A._n;
00329 k1 = A._m;
00330 }
00331 if (tolower(transB) == 'n') {
00332 n = B._n;
00333 k2 = B._m;
00334 } else {
00335 n = B._m;
00336 k2 = B._n;
00337 }
00338
00339 assert(m = C._m);
00340 assert(n = C._n);
00341 assert(k1 = k2);
00342
00343 cgemm(
00344 &transA, &transB,
00345 &m, &n, &k1,
00346 &alpha, A._v, &A._ld,
00347 B._v, &B._ld,
00348 &beta, C._v, &C._ld, 1, 1);
00349 }
00350
00351 extern "C" void zgemm(
00352 const char*, const char*,
00353 const int_t*, const int_t*, const int_t*,
00354 const std::complex<double>*, const std::complex<double>*, const int_t*,
00355 const std::complex<double>*, const int_t*,
00356 const std::complex<double>*, std::complex<double>*, const int_t*,
00357 const int_t, const int_t);
00358
00359
00360 void gemm(char transA, char transB,
00361 const Matrix< std::complex<double> >& A,
00362 const Matrix< std::complex<double> >& B,
00363 Matrix< std::complex<double> >& C,
00364 std::complex<double> alpha, std::complex<double> beta)
00365 {
00366 int_t m, n, k1, k2;
00367 if (tolower(transA) == 'n') {
00368 m = A._m;
00369 k1 = A._n;
00370 } else {
00371 m = A._n;
00372 k1 = A._m;
00373 }
00374 if (tolower(transB) == 'n') {
00375 n = B._n;
00376 k2 = B._m;
00377 } else {
00378 n = B._m;
00379 k2 = B._n;
00380 }
00381
00382 assert(m = C._m);
00383 assert(n = C._n);
00384 assert(k1 = k2);
00385
00386 zgemm(
00387 &transA, &transB,
00388 &m, &n, &k1,
00389 &alpha, A._v, &A._ld,
00390 B._v, &B._ld,
00391 &beta, C._v, &C._ld, 1, 1);
00392 }
00393
00394 extern "C" void sgesv(
00395 const int_t* n, const int_t* nrhs, float a[], const int_t* lda,
00396 int_t ipiv[],
00397 float b[], const int_t* ldb, int_t* info);
00398
00399
00400 void gesv(const Matrix< float >& A, Matrix< float >& B, Vector< int_t >& ipiv, int_t& info)
00401 {
00402 assert(A._m == A._n);
00403 assert(A._m == B._m);
00404
00405 sgesv(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00406 }
00407
00408 extern "C" void dgesv(
00409 const int_t* n, const int_t* nrhs, double a[], const int_t* lda,
00410 int_t ipiv[],
00411 double b[], const int_t* ldb, int_t* info);
00412
00413
00414 void gesv(const Matrix< double >& A, Matrix< double >& B, Vector< int_t >& ipiv, int_t& info)
00415 {
00416 assert(A._m == A._n);
00417 assert(A._m == B._m);
00418
00419 dgesv(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00420 }
00421
00422 extern "C" void cgesv(
00423 const int_t* n, const int_t* nrhs, std::complex<float> a[], const int_t* lda,
00424 int_t ipiv[],
00425 std::complex<float> b[], const int_t* ldb, int_t* info);
00426
00427
00428 void gesv(const Matrix< std::complex<float> >& A, Matrix< std::complex<float> >& B, Vector< int_t >& ipiv, int_t& info)
00429 {
00430 assert(A._m == A._n);
00431 assert(A._m == B._m);
00432
00433 cgesv(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00434 }
00435
00436 extern "C" void zgesv(
00437 const int_t* n, const int_t* nrhs, std::complex<double> a[], const int_t* lda,
00438 int_t ipiv[],
00439 std::complex<double> b[], const int_t* ldb, int_t* info);
00440
00441
00442 void gesv(const Matrix< std::complex<double> >& A, Matrix< std::complex<double> >& B, Vector< int_t >& ipiv, int_t& info)
00443 {
00444 assert(A._m == A._n);
00445 assert(A._m == B._m);
00446
00447 zgesv(&A._m, &B._n, A._v, &A._ld, ipiv._v, B._v, &B._ld, &info);
00448 }
00449
00450 extern "C" void sgetrf(
00451 const int_t* m, const int_t* n, float a[], const int_t* lda,
00452 int_t ipiv[],
00453 int_t* info);
00454
00455
00456 void getrf(const Matrix< float >& A, Vector< int_t >& ipiv, int_t& info)
00457 {
00458 assert(A._m == A._n);
00459 assert(A._m == ipiv._n);
00460
00461 sgetrf(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00462 }
00463
00464 extern "C" void dgetrf(
00465 const int_t* m, const int_t* n, double a[], const int_t* lda,
00466 int_t ipiv[],
00467 int_t* info);
00468
00469
00470 void getrf(const Matrix< double >& A, Vector< int_t >& ipiv, int_t& info)
00471 {
00472 assert(A._m == A._n);
00473 assert(A._m == ipiv._n);
00474
00475 dgetrf(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00476 }
00477
00478 extern "C" void cgetrf(
00479 const int_t* m, const int_t* n, std::complex<float> a[], const int_t* lda,
00480 int_t ipiv[],
00481 int_t* info);
00482
00483
00484 void getrf(const Matrix< std::complex<float> >& A, Vector< int_t >& ipiv, int_t& info)
00485 {
00486 assert(A._m == A._n);
00487 assert(A._m == ipiv._n);
00488
00489 cgetrf(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00490 }
00491
00492 extern "C" void zgetrf(
00493 const int_t* m, const int_t* n, std::complex<double> a[], const int_t* lda,
00494 int_t ipiv[],
00495 int_t* info);
00496
00497
00498 void getrf(const Matrix< std::complex<double> >& A, Vector< int_t >& ipiv, int_t& info)
00499 {
00500 assert(A._m == A._n);
00501 assert(A._m == ipiv._n);
00502
00503 zgetrf(&A._m, &A._n, A._v, &A._ld, ipiv._v, &info);
00504 }
00505
00506 extern "C" void ssyev(
00507 char* jobz, char* uplo,
00508 int_t* n, float a[], int_t* lda,
00509 float w[],
00510 float work[], int_t* lwork, int_t* info,
00511 int_t len_jobz, int_t len_uplo);
00512
00513
00514 void syev(char jobz, char uplo, Matrix< float >& A, ColumnVector< float >&w, int_t& info)
00515 {
00516 assert(A._m == A._n);
00517 int_t nb = 8;
00518 int_t lwork = (nb + 2)*A._n;
00519 Vector< float > work(lwork);
00520
00521 ssyev(&jobz, &uplo,
00522 &A._m, A._v, &A._ld,
00523 w._v,
00524 work._v, &lwork, &info,
00525 1, 1);
00526 }
00527
00528 extern "C" void dsyev(
00529 char* jobz, char* uplo,
00530 int_t* n, double a[], int_t* lda,
00531 double w[],
00532 double work[], int_t* lwork, int_t* info,
00533 int_t len_jobz, int_t len_uplo);
00534
00535
00536 void syev(char jobz, char uplo, Matrix< double >& A, ColumnVector< double >&w, int_t& info)
00537 {
00538 assert(A._m == A._n);
00539 int_t nb = 8;
00540 int_t lwork = (nb + 2)*A._n;
00541 Vector< double > work(lwork);
00542
00543 dsyev(&jobz, &uplo,
00544 &A._m, A._v, &A._ld,
00545 w._v,
00546 work._v, &lwork, &info,
00547 1, 1);
00548 }
00549
00550 }