// 

// interval::Taylor -- domain for automatic differentation
// Starting with trivial Taylor series for constants and variables,
// this domain deduces truncated Taylor series for more complicated
// expressions.

// First implementation: Does not use an expansion point and bounds
// over a possibly different interval, but rather bounds the truncated
// Taylor series over the whole interval (historically the first extension
// of Taylor series to interval numerics).

// Most of the recursion formulas are from O. Aberth: Precise Numerical
// methods using C++, Academic Press, 1998.

domain interval::Taylor(order, x, xRange)
  inherits Dom::BaseDomain;

  // accessors
  taylorPoly := t -> extop(t, 1);
  _index := (t, i) -> coeff(dom::taylorPoly(t), i);
    
  // constructors
  constant := c  -> new(dom, poly([[interval(c), 0]], [hold(X)]));
  variable := () -> new(dom, poly([[interval(xRange), 0], [1, 1]], [hold(X)]));

  new := ex -> dom::convert(ex);

  convert :=
  proc(ex)
  begin
    case domtype(ex)
      of dom do
        return(ex);
        break;
      of DOM_EXPR do
        case op(ex,0)
          of hold(_power) do
            dom::convert(op(ex, 1)) ^ op(ex, 2);
            break;
          otherwise
            eval(op(ex,0)(op(map([op(ex)], dom::convert))));
        end_case;
        break;
      of DOM_IDENT do
        if ex = x then
          dom::variable();
        else
//          error("independent variable not permitted");
          dom::constant(hull(ex));
        end_if;
        break;
      of DOM_POLY do
        dom::convert(expr(ex));
        break;
      otherwise
        if testtype(ex, Type::Constant) then
          dom::constant(ex);
        else
          FAIL;
        end_if;
    end_case;
  end_proc;

  // operations.
  binPlus := (t1, t2) -> new(dom,
                             dom::taylorPoly(dom::convert(t1))
                             + dom::taylorPoly(dom::convert(t2)));

  binMult :=
  proc(t1, t2)
    local t3;
  begin
    t1 := dom::convert(t1);
    t2 := dom::convert(t2);
    t3 := dom::taylorPoly(t1) * dom::taylorPoly(t2);
    // cut off higher expos.
    new(dom, divide(t3, poly([[1,order+1]], [hold(X)]), Rem));
  end_proc;

  // TODO: For large n, explicit formulas will be
  // faster.  Try how large n must be for this.
  _power :=
  proc(t, n)
    local t2, t3, g0, i, k;
  begin
    if domtype(n) = DOM_INT then
      if n>0 then
        t2 := poly2list(dom::taylorPoly(t)^n);
        t2 := select(t2, x->x[2] <= order);
        new(dom, poly(t2, [hold(X)]));
      elif n<0 then
        dom::_invert(t)^(-n);
      else
        assert(0=n);
        dom::constant(1);
      end_if;
    else // n is not an integer
      if n = 1/2 then
	dom::sqrt(t);
      else
        exp(ln(t)*dom::convert(n));
      end_if;
    end_if;
  end_proc;
  
  // 1/f
  _invert :=
  proc(t)
   local pt, pt1, i, k;
  begin
   pt := dom::taylorPoly(t);
   pt1 := [[0,i]$i=0..order];
 // If this fails, "division by zero" is the correct error message
   pt1[1][1] := 1/pt(0);
   for k from 1 to order do
    pt1[k+1][1] := -pt1[1][1] * _plus((pt1[i+1][1]*coeff(pt,k-i))$i=0..k-1);
   end_for;
   new(dom, poly(pt1, [hold(X)]));
  end_proc;

  _divide :=
  proc(t1, t2)
   local pt1, pt2, pt3, h0, i, k;
  begin
    pt1 := dom::taylorPoly(t1);
    pt2 := dom::taylorPoly(t2);
    pt3 := [[0,i]$i=0..order];
 // If this fails, "division by zero" is the correct error message
    pt3[1][1] := pt1(0)/pt2(0);
    h0 := 1/pt2(0);
    for k from 1 to order do
      pt3[k+1][1] := h0 * (coeff(pt1,k) -
                           _plus((pt3[i+1][1]*coeff(pt2,k-i))$i=0..k-1));
    end_for;
    new(dom, poly(pt3, [hold(X)]));
  end_proc;

  // multiplying exp(coeff(pt, 0)) after the loop
  // is better because of subdistributivity
  exp :=
  proc(t: dom)
    local pt, pt1, i, k, exp_a0;
  begin
    pt := dom::taylorPoly(t);
    pt1 := [[0,i]$i=0..order];
    pt1[1][1] := 1;
    for k from 1 to order do
      pt1[k+1][1] := 1/k*_plus(i*coeff(pt,i)*pt1[k-i+1][1]$i=1..k);
    end_for;
    exp_a0 := exp(coeff(pt, 0));
    for k from 0 to order do
      pt1[k+1][1] := pt1[k+1][1] * exp_a0;
    end_for;
    new(dom, poly(pt1, [hold(X)]));
  end_proc;

  ln :=
  proc(t: dom)
    local pt, pt1, i, k, g0;
  begin
    pt := dom::taylorPoly(t);
    pt1 := [[0,i]$i=0..order];
    g0 := coeff(pt, 0);
    pt1[1][1] := ln(g0);
    for k from 1 to order do
      pt1[k+1][1] := 1/g0*(coeff(pt, k) -
                           1/k*_plus(i*coeff(pt,k-i)*pt1[i+1][1]$i=1..k-1));
    end_for;
    new(dom, poly(pt1, [hold(X)]));
  end_proc;  

  // sine and cosine must be computed simultaneously
  sincos :=
  proc(t: dom)
    local pt, pts, ptc, i, k;
  begin
    pt := dom::taylorPoly(t);
    pts := [[0,i]$i=0..order];
    ptc := [[0,i]$i=0..order];
    
    pts[1][1] := sin(coeff(pt, 0));
    ptc[1][1] := cos(coeff(pt, 0));
    for k from 1 to order do
      pts[k+1][1] := 1/k*_plus(i*coeff(pt,i)*ptc[k-i+1][1]$i=1..k);
      ptc[k+1][1] := -1/k*_plus(i*coeff(pt,i)*pts[k-i+1][1]$i=1..k);
    end_for;
    [new(dom, poly(pts, [hold(X)])),new(dom, poly(ptc, [hold(X)]))];
  end_proc;  

  sin := t -> dom::sincos(t)[1];
  cos := t -> dom::sincos(t)[2];

  // the same for sinh and cosh:
  sinhcosh :=
  proc(t: dom)
    local pt, pts, ptc, i, k;
  begin
    pt := dom::taylorPoly(t);
    pts := [[0,i]$i=0..order];
    ptc := [[0,i]$i=0..order];
    
    pts[1][1] := sinh(coeff(pt, 0));
    ptc[1][1] := cosh(coeff(pt, 0));
    for k from 1 to order do
      pts[k+1][1] := 1/k*_plus(i*coeff(pt,i)*ptc[k-i+1][1]$i=1..k);
      ptc[k+1][1] := 1/k*_plus(i*coeff(pt,i)*pts[k-i+1][1]$i=1..k);
    end_for;
    [new(dom, poly(pts, [hold(X)])),new(dom, poly(ptc, [hold(X)]))];
  end_proc;  

  sinh := t -> dom::sinhcosh(t)[1];
  cosh := t -> dom::sinhcosh(t)[2];

  arctan :=
  proc(t: dom)
    local pt, pt1, i, k, denom_fn, d0;
  begin
    pt := dom::taylorPoly(t);
    pt1 := [[0,i]$i=0..order];
    denom_fn := dom::taylorPoly(t*t+dom::constant(1));
    d0 := 1/coeff(denom_fn, 0);
    pt1[1][1] := arctan(coeff(pt, 0));
    for k from 1 to order do
      pt1[k+1][1] := d0*(coeff(pt, k) -
                         1/k*_plus(i*coeff(denom_fn,k-i)*pt1[i+1][1]$i=1..k-1));
    end_for;
    new(dom, poly(pt1, [hold(X)]));
  end_proc;  
  
  arcsin :=
  proc(t: dom)
    local pt, pt1, i, k, denom_fn, d0;
  begin
    pt := dom::taylorPoly(t);
    pt1 := [[0,i]$i=0..order];
    denom_fn := dom::taylorPoly((dom::constant(1)-t*t)^(1/2));
    d0 := 1/coeff(denom_fn, 0);
    pt1[1][1] := arcsin(coeff(pt, 0));
    for k from 1 to order do
      pt1[k+1][1] := d0*(coeff(pt, k) -
                         1/k*_plus(i*coeff(denom_fn,k-i)*pt1[i+1][1]$i=1..k-1));
    end_for;
    new(dom, poly(pt1, [hold(X)]));
  end_proc;  

// found manually, later also seen in Neumair, p. 46
// special case for sqrt, much narrower than general _power-results
  sqrt :=
  proc(t: dom)
    local pt, pt1, i, k, g0;
  begin
    pt := dom::taylorPoly(t);
    pt1 := [[0,i]$i=0..order];
    g0 := sqrt(coeff(pt, 0));
    pt1[1][1] := g0;
    g0 := 1/g0;
    for k from 1 to order do
      pt1[k+1][1] := g0*(1/2*coeff(pt, k) -
                          1/k*_plus(i*pt1[k-i+1][1]*pt1[i+1][1]$i=1..k-1));
    end_for;
    new(dom, poly(pt1, [hold(X)]));
  end_proc;  
  

  ////////////////////////////////////////////////////////////  
  print :=
  proc(t)
    local pl, i, n, term;
  begin
    pl := dom::taylorPoly(t);
    [[/*hold(_in)(hold(_fnest)(D,i), */ (
		coeff(pl, i))$i=0..order]], hold(_in)(x, xRange);
  end_proc;

  TeX := proc(t, pri) 
	   local p;
	 begin
	   p := dom::print(t);
	   generate::tex(op(p,1), pri) . "{}_{|_{".
	   generate::tex(op(p,2), output::Priority::Noop)."}}";
	 end_proc;
  
  testtype :=
  proc(ex, T)
  begin
    if ex::dom = dom and T in {Type::Arithmetical} then
      return(TRUE);
    end_if;
    FAIL
  end_proc;

  set2expr :=
  proc(T, f)
    local ff, i, p;
  begin
    p := dom::taylorPoly(T);
    if domtype(f) = DOM_LIST and nops(f) = order+1 then
      for i from 1 to order+1 do
        if not f[i] in coeff(p, i-1) then
          return(FALSE);
        end_if;
      end_for;
    else
      ff := f;
      for i from 0 to order do
        if not eval(subs(ff, x=xRange, Unsimplified)) in coeff(p, i) then
          return(FALSE);
        end_if;
        ff := diff(ff, x)/(i+1);
      end_for;
    end_if;
    TRUE;
  end_proc;

  initDomain :=
  proc()
  begin
    dom::zero := dom::constant(0);
    dom::one := dom::constant(1);
    dom::_mult := misc::genassop(dom::binMult, dom::one);
    dom::_plus := misc::genassop(dom::binPlus, dom::zero);
  end;

// domain constructor
begin
  if args(0) <> 3 then error("Wrong number of args"); end_if;
  assert(testtype(order, Type::NonNegInt));
  assert(testtype(x, DOM_IDENT));
  assert(testtype(xRange, Dom::FloatIV));
end_domain:

/*

fn := sqrt:
n := 5:
r := 0...1:

T := interval::Taylor(n, x, r):
f := x^3+1:
pf := T(f):
epf:=fn(pf);
eval(subs(diff(fn(f),x$i), x=r))$i=0..n;

*/
