00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050 #include "LinSolvers.h"
00051
00052
00053 const int LinSolvers::LINSOLV_CGS = 1;
00054 const int LinSolvers::LINSOLV_MINRES = 2;
00055 const int LinSolvers::LINSOLV_SYMMLQ = 3;
00056 const int LinSolvers::LINSOLV_QMR = 4;
00057
00058
00059 LinSolvers::LinSolvers(const Epetra_Operator *AA, const Epetra_Operator *PP, double *wTmp,
00060 int _type)
00061 : A(AA), Prec(PP), work_s(wTmp), solverType(_type)
00062 {
00063
00064 }
00065
00066
00067 int LinSolvers::solve(const Epetra_Vector &b, Epetra_Vector &x, double _tol, int _itMax) {
00068
00069 int info;
00070
00071 switch (solverType) {
00072 default:
00073 case LINSOLV_QMR:
00074 info = QMRS(b, x, _tol, _itMax);
00075 break;
00076 }
00077
00078 return info;
00079
00080 }
00081
00082
00083 int LinSolvers::QMRS(const Epetra_Vector &b, Epetra_Vector &x, double _tol, int _itMax) {
00084
00085 int maxIter = _itMax;
00086 double tol = _tol;
00087
00088
00089
00090 int nLocal = b.MyLength();
00091
00092 Epetra_Vector wrk1(View, b.Map(), work_s);
00093 Epetra_Vector p(View, b.Map(), work_s + nLocal);
00094 Epetra_Vector d(View, b.Map(), work_s + 2*nLocal);
00095 Epetra_Vector v1(View, b.Map(), work_s + 3*nLocal);
00096 Epetra_Vector t(View, b.Map(), work_s + 4*nLocal);
00097 Epetra_Vector g(View, b.Map(), work_s + 5*nLocal);
00098
00099 double beta;
00100 double c0, cc, c1;
00101 double delta;
00102 double eps0, eta0;
00103 double res_init, rho0 = 0.0, rho1 = 0.0, rho1inv;
00104 double tau, theta0, theta;
00105 double xi1;
00106
00107 int iter;
00108
00109 memcpy(v1.Values(), b.Values(), nLocal*sizeof(double));
00110 v1.Norm2(&rho0);
00111 tau = rho0;
00112
00113 v1.Scale(1.0/rho0);
00114 p.PutScalar(0.0);
00115 g.PutScalar(0.0);
00116 d.PutScalar(0.0);
00117 x.PutScalar(0.0);
00118
00119 c0 = 1.0;
00120 eps0 = 1.0;
00121 xi1 = 1.0;
00122 theta0 = 0.0;
00123 eta0 = -1.0;
00124 res_init = rho0;
00125
00126 for (iter = 0; iter < maxIter; ++iter) {
00127
00128 if (eps0 == 0.0)
00129 return -2;
00130
00131 if (Prec)
00132 Prec->ApplyInverse(v1, wrk1);
00133 else
00134 memcpy(wrk1.Values(), v1.Values(), nLocal*sizeof(double));
00135
00136 wrk1.Dot(v1, &delta);
00137 if (delta == 0.0)
00138 return -2;
00139
00140 cc = xi1*(delta/eps0);
00141
00142 p.Update(1.0, v1, -cc);
00143 g.Update(1.0, wrk1, -cc);
00144
00145 A->Apply(g, t);
00146
00147 g.Dot(t, &eps0);
00148 beta = eps0/delta;
00149 v1.Update(1.0, t, -beta);
00150
00151 v1.Norm2(&rho1);
00152 xi1 = rho1;
00153
00154 theta = c0*fabs(beta);
00155 if (theta == 0.0)
00156 return -2;
00157 theta = rho1/theta;
00158
00159 c1 = 1.0/sqrt(1.0 + theta*theta);
00160 eta0 = -eta0*rho0*c1*c1/(beta*c0*c0);
00161 tau = tau*theta*c1;
00162
00163 if (rho1 == 0.0)
00164 return -2;
00165
00166 cc = theta0*c1;
00167 cc = cc*cc;
00168 rho1inv = 1.0/rho1;
00169
00170 d.Update(eta0, p, cc);
00171 x.Update(1.0, d, 1.0);
00172 v1.Scale(rho1inv);
00173
00174 if (xi1 == 0)
00175 return -2;
00176
00177 rho0 = rho1;
00178 c0 = c1;
00179 theta0 = theta;
00180
00181 if (tau <= res_init*tol)
00182 break;
00183
00184 }
00185
00186 memcpy(wrk1.Values(), x.Values(), nLocal*sizeof(double));
00187 Prec->ApplyInverse(wrk1, x);
00188
00189 if (iter < maxIter)
00190 iter += 1;
00191
00192 if (tau > res_init*tol)
00193 return -1;
00194
00195 return iter;
00196
00197 }
00198
00199