/*===========================================================================

sum -- (in)definite summation

Calling sequence:
sum(f, x)

Parameters:
f -- an expression (rational function)
x -- identifier

Summary:
sum(f,x) computes the indefinite sum of f(x) with respect to x.

--------------------------------------------------------------------------

Examples for indefinite sums:

1) Polynomial case:

>> sum(1,k);

                                     k

>> sum(k^2,k);

                                     2    3
                                k   k    k
                                - - -- + --
                                6   2    3

>> sum(a*x^2+b*x+c,x);

                    3
                 a x      / a   b     \    2 /   a   b  \
                 ---- + x | - - - + c | + x  | - - + -  |
                  3       \ 6   2     /      \   2   2  /


2) Fractional case:

a) Decomposition into the polynomial part and the rational part:

>> sum(k/(k+1),k);

                              k - psi(k + 1)

>> sum((x^2+1)/(x-1),x);

                              2
                         x   x
                         - + -- + 2 psi(x - 1)
                         2   2


b) Reduced rational expressions:
   (degree of numerator is less than degree of denominator)

>> sum(1/k/(k-2),k);

                                1           1
                          - --------- - ---------
                            2 (k - 1)   2 (k - 2)


>> sum(1/k/(2*k-1),k);

                          - psi(k) + psi(k - 1/2)


==========================================================================*/

/*

Calling sequence:

sum(f, x <,"LookupOnly">)

The "option"" "LookupOnly" is undocumented, because it should be only used
internally. It means only the lookup mechanism should be used.

*/

sum :=
proc(s, x)
  local k, l, r, a, y, t, myE, myK, branches, conditions,
        lookupOnly, options: DOM_TABLE;

begin
  if args(0) = 0 then
    error("Wrong number of arguments")
  end_if;

  if args(args(0)) = "LookupOnly" then
    lookupOnly:= TRUE ;
  else
    lookupOnly:= FALSE ;
  end_if;

  // overloading
  if s::dom::sum <> FAIL then
    if lookupOnly then
      s::dom::sum(args(1..args(0)-1));
    else
      s::dom::sum(args()) ;
    end_if;
  end_if;

  if args(0) < 2 then
    error("at least two arguments expected")
  end_if;

  if not contains({"_equal", "_in", DOM_IDENT, "_index"}, type(x)) then
    error("identifier, indexed identifier, or equation expected ".
            "as 2nd argument")
  end_if;

  if type(x) = "_equal" or type(x) = "_in" then
    if domtype(op(x, 1)) <> DOM_IDENT and type(op(x,1)) <> "_index" then
      error("summation variable must be an identifier or indexed identifier")
    end_if
  end_if;

  if testargs() then
    if type(x) = "_equal" then   
      if type(op(x, 2)) = "_range" then
        // change the limits for wrong sums like sum(n,n=1..n) so that
        // we have something useful like  sum(n,n=1..k)
        if has(op(x, 2), op(x, 1)) then
          warning("summation variable occurs in rhs") ;
        end_if:
        if map(op(x, 2), testtype, Type::Arithmetical) <> (TRUE..TRUE) then
          error("bounds must be of Type::Arithmetical")
        end_if;
        if is(op(x, [2, 1]) <= op(x, [2, 2])) = FALSE then
          error("Left border must not exceed right border")
        end_if;
      elif domtype(op(x, 2)) <> RootOf then
        error("range or RootOf expected as rhs")
      elif has(op(x, 2), op(x, 1)) then
        error("summation variable occurs in rhs")
      elif stdlib::hasfloat(op(x, 2)) then
        error("floating-point numbers not allowed in rhs")
      end_if;
    elif type(x) = "_in" then
      if testtype(op(x, 2), Type::Set) = FALSE then
        error("summation variable must range over a set")
      elif has(op(x, 2), op(x, 1)) then
        warning("summation variable occurs in rhs")
      end_if
    end_if
  end_if;



  // simple form
  // summation variable does not occurs in rhs means myk = FALSE
  myK:= FALSE ;
   if type(x) = "_equal" and type(op(x, 2)) = "_range" then
     if has(op(x, 2), op(x, 1)) then
       if has(op(x, 2), _K) then
         myK:= genident("_K") ;
       else
         myK:= _K ;
       end_if:
       x:= op(x,1) = subs(op(x, 2), op(x,1) = myK);
     end_if:
    // definite summation; make assumptions about the summation
    // variable if possible
     k := op(x, 1);
     [l, r]:= [op(op(x, 2))];
     if protected(k) = hold(None) then
       save k;
       assume(k in Dom::Interval([l], [r]) intersect Z_ );
     end_if;

     a:= r - l;
     for y in select(indets(a) minus Type::ConstantIdents,
                     ind -> protected(ind) = hold(None)) do
       // we solve the condition a >= 0 for y, provided that
       // a is linear in y
       if (t:= Type::Linear(a, [y])) <> FALSE then
         // a = t[1] * y + t[2]
         // thus a >= 0 iff t[1] is positive and y >= -t[2]/t[1]
         // or t[1] is negative and y <= -t[2]/t[1] or t[1] = 0
         if is(t[1] > 0) = TRUE then
           save y;
           assume(y >= -t[2]/t[1], _and)
         elif is(t[1] < 0) = TRUE then
           save y;
           assume(y <= -t[2]/t[1], _and)
         end_if
       end_if;
     end_for;


     s:= eval(s)
   end_if;

  // Spezialfall
  if s = 0 then
    return(s);
  end_if:

  if testargs() then
    if domtype(s) <> DOM_POLY and not testtype(s, Type::Arithmetical) then
      error("1st argument must be of Type::Arithmetical")
    end_if
  end_if;


  if lookupOnly then
    options:= prog::getOptions(4, [args()], sum::options, TRUE, sum::optionTypes)[1];
    options["LookupOnly"] := TRUE ;
  else
    options:= prog::getOptions(3, [args()], sum::options, TRUE, sum::optionTypes)[1];
    options["LookupOnly"] := FALSE ;
  end_if:

  if domtype(s) = DOM_POLY then
    s:= expr(s)
  end_if;

  // In sum::sum, exp(x) is written as `#E`^x.
  // Do not mistake this with `#E` used by the user.
  myE:= genident("E"):
  if has(s, `#E`) then
     s:= subs(s, `#E` = myE);
  end_if;

  a:= sum::sum(s, x, options);
  if ( myK <> FALSE ) then
    a:= subs(a, myK = op(x,1))
  end_if:

  // In sum::sum, exp(x) is written as `#E`^x.
  // Upon return, this should have been cleaned
  // by sum::sum, but sometimes this does not
  // happen.
  if has(a, `#E`) then
     a:= subs(a, `#E` = exp(1));
  end_if;
  a:= subs(a, myE = `#E`);

  if domtype(a) <> piecewise then
    return(a)
  end_if;

  // fill "uncovered case" if the result is a piecewise
  // and the condition allows to replace an identifier
  t:= not _or(op(map([op(a)], op, 1)));
  if type(t) = "_or" then
    branches:=[op(t)]
  else
    branches:= [t]
  end_if;

  conditions:= map({extop(a)}, op, 1);
  for t in branches do
    if testtype(t, "_equal") and
       testtype(op(t, 1), DOM_IDENT) and
       testtype(op(t, 2), Type::Constant) and
      traperror
      ((
        t:= piecewise::flattenBranch
        (stdlib::branch(t, sum(subs(s, t, EvalChanges), subs(x, t, EvalChanges), options)))))=0
      then
      for y in [t] do
        if not contains(conditions, op(y, 1))
          then
          conditions:= conditions union {op(y, 1)};
          a:= piecewise::insert(a, y)
        end_if
      end_for;
    end_if;
  end_for;

  a

end_proc:

sum:= funcenv(sum,
proc(a)
  local eq, sumSign, res;
begin
  if PRETTYPRINT = TRUE then
    sumSign := "---",
               "\\  ",
               "/  ",
               "---":
    if type(op(a, 2)) = "_equal" then
      // definite sum or RootOf
      eq := op(a, [2,2]);
      if domtype(eq) = RootOf then
        res := _outputSequencePlus(stdlib::Exposed(stdlib::align(sumSign,
                                                                 expr2text(op(a, [2,1])=eq))),
                                   stdlib::Exposed(" "), op(a, 1));
      else
        res := _outputSequencePlus(stdlib::Exposed(stdlib::align(expr2text(op(eq, 2)),
                                                                 sumSign,
                                                                 expr2text(op(a, [2,1])=op(eq, 1)))),
                                   stdlib::Exposed(" "), op(a, 1));
      end_if;
    else
      // indefinite sum
      res := _outputSequencePlus(stdlib::Exposed(stdlib::align(sumSign,
                                                               expr2text(op(a, 2)))),
                                 stdlib::Exposed(" "), op(a, 1));
    end_if;
    res := strprint(All, res);
    if (res[2] <> res[4]) then
      return(FAIL)
    end_if;
    res[1]
  else
    FAIL
  end:
end_proc):
sum:= slot(sum, "type", "sum"):
sum:= slot(sum, "print", "sum"):

sum::options:= table(MaxOrder = 2,
                     ExpandFinite = 1000):
sum::optionTypes:= table(MaxOrder = Type::PosInt,
                         ExpandFinite = Type::NonNegInt):




sum::evalAt:=
proc(J: "sum", subst: Type::SetOf("_equal"))
  local hasx, notx, dummy;
begin
  subst := select(subst, eq -> has(J, op(eq, 1)));
  if type(op(J, 2)) = "_equal" or 
    type(op(J, 2)) = "_in" then
    // in sum(f(x), x=0..x) (which is legal in mupad!), we want to carry
    // out a substitution x=... only on the last x
    notx:= select(subst, s -> op(s, 1) <> op(J, [2, 1]));
    // if one of the substituted values contains our bound variable,
    // rename the bound variable first, to avoid aliasing problems.
    if has(notx, op(J, [2, 1])) then
      notx := notx union {op(J, [2, 1]) = solvelib::getIdent(Any, indets([args()]))};
    end_if;
    sum(evalAt(op(J, 1), notx),
      subsop(op(J, 2),
        1 = evalAt(op(J, [2, 1]), notx),
        2 = evalAt(op(J, [2, 2]), subst)),
      op(J, 3..nops(J)))
  else
    // indefinite summation
    [hasx, notx, dummy] := split(subst, has, op(J, 2));
    if hasx = {} then
      eval(hold(sum)(evalAt(op(J,1), notx), op(J, 2..nops(J))));
    elif notx = {} then
      hold(evalAt)(args());
    else
      hold(evalAt)(eval(hold(sum)(evalAt(op(J,1), notx), op(J, 2..nops(J)))), hasx);
    end_if
  end_if
end_proc:



sum::int:=
proc(s,x)
begin
  sum(int(op(s,1),x),op(s,2))
end_proc:


sum::limit:=
proc(f: "sum", x: DOM_IDENT, lp: Type::Arithmetical, dir: DOM_IDENT,
     options: DOM_TABLE)
  local s, t;
begin
  if options[Intervals] then
    // cannot sum over intervals
    return(hold(limit)(f, x=lp, dir, Intervals))
  end_if;
  
  if not has(op(f, 2), x) then
    if type(op(f,2)) = DOM_IDENT then
      assumeAlso(op(f, 2) in Z_)
    elif type(op(f, 2)) = "_equal" then
      if op(f, [2, 2, 1]) = -infinity or
         op(f, [2, 2, 2]) = infinity then
        // cannot interchange summation and limit
        return(hold(limit)(f, x=lp, dir))
      end_if;
      assumeAlso(op(f, [2, 1]) in Z_ intersect Dom::Interval(op(f, [2, 2])))
    end_if;
    // limit(sum(...)) = sum(limit(...)) since the sum is finite
    return(sum(limit(op(f, 1), x=lp, dir), op(f, 2)))
  end_if;
  
  if not has(op(f, 1), x) and type(op(f, 2)) = "_equal" then
    s:= limit(op(f, [2, 2, 1]), x=lp, dir);
    t:= limit(op(f, [2, 2, 2]), x=lp, dir);
    if [s, t] = [-infinity, infinity] then
      s:= sum(op(f, 1), op(f, [2, 1]) = -infinity..infinity);
      if not hastype(s, "sum") then
        return(s)
      end_if
    elif not hastype([s, t], "limit") and s <> infinity
      and t <> - infinity then
      return(sum(op(f, 1), op(f, [2, 1]) = s..t))
    end_if;
  end_if;
  return(hold(limit)(f, x=lp, dir))
end_proc:


sum::rectform:=
proc(s, x)
  local k, l, r, rf;
begin
  case type(x)
    of "_equal" do
      if type(op(x, 2)) = "_range" then
        // definite summation; make assumptions about the summation
        // variable if possible
        k := op(x, 1);
        [l, r]:= [op(op(x, 2))];
        if protected(k) = hold(None) then
          save k;
          assume(k in Dom::Interval([l], [r]) intersect Z_ );
        end_if;
      end_if;
      break;
    of "_in" do
      k:= op(x, 1);
      if protected(k) = hold(None) then
        save k;
        assume(k in op(x, 2))
      end_if;
      break
    of DOM_IDENT do
      if protected(x) = hold(None) then
        save x;
        assume(x in Z_)
      end_if;
  end_case;
  rf:= rectform(s);
  new(rectform, op(map([op(rf)], sum, args(2..args(0)))))
end_proc:

sum::freeIndets:=
proc(S: "sum"): DOM_SET
  local res;
begin
  if testargs() then
    if args(0) > 2 then
      error("wrong number of arguments") ;
    end_if;
    if args(0) = 2 and args(2) <> All then
      error("illegal argument") ;
    end_if:
  end_if:

  if type(op(S, 2)) = "_equal" or type(op(S, 2))="_in" then
    res:= (freeIndets(op(S, 1), args(2..args(0))) union
           freeIndets(op(S, [2, 2]), args(2..args(0)))) minus
          freeIndets(op(S, [2, 1]), args(2..args(0)))
  else
    res:= freeIndets(op(S, 1), args(2..args(0))) union
          freeIndets(op(S, 2), args(2..args(0)))
  end_if:

  if args(0) = 2 then
    res union {hold(sum)};
  else
    res;
  end_if:
end_proc:

sum::operandsToSimplify:= [1]:


sum::Content      :=
  proc(Out, data)
    local bvar;
  begin
    if nops(data) <> 2 then
      return(Out::stdFunc(data));
    end_if;
    if type(op(data, 2)) = "_equal" then
      if type(op(data, [2,2])) = "_range" then
        Out::Capply(Out::Csum,
                    Out::Cbvar(Out(op(data, [2,1]))),
                    Out::Clowlimit(Out(op(data, [2,2,1]))),
                    Out::Cuplimit( Out(op(data, [2,2,2]))),
                    Out(op(data,1)));
      else
        bvar := Out::Cbvar(Out(op(data, [2,1])));
        Out::Capply(Out::Csum, bvar,
                    Out::Ccondition(
                          Out::Capply(Out::Cin,
                                      bvar,
                                      Out(op(data, [2,2])))),
                    Out(op(data,1)));
      end_if:
    else
      Out::Capply(Out::Csum,
                  Out::Cbvar(Out(op(data, 2))),
                  Out(op(data,1)));
    end_if;
  end_proc:

sum::float := loadproc(sum::float, pathname("NUMERIC"), "floatsum"):

sum::diff :=
proc(e) // e = sum(f, x)
  local l, _x;
begin
  l := [args(2..args(0))]; // variables to differentiate
  if map({op(l)}, testtype, Type::Indeterminate) = {TRUE}
    and not has((_x := op(e, 2)), l) then
     // summation independent of differentiation
    sum(diff(op(e, 1), op(l)), _x)
  else
    hold(diff)(e, op(l))
  end_if
end_proc:

sum::expand:=
proc(S: "sum")
  local s, d, l, k, f, u, a, x;
begin
  a:= op(S, 1);
  x:= op(S, 2);
  if type(x) = "_equal" and
    type((k:= op(x, 2))) = "_range" and
    domtype((d:= op(x,[2,2]) - op(x, [2, 1]))) = DOM_INT then
    // left border must not exceed right border, this should
    // have been checked before
    assert(d>=0);
    l:= op(x, [2, 1]);
    k:= op(x, 1);
    f:= subs(a, k = k + l);
    f:= _plus(subs(f, k=u, EvalChanges) $ u=0..d);
    return(eval(%))
  end_if;
  s:=expand(a, args(2..args(0)));
  if type(s)="_plus" then
    _plus(map(s, expand@sum, x))
  elif type(s)="_mult" then
    split(s,has,(if type(x)=DOM_IDENT
                   then
                   x
                 else
                   op(x,1)
                 end_if));
    %[2]*hold(sum)(%[1], x)
  else
    S
  end_if
end_proc:

sum::ztrans:= proc(s, k, z)
begin
   sum(transform::ztrans(extop(s, 1), k, z), extop(s, 2..extnops(s)));
end_proc:

sum::invztrans:= proc(s, z, k)
begin
   sum(transform::invztrans(extop(s, 1), z, k), extop(s, 2..extnops(s)));
end_proc:

sum::simplify :=loadproc(sum::simplify,pathname("STDLIB","SIMPLIFY"),"sum"):

autoload(sum::addpattern):
autoload(sum::dispersion):
sum::evalAtPoint := loadproc(sum::evalAtPoint, pathname("SUM"), "sum"):
autoload(sum::factor):
autoload(sum::gosper):
sum::gosper2:= loadproc(sum::gosper2, pathname("SUM"), "gosper"):
autoload(sum::indefinite):
autoload(sum::lookup):
autoload(sum::myeval):
sum::patternFSA := loadproc(sum::patternFSA, pathname("SUM"), "load_patterns"):
sum::sum_fn     := loadproc(sum::sum_fn,     pathname("SUM"), "sum"):
autoload(sum::normal):
autoload(sum::poly):
autoload(sum::rat):
autoload(sum::ratio):
autoload(sum::rootOf):
autoload(sum::sum):
autoload(sum::support):
autoload(sum::zeilberger):
sum::userpatterns := []:
sum::userpatternsFSA := FAIL:
sum::interface := {hold(addpattern)}:

prog::setcheckglobals(sum, {_X,_K}):

null():
