/* 
   ============================================
   METHODS FOR SOLVING MATRIX RICCATI EQUATIONS
   ============================================
 
   REFERENCE: [1] D. Zwillinger: "Handbook of Differential Equations", 
                  Section 87, pp. 357.

   DETAILS: 

    ode::matrixRiccati(sys,unk,t) tries to solve the system sys of order 1
    wrt unk, trying to recognize a matrix Riccati system. 

   EXAMPLE [cf (87.4) in [1]]
     >> ode::matrixRiccati(
             {diff(x(t),t)-a(t)*(y(t)^2-x(t)^2)-2*b(t)*x(t)*y(t)-2*c*x(t),
              diff(y(t),t)-b(t)*(y(t)^2-x(t)^2)+2*a(t)*x(t)*y(t)-2*c*y(t)},
             {x(t),y(t)},t
        );
*/

// We assume here that the maximal differential order is 1. 
ode::matrixRiccati:= proc(sys,unk,t,solveOptions,odeOptions)
  local Z,A,K,n,i,l,x,j,Q,unkl;
begin
  n:= nops(sys);
  if nops(unk) <> n then
    return(FAIL);
  end_if;
  /* 
     First check that each equation is of the form

        diff(xi(t),t) = ... + a[i,j,k]*xj*xk + ... + b[i,j]*xj + ...

     There are n^2*(n+1)/2 coefficients a[i,j,k] and n^2 coefficients b[i,j] 
  */
  sys:= [op(sys)];
  unkl:= [op(unk)]; // to get a fixed order 
  for i from 1 to n do
    if type(sys[i]) = "_plus" then
      l:= [op(sys[i])];
    else
      l:= [sys[i]];
    end_if;
    l:= select(l,testtype,"diff");
    l:= select(l,proc()
                 begin
                   contains(unk,op(args(1),1))
                 end_proc);
    if nops(l)<>1 then
      return(FAIL);
    end_if;
    l:= op(l,1);
    x[i]:= op(l,[1]);
    sys[i]:= l-sys[i];
    for j from 1 to n do
      if not testtype(sys[i],Type::PolyExpr(unkl[j])) then
        return(FAIL);
      end_if
    end_for;
    if degree(sys[i],unkl) > 2 or ldegree(sys[i],unkl) = 0 then
      return(FAIL);
    end_if
  end_for;
  if n = 2 then // try Z=[[x,y],[y,-x]], A=[[a,b],[c,d]]
    A:= matrix(2,2);
    A[1,1]:=coeff(poly(sys[1],[x[1]]),2); 
    A[2,2]:=coeff(poly(sys[1],[x[2]]),2);
    A[1,2]:=coeff(poly(sys[2],[x[2]]),2); 
    A[2,1]:=-coeff(poly(sys[2],[x[1]]),2);
    if coeff(coeff(sys[1],x[1],1),x[2],1) = A[1,2]+A[2,1] then
      if coeff(coeff(sys[2],x[1],1),x[2],1) = A[1,1]-A[2,2] then
        Z:= matrix([[x[1],x[2]],[x[2],-x[1]]]);
        K:= matrix(2,2);
        K[1,1]:= coeff(coeff(sys[1],x[1],1),x[2],0)/2;
        K[1,2]:= coeff(coeff(sys[1],x[1],0),x[2],1)/2;
        K[2,1]:= coeff(coeff(sys[2],x[1],1),x[2],0)+K[1,2];
        K[2,2]:= coeff(coeff(sys[2],x[1],0),x[2],1)/2-K[1,2];
        // one should call exp(K,t) here
        if not iszero(K[1,2]) or not iszero(K[2,1]) then
          return(FAIL);
        end_if;
        K[1,1]:= exp(K[1,1]*t); 
        K[2,2]:= exp(K[2,2]*t); 
        Q:= K;
        // now apply formula (77.2)
        A:= linalg::transpose(Q)*A*Q;
        A:= map(A, int, t);
        if traperror((A:= (subs(Z,t=0)^(-1)-A)^(-1))) <> 0 then 
          return(FAIL);
        end_if;  
        Z:= Q*A*linalg::transpose(Q);
        return({[x[1] = Z[1,1],
                 x[2] = Z[2,1]]});
      end_if;
    end_if;
  end_if;
  
  return(FAIL);
end_proc:

