

alias(bl=matchlib::block):

sum::addpattern :=
proc(pattern, var, result, pat_vars, conds)
  name sum::addpattern;
  option hold;
  local newvar, definite ;
begin
  if args(0) < 3 then
     error("At least three arguments expected")
  end_if;

  pattern := context(pattern);
  var := context(var);
  result := context(hold(bl)(result));
  if args(0) > 3 then
    pat_vars := context(pat_vars);
  else
    pat_vars := [];
  end;
  if args(0) > 4 then
    conds := context(hold(map)(hold(hold)(conds), bl));
  else
    conds := [];
  end;
  if domtype(conds) <> DOM_LIST then conds := [conds]; end_if;

  if type(var)="_equal" then
    if type(lhs(var)) <> DOM_IDENT or
       type(rhs(var)) <> "_range" then
       error("illegal variable specification: ".expr2text(var));
     end;
     // pattern := `#hdef`(pattern, op(var,[2,1]), op(var,[2,2]));
     definite := TRUE ;
     newvar := lhs(var);
  else
     definite := FALSE ;
     newvar := var ;
  end;
  
  if type(newvar) <> DOM_IDENT then
     error("illegal variable specification: ".expr2text(var));
  end;

  if newvar <> _X and has([pattern, result, pat_vars, conds], _X) then
    [pattern, result, pat_vars, conds] := 
      subs([pattern, result, pat_vars, conds], _X=genident("_X"));
  end;
  [pattern, result, pat_vars, conds] :=
    sum::addpattern::extractconstants(subs([pattern, result], newvar = _X)).
                                      subs([pat_vars, conds], newvar = _X);

  // Mark definite pattern
  if (definite) then
    pattern := `#hdef`(pattern, op(var,[2,1]), op(var,[2,2]));
  end_if;
  
  matchlib::addpatterns(sum::userpatternsFSA, sum::userpatterns,
	                    [[pattern, result, pat_vars, conds,
	                      map(pat_vars, v -> subs(bl(not has(`#v`, _X)), `#v`=v))]],
	                    sum::addpattern::generalize, FALSE, FAIL);

  
  
  sum::sum(Remember, Clear);
  
  null();
end:

sum::addpattern := funcenv(sum::addpattern):

// here the pattern variable is _X
// The arguments are:
// pat[1]: The pattern
// pat[2]: The result
//
sum::addpattern::extractconstants :=
proc(pat)
  local pat1, noX, dummy;
begin
  if type(pat[1]) = "_mult" then
    [pat1, noX, dummy] := split([op(pat[1])],
                                has, hold([_X, fx, gx, f1x, g1x, px, qx, qx1]));
    noX := _mult(op(noX));
    if noX <> 1 then
      [_mult(op(pat1)),
       subs(matchlib::block(`#a`/`#b`),
            `#a`=matchlib::unblock(pat[2]),
            `#b`=noX)];
    else
      pat;
    end_if;
  else
    pat
  end_if;
end_proc:

// The arguments are:
// pat[1]: The pattern
// pat[2]: The range
// pat[3]: The result
// pat[4]: the conditions if available
//
sum::addpattern::extractconstantsEx :=
proc(pat)
  local pat1 ;
begin
  pat1 :=  sum::addpattern::extractconstants([pat[1], pat[3]]) ;
  if nops(pat) > 3 then
    [pat1[1], pat[2], pat1[2], pat[4]] ;
  else
    [pat1[1], pat[2], pat1[2], []] ;
  end_if;
end_proc:

// One of the generaliztion steps:
//
// We will generate some new "patterns" automatically by
// changing the limits
// If something like:
//     sum(f(k),k)
//     sum(f(k),k=b..infinity)
// or  sum(f(k),k=-infinity..b)
// is given patterns for:
//     sum(f(k+s),k)
//     sum(f(k+s),k=a..infinity)
// or  sum(f(k+s),k=-infinity..a)
// are generated
//
// The arguments are:
// pat[1]: The pattern
// pat[2]: A list yielding a piecewise branch
// pat[3]: Hard conditions
//
sum::addpattern::newlimits :=
proc(pat)
  local l, r, fpat, getvar, hc, hleft, hright, hashposs, pos, vars;
begin
  getvar :=
  proc(pref)
  begin
    if has(pat, pref) then
      genident("".pref);
    else
      pref
    end_if;
  end:
  
  hc := getvar(`#c`);
  
  if op(pat[1],0)<> `#hdef` then // indefinite summation
    if has(pat[1], {`#fx`,`#gx`,`#f1x`,`#g1x`}) then
      return(pat);
    else
      return(pat, [subs(pat[1],_X=_X+`##`(hc))].subs(pat[2..-1], _X=_X+hc));
    end_if:
  end_if:
  l := op(pat[1], 2);
  r := op(pat[1], 3);
  // fp::unapply does not work reliably
  // besides, we want to remove the `##` thingies
  fpat := op(pat[1], 1);
  hashposs := [prog::find(fpat, `##`)];
  vars := {op(fpat, pos[1..-2]) $ pos in hashposs};
  fpat := subs(k -> subs(`#fn`, _X=k),
               `#fn` = subs(matchlib::unblock(fpat),
               map(vars, x->x=op(x)))); 
  hleft := getvar(`#c_left`);
  hright := getvar(`#c_right`);

  if r = infinity and domtype(l) = DOM_INT then
    pat[2] := matchlib::unblock(pat[2]);
    pat:= // pat,
      [`#hdef`(op(pat[1], 1), `##`(hleft), infinity),
       (matchlib::block@(x->x))
       ([hold(_and)(pat[2][1], hleft>=l),
         hold(_subtract)(pat[2][2],
                         hold(sum::sum_fn)(fpat,l..hleft-1))]),
       pat[3]],
      [`#hdef`(op(pat[1], 1), `##`(hleft), infinity),
       (matchlib::block@(x->x))
       ([hold(_and)(pat[2][1], hleft<=l),
         hold(_plus)(pat[2][2],hold(sum::sum_fn)(fpat, hleft..l-1))]),
       pat[3]],
       if not has(pat[1], {`#fx`,`#gx`,`#f1x`,`#g1x`}) then
         [`#hdef`(subs(op(pat[1],1), _X = _X + `##`(hc)), `##`(hleft), infinity),
          (matchlib::block@(x->x))
          ([hold(_and)(pat[2][1], hleft >= l-hc),
            hold(_subtract)(pat[2][2],
                            hold(sum::sum_fn)(fpat, l..hleft+hc-1))]),
          pat[3]],
         [`#hdef`(subs(op(pat[1],1), _X = _X + `##`(hc)), `##`(hleft), infinity),
          (matchlib::block@(x->x))
          ([hold(_and)(pat[2][1], hleft <= l-hc),
            hold(_plus)(pat[2][2],
                        hold(sum::sum_fn)(fpat, hleft+hc..l-1))]),
          pat[3]]
       else
         null();
       end_if:
  elif l = -infinity and domtype(r) = DOM_INT then
    pat:= // pat,
      [`#hdef`(op(pat[1], 1), -infinity, `##`(hright)),
       (matchlib::block@(x->x))
       ([hold(_and)(pat[2][1], hright<=r),
         hold(_subtract)(pat[2][2],
                         hold(sum::sum_fn)(fpat, hright+1..r))]),
       pat[3]],
      [`#hdef`(op(pat[1], 1), -infinity, `##`(hright)),
       (matchlib::block@(x->x))
       ([hold(_and)(pat[2][1], hright>=r),
         hold(_plus)(pat[2][2],
                     hold(sum::sum_fn)(fpat, r+1..hright))]),
       pat[3]],
       if not has(pat[1], {`#fx`,`#gx`,`#f1x`,`#g1x`}) then
         [`#hdef`(subs(op(pat[1],1), _X = _X + `##`(hc)), -infinity, `##`(hright)),
          (matchlib::block@(x->x))
          ([hold(_and)(pat[2][1], hright <= r-hc),
            hold(_subtract)(pat[2][2], 
                            hold(sum::sum_fn)(fpat, hright+hc+1..r))]),
          pat[3]],
         [`#hdef`(subs(op(pat[1],1), _X = _X + `##`(hc)), -infinity, `##`(hright)),
          (matchlib::block@(x->x))
          ([hold(_and)(pat[2][1], hright >= r-hc),
            hold(_plus)(pat[2][2],
                        hold(sum::sum_fn)(fpat, r+1..hright+hc))]),
          pat[3]]
       else
         null();
       end_if:
  end_if:
  // now, ensure that hc+constant is folded into a single new
  // pattern variable, thus avoiding problems with
  // (_X+hc)/(_X+hc+a) and similar constructs
  pat := op(map([pat],
            proc(p)
              local hcpos, pos, t_ren, summands, substs, subst,
                    condsubsts, patvars, v, found, p1, t, c;
            begin
              t_ren := table(`##`(hc)=hc);
              hcpos := [prog::find(p[1], `##`(hc))];
              for pos in hcpos do
                if op(p[1], pos[1..-2].[0]) = hold(_plus) then
                  summands := _plus(op(select([op(op(p[1], pos[1..-2]))],
                                              t -> not has(t, _X) and
                                                   (op(t,0)=`##` or not has(t, `##`)))));
                  if not contains(t_ren, summands) then
                    t_ren[summands] := genident("#c_");
                  end_if;
                end;
              end;
              // now, for the conditions, we must try for each of
              // the sums we replace to extract one pattern variable
              // we can express with the other new pattern vars
              condsubsts := [];
              substs := prog::sort([op(t_ren)], e -> nops(select([op(lhs(e))],
                                                       t -> op(t,0)=`##`)));
              for subst in substs do
                patvars := select([op(lhs(subst))], t -> op(t,0)=`##` and t <> `##`(hc));
                if patvars=[] then
                  condsubsts := [rhs(subst)=subs(lhs(subst), `##`=id,EvalChanges)].
                    condsubsts;
                  next;
                end;
                found := FALSE;
                for v in patvars do
                  for t in t_ren do
                    if t = subst then next; end_if;
                    c := lhs(t)+v-lhs(subst);
                    if testtype(c, Type::Constant) then
                      found := TRUE;
                      condsubsts := [op(v)=rhs(subst)-rhs(t)+c].
                        condsubsts;
                    end;
                  end_for;
                end_for;
                if not found then
                  return(p); // TODO: how much can we still do?
                end_if;
              end_for;
              p1 := subsex(p[1], 
                           prog::sort(map([op(t_ren)],
                                          e -> lhs(e)=`##`(rhs(e))),
                                      e -> -nops(lhs(e))));
              // for variables still in p1, we introduce new hard conditions;
              // for variables not in p1 any longer, we substitute
              // in the conditions and the result
              condsubsts := select(condsubsts, e -> lhs(e) <> rhs(e));
              condsubsts := split(condsubsts, s -> has(p1, lhs(s)));
              assert(condsubsts[3]=[]);
              [p1, subs(p[2], condsubsts[2]),
               subs(p[3], condsubsts[2]).condsubsts[1].
               if has(p1, hc) then
                 [subs(bl(testtype(`#hc`, Type::Integer)), `#hc`=hc)]
               else
                 []
               end_if];
            end));
  pat:
end_proc:


// for patterns matching a _plus expression, add `#hx`
// to the pattern and sum(`#hx`, x) to the result:
//
// here, the vars have not been renamed yet
// The arguments are:
// pat[1]: The pattern
// pat[2]: The range
// pat[3]: The result
// pat[4]: The conditions
sum::addpattern::generalizesum :=
p -> ((if type(p[1])="_plus" then
			    [p[1]+`##`(`#hx`), p[2],
                             (matchlib::block@(x->x))
			     (hold(_plus)(matchlib::unblock(p[3]),
                                          hold(sum(`#hx`, k))))
			    ].p[4]
			  else p end_if)):

sum::addpattern::generalize := 
proc(pat)
begin
  pat := {sum::addpattern::generalizesum(pat)};
  pat := map(pat, sum::addpattern::newlimits);
  op(pat);
end:


prog::setcheckglobals(sum::addpattern::newlimits, {_X}):
prog::setcheckglobals(sum::addpattern, {_X}):
