/* --------------------------------------
Calls:   collect(p, x, <f>)
         collect(p, [x], <f>)
         collect(p, [x1, x2, ...], <f>)

Arguments:
  p: an arithmetical expression
  x, x1, ...: arithmetical expresssions (typically: 
              identifiers or indexed identifiers)
  f: procedure

Details: 
 - In collect(p, x) the expression p is treated as a polynomial in x. 
   The coefficients of the terms corresponding to different powers x^i 
   are collected. The result is of the form c0 + c1*x + c2*x^2 + ...
   with coefficients c0, c1, c2, ... that do not depend on x.

   Further, if p contains symbolic function calls x(t1), x(t2), ..., 
   then collect(p, x) automatically regards x(t1), x(t2) as different
   polynomial unknowns.

 - collect(p, [x1, x2, ..]) treats p as a multivariate polynomial in
   x1, x2, .. and collects the coefficients of terms with the same
   exponents.

   Further, if p contains symbolic function calls x1(t11), x1(t12), ..., 
   x2(t21), x2(t21), ..., then collect(p, [x1, x2, ...]) automatically 
   regards x1(t11), x1(t12), ..., x2(t21), x2(t21), ... as different
   polynomial unknowns.

 - If a third argument f is specified, it is applied to the polynomial
   coefficients, i.e.
      collect(p, x, f) = f(c0) + f(c1)*x + f(c2)*x^2  +...,
   where c0, c1, c2, ... are given by
      collect(p, x) = c0 + c1*x + c2*x^2  +....


Examples:

>> collect(a*x + x + b*x^2 + c*x^3 + d*y + e*y, x)

   (a + 1)*x + (b + c)*x^3 + d*y + e*y

>> collect(a*x + x + b*x^2 + c*x^3 + d*y + e*y, [x, y])

   (a + 1)*x + (b + c)*x^3 + (d + e)*y

>> collect(a*x(0) + x(0) + b*x(1) + c*x(1) + d*y + e*y, x)

   (a + 1)*x(0) + (b + c)*x(1) + d*y + e*y

>> collect(diff(besselJ(0, x), x $ 6), besselJ, expand)

(9/x^2 - 60/x^4 - 1)*besselJ(0, x) + (3/x - 33/x^3 + 120/x^5)*besselJ(1, x)

Authors: 

  This is a very old function (many authors were involved)
  Extension by Walter, Jan 2009: symbolic functions calls 
  x(t1), x(t2) are now treated as independent variables.
-------------------------------------- */

collect :=
proc(p, ll)
  local l, f, fn, eq,
  unks: DOM_TABLE,
  subs_table: DOM_TABLE,
  i: DOM_INT,
  j: DOM_INT,
  myplus: DOM_PROC,
  mymult: DOM_PROC,
  mydivide: DOM_PROC;
begin
  if args(0) = 0 then
    error("collect called without arguments")
  end_if;

  if p::dom::collect <> FAIL then
    return(p::dom::collect(args()))
  end_if;
  
  if args(0)<2 or args(0)>=4 then
    error("Expecting 2 or 3 arguments")
  end_if;

  if not testtype(p, Type::Arithmetical) then
    error("First argument must be an arithmetical expression")
  end_if;
  
  if type(ll) <> DOM_LIST then
    l:=[ll]
  else
    l:=ll
  end_if;

  if args(0)<3 then
    fn := id;
  else
    fn := args(3);
  end_if;

  if nops(l) = 0 then
     return(fn(p));
  end_if;

  //----------------------------------------------
  // collect(expr, [f]) with some function f is
  // equivalent to collect(expr, [f(x1),f(x2),..])
  // where f(x1), f(x2), .. are all the different
  // symbolic calls of f occurring in expr.
  //----------------------------------------------
  unks:= table(): // table to collect symbolic function calls
  // store functions calls in p in the table unks:
  misc::maprec(p, {Type::AnyType} = 
               proc(x) local f; begin
                 if op(x, 0) <> FAIL then
		    f:= eval(op(x, 0));
		    if contains(l, f) > 0 then
                      if contains(unks, f) then
	                unks[f]:= unks[f] union {[op(x)]};
	              else
	                unks[f]:= {[op(x)]}
	              end_if;
	            end_if;
	         end_if;
	         return(x);
               end_proc, NoOperators);

  // sort the arguments of the functions calls so that
  // f(1) comes before f(2) comes before f(x) comes before f(y) ...
  unks:= map(unks, x -> sort([op(x)]));

  // introduce unknowns in l corresponding to different
  // function calls f(x) in the expression p (that are
  // stored in unks).
  l:= map(l, proc(f) local x; begin
               if contains(unks, f) then
	         return(f, (f(op(x)) $ x in unks[f]))
	       end_if;
	       return(f)
             end_proc);

  // we must eliminate duplicates in l (note that poly(.., [x, x])
  // is illegal), but we must not loose the ordering in l !
  // l:= sort([op({op(l)})]); // <- fast, but we'd loose the order
  if nops({op(l)}) < nops(l) then
     for i from 1 to nops(l) do
       for j from i+1 to nops(l) do
         if l[i] = l[j] then
           l[j]:= NIL;
	 end_if;
       end_for;
     end_for;
     l:= subs(l, NIL = null());
  end_if:

  // substitute expressions in l that are not legal unknowns
  // for poly by freshly generated identifiers and store these
  // substitutions in some table (for later re-substition):
  subs_table:= table():
  l:= map(l, proc(f) local x; begin
               if traperror(poly(0, [f])) = 0 then
	         return(f)
	       else 
	         x:= genident("_X_");
	         subs_table[x]:= f;
		 return(x)
	       end_if
	     end_proc);
  // replace the expressions stored in subs_table by the 
  // identifiers stored in subs_table:
  p:= subs(p, [(op(eq, 2) = op(eq, 1)) $ eq in subs_table]);
	    
  // ------------------------------------------------------------
  // end of preprocessing collect(expr, [f]) with some function f 
  // ------------------------------------------------------------

  // hold(_plus), but treat one and zero ops correctly
  myplus := proc()
	    begin
	      if args(0) = 0 then 0
	      elif args(0) = 1 then args()
	      else hold(_plus)(args()); end_if;
            end_proc;
  
  // hold(_mult), but treat one and zero args correctly
  // and flatten out ones.
  mymult := proc()
	      local factors;
	    begin
	      factors := select([args()], _unequal, 1);
	      if nops(factors) = 0 then 1
	      elif nops(factors) = 1 then factors[1]
	      else hold(_mult)(op(factors)); end_if;
	    end_proc;

  // hold(_divide), taking care of denom=1
  mydivide := proc(a, b)
      begin
        if b=1 then
          a
        elif type(b)="_power" then
          mymult(a, subsop(b, 2=-op(b, 2), Unsimplified))
        else
          mymult(a, hold(_power)(b, -1))
        end_if;
      end_proc;
  
  f:= poly(p, l);
  if f <> FAIL then
    f:= poly2list(f);
    if nops(subs_table) <> 0 then 
      [f, l]:= subs([f, l], [op(subs_table)]);
    end_if;
    if nops(l) = 1 then
      myplus(op(map(f, e -> mymult(generate::sortSums(fn(e[1])), l[1]^e[2]))));
    else
      myplus(op(map(f,
		    e -> mymult(generate::sortSums(fn(e[1])),
				l[i]^e[2][i] $ i = 1..nops(l)))));
    end_if;
  else
    f:= normal(p, List);
    f:= map(f, poly, l);
    if contains(f, FAIL) > 0 then
      return(FAIL)
    end_if;
    f:= map(f, poly2list);
    if nops(subs_table) <> 0 then 
      [f, l]:= subs([f, l], [op(subs_table)]);
    end_if;
    if nops(l) = 1 then
      mydivide(
        myplus(op(map(f[1],
  		    e -> mymult(generate::sortSums(fn(e[1])),
  				l[1]^e[2])))),
        myplus(op(map(f[2],
  		    e -> mymult(generate::sortSums(fn(e[1])),
  				l[1]^e[2])))));
    else
      mydivide(
        myplus(op(map(f[1],
  		    e -> mymult(generate::sortSums(fn(e[1])),
  				l[i]^e[2][i] $ i = 1..nops(l))))),
        myplus(op(map(f[2],
  		    e -> mymult(generate::sortSums(fn(e[1])),
  				l[i]^e[2][i] $ i = 1..nops(l))))));
    end_if;
  end_if;

end_proc:
