// 

// interval::autodiff -- "automatic differentiation".
//
// interval::autodiff(ex, deg) returns a procedure.  This procedure,
// on input of an interval a...b, returns a list of intervals
// bounding the first <deg+1> Taylor coefficients of ex over a...b.


alias(AD = interval::autodiff):

interval::autodiff :=
  proc(ex, deg)
    local idents, temps, opt_ex, split_subex,
	  eq, subex, t, ret, i, ops;
  begin
    if deg<2 then
      error("interval::autodiff requires a degree of at least 2.");
    end_if;

    idents := indets(ex);
    if nops(idents) <> 1 then
      error("Currently, interval::autodiff expects an expression in one indeterminate");
    end_if;
    // TODO: generate::optimize perform repeated squaring,
    // which is a BAD THING here.  Option for optimize not to do so?
    opt_ex := generate::optimize(ex, TRUE);
    
    // generate::optimize unfortunately does not break up
    // subexpressions it doesn't need to: (a+1)*(a+2)
    // split everything into atomic operations
    split_subex := proc(t, ex)
		     local non_atomic, ret, v, newt;
		   begin
		     if domtype(ex) <> DOM_EXPR then
		       return(t=ex);
		     end_if;
		     non_atomic := select({op(ex)},
					  t -> domtype(t)=DOM_EXPR);
		     ret := [];
		     for v in non_atomic do
		       newt := genident("t");
		       ret := ret.[split_subex(newt, v)];
		       // v occurs only once in ex,
		       // unless generate::optimize failed miserably
		       ex := subs(ex, v=newt, Unsimplified);
		     end_for;
		     op(ret), t=ex;
		   end_proc:
		     
    opt_ex := map(opt_ex, eq -> split_subex(op(eq, 1), op(eq, 2)));
    
    temps  := map(opt_ex, lhs);
    
    ret := [hold(_assign)(op(idents),
			  [hold(interval)(op(idents)),
			   1, 0$deg-1]),
	    0$nops(opt_ex)];
    
    for i from 1 to nops(opt_ex) do
      eq := opt_ex[i];
      [t, subex] := [op(eq)];
      case domtype(subex)
	of DOM_EXPR do
	  ops := map([op(subex)],
		     t -> case domtype(t)
			    of DOM_IDENT do
			      return(t);
			    otherwise
			      [interval(t), 0$deg];
			  end_case);
	  case op(subex, 0)
	    of hold(_plus) do
	      subex := [hold(_plus)(op(map(ops, _index, i)))$i=1..deg+1];
	      break;
	    of hold(_mult) do
	      subex := hold(AD::mult)(ops, deg);
	      break;
	    of hold(_power) do
	      subex := hold(AD::power)(ops, op(subex, 2), deg);
	      break;
	    of hold(exp) do
	      subex := hold(AD::exp)(ops[1], deg);
	      break;
	    of hold(ln) do
	      subex := hold(AD::ln)(ops[1], deg);
	      break;
	    of hold(sin) do
	      subex := hold(AD::sin)(ops[1], deg);
	      break;
	    of hold(cos) do
	      subex := hold(AD::cos)(op(ops), deg);
	      break;
	    of hold(sinh) do
	      subex := hold(AD::sinh)(op(ops), deg);
	      break;
	    of hold(cosh) do
	      subex := hold(AD::cosh)(op(ops), deg);
	      break;
//	    of hold(arctan) do
//	      subex := Arctan(op(ops), deg);
//	      break;
//	    of hold(arcsin) do
//	      subex := Arcsin(op(ops), deg);
//	      break;
	    otherwise
	      // TODO: allow overloading
	      error("function ".expr2text(op(subex, 0))." not yet implemented");
	  end_case;
      end_case;
      ret[i+1] := hold(_assign)(t, subex);
    end_for;

    // now, we need to translate all the temps into DOM_VARs:
    ret := subs(ret,
		[op(idents) = DOM_VAR(0, nops(temps)+2),
		 (temps[i] = DOM_VAR(0, i+1))$i=1..nops(temps)]);

    subsop(proc() begin end_proc,
	   1=(op(idents)), 
	   2=(op(temps)),
	   // _stmtseq(op(ret)) does not flatten, which is desirable here:
	   4=subsop(hold(f)(op(ret)), 0=_stmtseq));
  end_proc:


interval::autodiff := funcenv(interval::autodiff):

AD::mult2 := 
  proc(a, b, deg)
    local i, k;
  begin
    [_plus(a[k+1]*b[i-k]$k=0..i-1) $ i = 1..deg+1];
  end_proc:

AD::mult :=
  proc(l, deg)
    local i, k, tt;
  begin
    tt := (a, b) -> [_plus(a[k+1]*b[i-k]$k=0..i-1) $ i = 1..deg+1];
    if nops(l) = 2 then
      tt(l[1], l[2])
    else
      fp::fold(tt, l[-1])(op(l[1..-2]));
    end_if;
  end_proc:

AD::power :=
  proc(l, expo, deg)
    local b, tt, k, i;
  begin
    if domtype(expo) = DOM_INT then
      b := l[1]; // b ^ expo
      if expo > 0 then
	if expo=1 then
	  b
	elif expo=2 then
	  [b[1]^2, _plus(b[i+1]*b[k-i+1]$i=0..k)$k=1..deg];
	else
	  if expo < 2*deg then
	    // repeated squaring
	    tt := [1, 0$deg];
	    while TRUE do
	      if expo mod 2 = 1 then
		tt := AD::mult2(tt, b, deg);
	      end_if;
	      if expo < 2 then break end_if;
	      expo := expo div 2;
	      // b := b^2;
	      b := [b[1]^2, _plus(b[i+1]*b[k-i+1]$i=0..k)$k=1..deg];
	    end_while;
	    tt;
	  else
	    // This is worse in terms of overestimation, but
	    // definitely faster.
	    tt := [b[1]^expo, 0$deg];
	    for k from 1 to deg do
	      tt[k+1] := 1/b[1]*_plus(((expo+1)*i/k-1)*tt[k-i+1]*b[i+1]
				      $i=1..k);
	    end_for;
	    tt;
//	  else
//	    tt := [b[1]^(expo-deg), 0$deg];
//	    for j from 2 to deg do
//	      tt2 := tt;
//	      tt[1] := tt[1]*b[1];
//	      for k from 1 to j do
//		tt[k+1] := expo/k*_plus(tt2[k-i+1]*i*b[i+1]$i=1..k);
//	      end_for;
//	    end_for;
	  end_if;
//	  tt;
	end_if;
      elif expo = 0 then
	[1, 0$deg];
      else
	AD::power([AD::invert(l[1], deg), l[2]], -expo, deg);
      end_if;
    elif expo=1/2 then
      AD::sqrt(l[1], deg);
    else
      AD::exp(AD::mult2(AD::ln(l[1], deg),l[2], deg), deg);
    end_if;
  end_proc:

AD::invert :=
  proc(l, deg)
    local ret, i, k;
  begin
    ret := [1/l[1], 0$deg];
    for i from 2 to deg+1 do
      ret[i] := -ret[1]*_plus(ret[k+1]*l[i-k]$k=0..i-1);
    end_for;
    ret;
  end_proc:

AD::sqrt :=
  proc(l, deg)
    local ret, g0, k, i;
  begin
    ret := [0$deg+1];
    ret[1] := sqrt(l[1]);
    g0 := 1/ret[1];
    for k from 2 to deg+1 do
      ret[k] := g0*(1/2*l[k]-
		    1/(k-1)*_plus(i*ret[k-i]*ret[i+1]$i=1..k-2));
    end_for;
    ret;
  end_proc:

AD::exp := 
  proc(l, deg)
    local ret, i, k, t;
  begin
    ret := [1$deg+1];
    for k from 2 to deg+1 do
      ret[k] := 1/(k-1)*_plus(i*l[i+1]*ret[k-i]$i=1..k-1);
    end_for;
    t := exp(l[1]);
    map(ret, _mult, t);
  end_proc:

AD::ln :=
  proc(l, deg)
    local ret, k, i, g0;
  begin
    ret := [0$deg+1];
    g0 := l[1];
    ret[1] := ln(g0);
    for k from 2 to deg+1 do
      ret[k] := 1/g0*(l[k] -
		      1/(k-1)*_plus(l[k-i]*ret[i+1]$i=1..k-1));
    end_for;
    ret;
  end_proc:
  
AD::sincos :=
  proc(l, deg)
    local s, c, i, k;
  begin
    s := [0$deg+1];
    c := [0$deg+1];
    
    s[1] := sin(l[1]);
    c[1] := cos(l[1]);
    for k from 2 to deg+1 do
      s[k] :=  1/(k-1)*_plus(i*l[i+1]*c[k-i]$i=1..k-1);
      c[k] := -1/(k-1)*_plus(i*l[i+1]*s[k-i]$i=1..k-1);
    end_for;
    [s, c];
  end_proc:
  
AD::sin := (l, deg) -> (AD::sincos(l, deg))[1];
AD::cos := (l, deg) -> (AD::sincos(l, deg))[2];

AD::sinhcosh :=
  proc(l, deg)
    local s, c, i, k;
  begin
    s := [0$deg+1];
    c := [0$deg+1];
    
    s[1] := sinh(l[1]);
    c[1] := cosh(l[1]);
    for k from 2 to deg+1 do
      s[k] :=  1/(k-1)*_plus(i*l[i+1]*c[k-i]$i=1..k-1);
      c[k] :=  1/(k-1)*_plus(i*l[i+1]*s[k-i]$i=1..k-1);
    end_for;
    [s, c];
  end_proc:

AD::sinh := (l, deg) -> (AD::sinhcosh(l, deg))[1];
AD::cosh := (l, deg) -> (AD::sinhcosh(l, deg))[2];

// TODO: Formulas from 7.pdf, pp. 9ff.

/*

 BUG in the above code: interval::autodiff(1/x^2*exp(-1/x^2), 20)
 creates broken code!

*/