/* 
   =====================================
   METHODS FOR SOLVING RICCATI EQUATIONS
   =====================================
 
   REFERENCE: [1] D. Zwillinger: "Handbook of Differential Equations", 
                  Section 86, pp. 354.

   DETAILS: 

     ode::riccati(eq,y,x) tries to solve the first-order ODE eq=0 wrt y(x)
     by rewriting it as a linear second order ODE. 
  
     Applicable to: equations of the form y' = a(x) y^2 + b(x) y + c(x).
     With the change y=-w'/w/a, they rewrite to w''-(a'/a+b)w'+a*c*w=0.

   EXAMPLES: 

     >> ode::riccati(diff(y(x),x)-exp(x)*y(x)^2+y(x)-exp(-x),y,x);

     >> ode::riccati(diff(y(x),x)-y(x)^2+x*y(x)-1,y,x);

*/

ode::riccati:= proc(eq,y,x,solveOptions,odeOptions) 
  local a,b,c,d,e,dd,ee,eq0,parameq,s,zz,C, yp, w, sol, optIgnoreAnalyticConstraints,
        intOptions;
begin
  optIgnoreAnalyticConstraints:= if has(solveOptions, IgnoreAnalyticConstraints) then 
                IgnoreAnalyticConstraints;
              else
                null();
              end_if;  yp := genident();
  intOptions:= null();            
  if has(solveOptions, IgnoreSpecialCases) then 
    intOptions:= intOptions,IgnoreSpecialCases;
  end_if;
  if has(solveOptions, IgnoreAnalyticConstraints) then   
    intOptions:= intOptions,IgnoreAnalyticConstraints;
  end_if;   
  w:= genident();
  eq0:= eq; 
  parameq:= indets(eq) minus {x,y,PI,EULER,CATALAN};
  userinfo(2,"trying to recognize a Riccati equation");
  eq:= subs(eq,diff(y(x),x)=yp,y(x)=y,EvalChanges);
  if testtype(eq,Type::PolyExpr(yp)) then
    if degree((eq:= poly(eq,[yp])))=1 then
      eq:= -coeff(eq,0)/coeff(eq,1); // should match a*y^2+b*y+c
      if testtype(eq,Type::PolyExpr(y)) then
        if degree((eq:=poly(eq,[y])))=2 then
          userinfo(1,"Riccati equation");
          a:= coeff(eq,2); 
          b:= coeff(eq,1); 
          c:= coeff(eq,0);
          // first try to find a particular solution
          userinfo(2,"try to find a particular solution");
//          eq:= solvelib::discreteSolve(subs(eq0,y(x)=w*x,EvalChanges),w,
//                                       op(solveOptions));
          eq:= solve(subs(eq0,y(x)=w*x,EvalChanges),w,IgnoreSpecialCases,op(solveOptions));
          if type(eq)=DOM_SET then
            for s in eq do
              if has(s,x) or domtype(s)=RootOf then
                next;
              end_if;
              s:= s*x; // particular solution
              userinfo(1,"Riccati method worked");
              sol:= {s} union
                    map(ode::bernoulli(-diff(y(x),x)+(2*a*s+b)*y(x)+a*y(x)^2, // cf. [1], (86.4)
                                       y,x,solveOptions,{}),
                          _plus,s);
              if not has(sol, FAIL) then 
                return(sol);
              end_if; 
            end_for
          end_if;
//          eq:=solvelib::discreteSolve(subs(eq0,y(x)=w/x,EvalChanges),w,
//                                      op(solveOptions));
          eq:= solve(subs(eq0,y(x)=w/x,EvalChanges),w,IgnoreSpecialCases,op(solveOptions));
          if type(eq)=DOM_SET then
            for s in eq do
              if has(s,x) or domtype(s)=RootOf then
                next;
              end_if;
              s:= s/x; // particular solution
              userinfo(1,"Riccati method worked");
              sol:= {s} union
                    map(ode::bernoulli(-diff(y(x),x) + (2*a*s+b)*y(x) + a*y(x)^2,
                                       y,x,solveOptions,{}),
                          _plus,s);
              if not has(sol, FAIL) then 
                return(sol);
              end_if; 
            end_for
          end_if;
          //  Another way to find a particular solution in a particular case:
          if iszero(b) and not has(a,x) then
            zz:= {};
            misc::maprec(c, {"function"} = (t-> (if nops(t)=1 and op(t,1)=x then
                                                   zz:=zz union {t}
                                                 end_if;
                         t)));
            if nops(zz)=1 then
              C:= genident();
              s:= C*zz[1];
//              zz:= solvelib::discreteSolve(diff(s,x)-a*s^2-c,C,op(solveOptions));
              zz:= solve(diff(s,x)-a*s^2-c,C,IgnoreSpecialCases,op(solveOptions));           
              if type(zz) <> DOM_SET then 
                zz:= {};
              end_if;  
              zz:= select(zz,not has,x);
              if zz <> {} then
                s:= subs(s,C=op(zz,1),EvalChanges);
                userinfo(1,"Riccati method worked");
                sol:= {s} union
                      map(ode::bernoulli(-diff(y(x),x) + (2*a*s)*y(x) + a*y(x)^2,
                                       y,x,solveOptions,{}),
                            _plus,s);
                if not has(sol, FAIL) then 
                  return(sol);
                end_if; 
              end_if;
            end_if;
          end_if;
          /* 
             otherwise try to solve the associated second order 
             linear equation 
          */
          dd:= expand(-(diff(a,x)/a+b),optIgnoreAnalyticConstraints); 
          ee:= expand(a*c,optIgnoreAnalyticConstraints);
          if has(dd,{sin,cos}) or has(ee,{sin,cos}) then
            //d:= simplify(expand(simplify(dd,optIgnoreAnalyticConstraints), optIgnoreAnalyticConstraints), optIgnoreAnalyticConstraints);
            //e:= simplify(expand(simplify(ee,optIgnoreAnalyticConstraints), optIgnoreAnalyticConstraints), optIgnoreAnalyticConstraints);
            d:= simplify(dd,optIgnoreAnalyticConstraints);
            e:= simplify(ee,optIgnoreAnalyticConstraints);
          else
            d:= ode::normal(dd); 
            e:= ode::normal(ee);
          end_if;
          // check, if constant coefficients 
          if not has([d,e],x) then
            eq:= ode::solve_eq(diff(w(x),x,x)+d*diff(w(x),x)+e*w(x),w,x,{},solveOptions,{});
            if not has(eq,FAIL) then
              userinfo(1,"Riccati transformation worked");
              d:= indets(eq) minus {x,PI,EULER,CATALAN} minus parameq;
              if nops(d)=2 then 
                eq:=subs(eq,op(d,2)=1);
              end_if;
              //return(map(eq,()->normal(expand(-diff(args(1),x)/args(1)/a, optIgnoreAnalyticConstraints))))
              return(map(eq,()->ode::normal(-diff(args(1),x)/args(1)/a)));
            end_if;
          end_if;
          /* 
             check, if coefficients are rational functions over the
             rational numbers 
          */
          userinfo(2,"particular solution failed");
          userinfo(2,"do Riccati transformation"); 
          eq:=ode::lookUp2ndOrderLinear(diff(w(x),x,x)+d*diff(w(x),x)+e*w(x),w,x,
                                        solveOptions,{});
          if eq={} and has(dd, {sin,cos}) or has(ee, {sin,cos}) then 
            eq:=ode::lookUp2ndOrderLinear(diff(w(x),x,x)+dd*diff(w(x),x)+ee*w(x),w,x,
                                          solveOptions,{});
          end_if;  
          if map({d,e}, testtype, Type::RatExpr(x))={TRUE} and eq={} then
            eq:= ode::specfunc(diff(w(x),x,x)+d*diff(w(x),x)+e*w(x),w,x,
                               solveOptions,{});
          end_if;
          if map({d,e}, testtype, Type::RatExpr(x))={TRUE} and eq={} then
            eq:=ode::secondOrder(diff(w(x),x,x)+d*diff(w(x),x)+e*w(x),w,x,
                                 solveOptions,{});
          end_if;
          // if only one solution has been found, generate the second one via integration
          if nops(eq) = 1 and not has(eq,FAIL) then 
            sol:= eq[1]; 
            eq:= eq union {sol*int(exp(-int(d,x,intOptions))/sol^2,x,intOptions)};
          end_if;
          if eq<>{} then
            userinfo(1,"Riccati transformation worked");
              if hastype(eval(eq), RootOf) then
                if ode::printWarningsFlag then
                  ode::odeWarning("Cannot do back transformation,\n".
                  "solution of transformed equation \n".
                  expr2text(diff(w(x),x,x)+d*diff(w(x),x)+e*w(x)).
                  "\n is \n".expr2text(eq)."\n solution of ".
                  "given equation is diff(Y,".expr2text(x).
                  ")/Y/".expr2text(a));
                end_if;
                return(FAIL);
              else
                eq:= {op(eq,1)*genident("C")+op(eq,2)}
              end_if;
              //return(map(eq,()->normal(expand(-diff(args(1),x)/args(1)/a, optIgnoreAnalyticConstraints))))
              return(map(eq,()->ode::normal(-diff(args(1),x)/args(1)/a)));
            end_if;
          userinfo(1,"Riccati method failed");
        end_if;
      end_if;
    end_if;
  end_if;
  if contains(odeOptions, Type = Riccati) then 
    ode::odeWarning("cannot detect Riccati equation");
  end_if;  
  
  return(FAIL);
end_proc:

