/* generate a new rec object and perform some basic argument checking
   input: R = recurrence equation or expression, UOFN = function to solve for
          (of the form u(n) for idents u,n),
          INITS = initial condition(s), equation or set of equations
          of the form u(foo)=bar, may be empty or omitted
          all occurrences of u in R must be of the form u(n+i) for integral i
          u must not occur in bar, n most not occur in foo
*/
rec::new :=
proc(R, UOFN, INITS = {})
  local r, uofn, inits;
begin
  if args(0) < 2 or args(0) > 3 then
    error("expecting two or three arguments")
  end_if;

  r := R;
  uofn := UOFN;
  inits := INITS;

  
  if type(r) = "_equal" then
    r := op(r, 1) - op(r, 2)
  end_if;
  if type(inits) = "_equal" then
    inits := {inits}
  end_if;

  if domtype(r) = piecewise then
    return(piecewise::extmap(r, rec, args(2..args(0))))
  end_if;  
  
  if testargs() then
    rec::check_args(r, uofn, inits)
  end_if;
  new(rec, r, uofn, inits)
end_proc:


/* rec::subs performs substitution in all three operands, evaluates the
   result, and checks whether the operands are still valid after the
   substitution
*/
rec::subs := proc(R)
  local ops;
begin
  ops := eval(subs(extop(R), args(2..args(0))));
  rec::check_args(ops);
  new(rec, ops)
end_proc:


//----------------------------------------------
rec::check_args := proc(r, uofn, inits)
  local s, islin, u, n, eq;
begin
  /* check second argument */
  if type(uofn) <> "function" then
    error("second argument must be the function to solve for")
  end_if;
  if nops(uofn) <> 1 then
    error("second argument must be a function with one operand")
  end_if;
  u := op(uofn, 0);
  n := op(uofn, 1);
  if map({u ,n}, type) <> {DOM_IDENT} then
    error("illegal second argument")
  end_if;
  // check for PI, EULER, ...
  if nops({u, n} intersect Type::ConstantIdents) > 0 then
    error("illegal second argument")
  end_if;

  /* check initial conditions */
  if inits <> {} then
    if domtype(inits) <> DOM_SET then
      error("third argument must be an equation or a set of equations")
    end_if;
  
    /* Check whether all initial conditions are of the form
       u(foo)=bar, with bar not involving u and foo not involving n */
    for eq in inits do
      if type(eq) <> "_equal" then
	      error("third argument must be an equation or a set of equations")
      end_if;
      if nops(op(eq, 1)) <> 1 or op(eq, [1,0]) <> u then
        error("left hand side of initial condition must contain the unknown function");
      end_if;
      if has(op(eq, [1,1]), n) then
        error("left hand side of initial condition must not contain the index variable")
      end_if;
      if has(op(eq, 2), u) then
        error("right hand side of initial condition must not contain the unknown function")
      end_if;
    end_for
  end_if;

  /* check first argument */
  if domtype(r) <> DOM_EXPR then
    error("first argument must be an expression or an equation")
  end_if;

  /* check whether all terms u(f(n)) are of the form u(n+i) */
  if type(r) = "_equal" then
    s:= {op(r)};
  else
    s:= r;
  end_if;
  s := select(indets(s, RatExpr), testtype, "function");
  s := select(s, proc(f) begin
      op(f, 0) = u and has([op(f)], n)
    end_proc);
  if nops(s) = 0 then // no u(..) in equation
    error("first argument must involve the unknown function")
  end_if;
  islin := select(s, proc(f) begin
      testtype(op(f) - n, DOM_INT)
    end_proc);
  if s <> islin then
    error("only integral shifts are allowed in the first argument")
  end_if;
end_proc:

//-------------------------------------
// overload freeIndets
rec::freeIndets:=
proc(f: rec)
begin
  (freeIndets(op(f, 1), args(2..args(0))) union 
   freeIndets(op(f, 3), args(2..args(0)))) 
   minus {op(f, [2, 1])} 
end_proc:

//-------------------------------------
// overload evalAt
rec::evalAt:=
proc(f:rec, subst)
  local dontsubst, dummy;
begin
  [subst, dontsubst, dummy]:= split(subst, equ -> op(equ, 1) <> op(f, [2,1]) );
  f:= eval(subs(f, subst, Unsimplified));
  if dontsubst = {} then
    f
  else
    hold(evalAt)(f, dontsubst)
  end_if;
end_proc:

//-------------------------------------
// overload the function call r(n) where type(r) is
// "rec" and n is a positive integer. We do the 
// recurrence stepwise without solving it symbolically. 
// This functionality already exists as solve::evalAt. 
// Here, we just build a nicer interface:
rec::func_call:= proc(r : rec, n)
local uu, nn, result;
begin
  n:= context(n);
  // op(r, 2) = u(n)
  uu:= op(r, [2, 0]);
  nn:= op(r, [2, 1]);
  if domtype(n) <> DOM_INT then
     return(hold(rec)(op(r))(n));
  end_if;
  result:= solve::evalAt(hold(solve)(r), {nn = n});
  if domtype(result) = DOM_SET and 
     nops(result) = 1 then
   return(op(result));
  else
     return(uu(n));
  end_if;
end_proc:
