//   
/* 


partfrac(f, x <, options>) 

f - arithmetical expression that must be rational in x
x - variable

Possible options:

Adjoin = ...      this option is just passed to factor when factoring the denominator
Domain = ...      just passed to factor
Full              do a full factorization in linear factors
List              returns [c1, ..., cn], [f1, ..., fn] such that c1/f1 + ... + cn/fn = f
Mapcoeffs = F     maps F to the coefficients before returning the partial fraction decomposition 
MaxDegree = n     solve with MaxDegree = n in case of full factorization. Default: n=2
*/

partfrac :=
proc(f, x)
  local g, d, i, n, l, m, p, options, optiont, r, s, minpolys,
  full, mapCoeffs, ListFormat;
begin

  // argument checking  
  if args(0)=0 then
    error("No argument given")
  end_if;

  if f::dom::partfrac <> FAIL then
    return(f::dom::partfrac(args()))
  end_if;
  
  if args(0)<2 then
    if nops((x:= numeric::indets(f))) < 1 then
      return(f)
    elif nops(x) > 1 then
      error("Second argument must be a variable")
    end_if;
    x:= op(x);
  end_if;

  if testargs() then
    if traperror((p:=poly(1, [x]))) <> 0 then
      error("Illegal indeterminate")
    end_if;
  end_if;

  // if there is any float in the expression,
  // to hell with the options. Do the numerical 
  // factorization of the denominator:
  if stdlib::hasfloat(f) then
     return(numeric::partfrac(f, x));
  end_if:

  options:=
  table(MaxDegree = 2,
        Adjoin = {},
        Full = FALSE,
        Domain = Expr,
        Mapcoeffs = id,
        List = FALSE
        );
  
  optiont:=
  table(MaxDegree = Type::PosInt,
        Adjoin = Type::AnyType,
        Full = DOM_BOOL,
        Domain = Type::AnyType,
        Mapcoeffs = Type::AnyType,
        List = DOM_BOOL
        );

  options:= prog::getOptions(3, [args()], options, TRUE, optiont)[1];
  mapCoeffs := options[Mapcoeffs];
  delete options[Mapcoeffs];
  

  // first, we rationalize the input  
  if type(x) <> DOM_IDENT then 
     p:= genident("x");
     f:= subs(f, x = p);
     g:= {p=x};
     x:= p
  else
     g:= {}
  end_if;   
    
  [r, s, minpolys]:= [rationalize(f, FindRelations = ["_power"], MinimalPolynomials)];
  s:= subs(s, g) union g;

  // write f = n/d, where n and d are polynomial expressions
  [n, d]:= normal(r, List);
  //----------------------------------

  l:= divide(n, d, [x]);
  
  
  
  p:= l[1]; // polynomial part
  if iszero(p) then p:= [[], []] else p:= [[p], [1]] end_if;
  n:= l[2];


  full:= options[Full];
  options[Full]:= FALSE;
  ListFormat:= options[List];
  delete options[List];
  // should we do the following?
  // options[Adjoin]:= {op(options[Adjoin])} union map(s, op, 2);

  l:= map([op(factor(poly(d, [x]), op(options)))], expr);
  
  g:= stdlib::partfracSimple(n, [l[2*i]^l[2*i+1] $i=1..nops(l) div 2], x);
  if full then
    m:= [stdlib::splitComponent(g[i], l[2*i], l[2*i+1], x, options) $i=1..nops(g)]
  else  
    m:= [stdlib::partfracFadic(g[i], l[2*i], l[2*i+1], x, mapCoeffs) $i=1..nops(g)]
  end_if;
  // now m is a list of lists; each contains two operands (list of numerators, lists of denominators); concatenate numerators and denominators 
  n:= _concat(op(map(m, op, 1)));
  d:= _concat(op(map(m, op, 2)));
  n:= map(n, _divide, l[1]);
  
  n:= p[1].n;
  d:= p[2].d;
  
  // back - substitute in denominators
  d:= subs(d, s, EvalChanges);
  // we use only those minpolys that describe algebraic numbers
  minpolys:= select(minpolys, X -> nops(indets(X)) = 1);
  // reduce numerators by minimal polynomials
  if nops(minpolys) > 0 then
    n:= map(n, property::normalGroebner, [op(minpolys)])
  end_if;  
  // back - substitute in numerators
  n:= subs(n, s, EvalChanges);
  // apply user-defined function to numerators
  n:= map(n, mapCoeffs);
  if ListFormat then
    n, d
  else
    _plus(op(zip(n, d, _divide)))
  end_if
end_proc:


// auxiliary function stdlib::partfracSimple
// given u and a list of factors f1,...,fn, this method determines g1, ..., gn such that
// u/ (f1*....*fn) = g1/f1 + ... + gn/fn 
// multiplying the latter by f1*...fn, we obtain
// u = g1 * (f2*...*fn) + g2*(f1*f3*...*fn) + ... + gn*(f1*f2*...*f.(n-1))
//   =: g1*m1 + ... + gn*mn
// thus gi*mi \equiv u (mod fi) for all i
// thus finding si with si*mi + t*fi = 1, we know that u*si*mi \equiv u (mod fi), so gi:= mi*si is the desired solution
stdlib::partfracSimple:=
proc(u, f: DOM_LIST, x)
  local v, m, n, s, t, i, j;
begin
  n:= nops(f);
  if n=1 then return([u]) end_if;
  m:= [_mult(f[j] $j=1..i-1, f[j] $j=i+1..n) $i=1..n]; // now m[i] = _prod(f[j], j<>i)
  s:= [FAIL $n];
  v:= map(f, fi -> divide(u, fi, [x], Rem));
  for i from 1 to n do
    [s[i], t]:= [solvelib::pdioe(m[i], f[i], v[i], x)]
  end_for;  
  // test whether we want a normalization in the following step
  map([divide(normal(s[i], Expand = FALSE), f[i], [x], Rem) $i=1..n], normal, Expand = FALSE)
end_proc:


// auxiliary function stdlib::partfracFadic(g, f, k, x, F)
// g - polynomial
// f - polynomial
// k - positive integer
// x - variable w.r.t. which we do the decomposition
// F - function to be mapped to the coefficients
// write g/f^k as g1/f + g2/f^2 + g3/f^3 + ... + gk/f^k such that deg(gi) < deg(f)
// multiplying both sides by f^k, this is equivalent to
// g = g1*f^(k-1) + ... + gk (f-adic expansion of g)

stdlib::partfracFadic:=
proc(g, f, k, x, F)
  local glist, i, q;
begin
  glist:= [FAIL $k];
  q:= g;
  for i from k downto 1 do 
    [q, glist[i]]:= [divide(q, f, [x])]
  end_for;
  [[glist[i] $i=1..k], [f^i $i=1..k]]  
end_proc:



stdlib::splitComponent:=
proc(f, d, n, x, options)
  local i, j, alpha, dalpha, F, R, A, B, RootOfd;
begin
  alpha:= genident("alpha");
  dalpha:= evalp(d, x=alpha);

  if (map({coeff(f, [x])}, domtype) union map({coeff(d, [x])}, domtype))
    minus {DOM_INT, DOM_RAT} = {} then
    F:= Dom::AlgebraicExtension(Dom::Rational, dalpha, alpha)
  else
    F:= Dom::AlgebraicExtension(Dom::ExpressionField(normal, iszero),
                                  dalpha, alpha);
  end_if;

  // let d = (x-alpha)*R
   
  //
  R:= divide(poly(d, [x], F), poly(x-alpha, [x], F), Exact);
  assert(R <> FAIL);
    // write f = A*(x-alpha)^n + B*R^n 
  [A, B]:= [solvelib::pdioe(poly(x-alpha, [x], F)^n, R^n, poly(f, [x], F))];
    // then f / d  =  f / ((x-alpha)^n * R^n)
    //             =  A/R^n + B/(x-alpha)^n
    // in order to find b0, b1, ... with
    // B(x)/(x-alpha)^n = b0/(x-alpha)^n + b1/(x-alpha)^(n-1) + ....,
    // we make a linear shift
    // B(x+alpha)/(x^n) = b0/x^n + b1/x^(n-1) + ....
  B:= polylib::shift(B, F(alpha));
  RootOfd:= solve(d, x, MaxDegree = options[MaxDegree], IgnoreSpecialCases);
  
  A:= [expr(coeff(B, i)) $i=0..degree(B)], [(x-alpha)^(n-i) $i=0..degree(B)];
  if type(RootOfd) = RootOf then
    [[freeze(sum)(op(A, [1, i+1])/op(A, [2, i+1]), alpha in RootOfd) $i=0..degree(B)], [1 $(degree(B) + 1)]]
  elif type(RootOfd) = DOM_SET then 
    [_concat(evalAt(op(A, j), alpha = op(RootOfd, i)) $i=1..nops(RootOfd)) $j=1..2 ]
  else
    [[sum(op(A, [1, i+1])/op(A, [2, i+1]), alpha in RootOfd) $i=0..degree(B)], [1 $(degree(B) + 1)]]
  end_if;
end_proc:
  
