/* solvelib::Union

   

   Given a set1 whose elements contain a parameter x or several parameters from a list l
   and a set2, return the set of all objects that can
   be obtained by substituting x in some element of set1
   by some element of set2

*/

alias(BOUND = 20):

solvelib::Union:=
proc(set1: Type::Set, x: Type::Union(DOM_IDENT, DOM_LIST), set2: Type::Set)
local i, j, l, S,
  imgsets: DOM_LIST,
  maxeffort: DOM_FLOAT,
  UnionOverVectorSet: DOM_PROC;
save MAXEFFORT;
  
begin

  // local method UnionOverVectorSet
  /*  set1 - any set
         l - list
      set2 - set of vectors, matching the length of l
  */

  UnionOverVectorSet:=
  proc(set1, l: DOM_LIST, set2)
    local v, vec, i, inds, indices, S;
  begin
    if set1 = {} or set2 = {} then
      return({})
    end_if;

    
    // multiple equal entries in l -> probably name conflict
    assert(nops(l) = nops({op(l)}));


    // leave out components which set1 does not depend on
    inds:= freeIndets(set1);
    indices:= select([$1..nops(l)], i -> contains(inds, l[i]));
    if nops(indices) < nops(l) then
      S:= solvelib::selectIndices(set2, indices);
      if S <> FAIL then
        l:= map(indices, i -> l[i]);
        set2:= S
      end_if
    end_if;

    case nops(l)
      of 0 do
        return(set1)
      of 1 do
        S:= solvelib::vectorSetToNumberSet(set2);
        if S<>FAIL then 
          return(solvelib::Union(set1, op(l, 1), S))
        end_if
    end_case;
    
    case type(set2)
      of DOM_SET do
        if set2 = {} then
          return({})
        end_if;
        if map(set2, nops) <> {nops(l)} then
          error("second argument contains vector of wrong dimension")
        end_if;
        return(_union(subs(set1, zip(l,[op(i)], _equal),EvalChanges)) $i in set2)
      of "_union" do
        return(map(set2, s -> UnionOverVectorSet(set1, l, s)))
      of piecewise do
        return(piecewise::extmap(set2, s -> UnionOverVectorSet(set1, l, s)))
      of solvelib::VectorImageSet do
        if type(set1) = DOM_SET then
          // \bigcup_{[v1, v2, ..] \in {[g1(z1, ...), ...]; [z1, z2, ..] in S}}
          // [f1(v1, ..), f2(v1, ..), ...]
          // = {[f1(g(z1, ..), g2(z1, ..)), f2(...), ...], [z1, z2, ..] in S} 
          v:= set2::dom::expr(set2);
          return
          (
           solvelib::solve_union
           (
            solvelib::VectorImageSet
            (
             subs(vec,
                       [l[i] = v[i] $i=1..nops(v)], EvalChanges),
             set2::dom::variables(set2),
             set2::dom::sets(set2)
             )
            $vec in set1
            )
           )
        end_if;
        if type(set1) = solvelib::VectorImageSet then
          // we do only handle one special case: if we can write set1 as a cartesian product *and*  
          // each variable occurs in one factor only,
          // then we may compute the Union componentwise
          v:= solvelib::splitVectorSet(set1);
          if v <> FAIL and _lazy_and(nops(select(v, has, l[i])) <=1 $i=1..nops(l)) then 
            v:= map(v, UnionOverVectorSet, l, set2);
            if contains(map(v, type), "Union") = 0 then
              return(solvelib::cartesianProduct(op(v)))
            end_if
          end_if
        end_if; 
        break
      of solvelib::cartesianPower do
        case type(set1)
          of DOM_SET do
            return(_union(solvelib::VectorImageSet
                          (v, l, [set2::dom::base(set2) $nops(l)])
                          $v in set1
                          ))
          of Dom::ImageSet do
          of solvelib::VectorImageSet do
            // we can write the  union of {f(x, y, z); x in S}
            // where [y, z] ranges over T^2 as
            // {f(x, y, z); x in S, y in T, z in T}
            return(set1::dom::new
                   (
                    expr(set1),
                    set1::dom::variables(set1).l,
                    set1::dom::sets(set1).[set2::dom::base(set2) $nops(l)]
                    )
                   )
        end_case;
        break
    end_case;
    hold(solvelib::Union)(set1, l, set2)
  end_proc;

  /**********************************
  // main program of solvelib::Union
  **********************************/
  
  if set1 = undefined then
    return(undefined)
  end_if;

  if domtype(x) = DOM_LIST then
    if nops(x) = 1 then
      x:= op(x, 1);
      set2:= solvelib::vectorSetToNumberSet(set2);
      if set2 = FAIL then
        return(procname(args()))
      end_if
    else
      return(UnionOverVectorSet(set1, x, set2))
    end_if
  end_if;

  if not contains(freeIndets(set1), x) and solvelib::isEmpty(set2) = FALSE then
    return(set1)
  end_if;  


  case type(set2)
    of Dom::Multiset do
      if type(set1) = Dom::Multiset then
        return
        (
         Dom::Multiset::convert
         (
          [[subs(op(set1, i)[1], x = op(set2, j)[1], EvalChanges),
            op(set1,i)[2]*op(set2, j)[2]] $i=1..nops(set1) $j=1..nops(set2)]
          )
         )
      end_if;
      set2:= coerce(set2, DOM_SET);
      // fall through
    of DOM_SET do
      if type(set1) = DOM_SET then
        return
        (
         { (if traperror(( l:= evalAt(i, x=j))) = 0 then
              l
            end_if)
          $i in set1 $j in set2
          }
         )
      elif nops(set2) <= 4 or (nops(set2) <= BOUND and type(set1) <> piecewise) then
         return(_union(
              (if traperror((l:= evalAt(set1, x = i))) = 0 and l <> undefined then 
                 l 
               end_if) 
             $i in set2))
      else
        return(procname(args()))
      end_if
    of piecewise do
      // do case analysis w.r.t. parameter set
      /* if has(set2, x) then
           error("Index set of Union must not depend on index variable")
         end_if; */
      return(piecewise::extmap(set2, z-> solvelib::Union(set1, x, z)))

    of "_union" do
      return(map(set2, S -> solvelib::Union(set1, x, S)))
  end_case;

  if not hastype(set2, "solve") then
      assumeAlso(x in set2);
      // re-evaluate some sub-expressions of set1
      set1:= misc::maprec(set1, {"sign", "abs", "signIm"} =
      proc(X)
      begin
        if has(X, x) then
          eval(X)
        else
          X
        end_if  
      end_proc)  
  end_if;

  
  // try overloading
  if set1::dom::Union <> FAIL then
    return(set1::dom::Union(args()))
  end_if;
  
  case type(set1)
    of DOM_SET do
      if nops(set1) = 0 then
        return({})
      end_if;
      l:=[];
      imgsets:=[];
      maxeffort:= MAXEFFORT/2;
      MAXEFFORT:= maxeffort/nops(set1);
      for i from 1 to nops(set1) do
        S:= solvelib::substituteBySet(op(set1,i), x, set2);
        if S= FAIL then
          return(procname(args()))
        end_if;
        if type(S) = Dom::ImageSet then
          imgsets:= append(imgsets, S)
        else
          l:= append(l, S)
        end_if;
      end_for;
      MAXEFFORT:= maxeffort;
      l:= _union(op(l));
      case nops(imgsets)
        of 0 do
          return(l)
        of 1 do
          return(l union op(imgsets, 1))
        otherwise
          // avoid computing a union of ImageSets
          return(_union(l, op(imgsets)))
      end_case  
    of "solve" do
    of RootOf do
      return(procname(args()))
    of "_union" do
    of "_intersect" do
      MAXEFFORT:= MAXEFFORT/nops(set1);
      return(map(set1, solvelib::Union, x, set2))
    of "_minus" do
      return(procname(args()))
    otherwise
      return(procname(args()))
  end_case
 end_proc:

solvelib::Union:= funcenv(solvelib::Union):

solvelib::Union::type := "Union":
solvelib::Union::info := "solvelib::Union --- union of a system of sets":
solvelib::Union::print:= "solvelib::Union":


solvelib::Union::freeIndets:=
proc(S: "Union"): DOM_SET
begin
  freeIndets(op(S, 3), args(2..args(0))) 
  union 
  (
  freeIndets(op(S, 1), args(2..args(0))) 
  minus 
  {op(S, 2)}
  )
end_proc:

solvelib::Union::evalAt:=
proc(S: "Union", subst: DOM_SET)
  local M, vars, i, Y;
begin
  M:= op(S, 1);
  vars:= op(S, 2);
  if type(vars) = DOM_IDENT then
     subst:= select(subst, X -> op(X, 1) <> op(S, 2));
     if has(subst, op(S, 2)) then
        Y:= solvelib::getIdent(op(S, 3), indets(S) union indets(subst));
        M:= subs(M, vars = Y, Unsimplified);
        vars:= Y
     end_if
  else
     assert(type(vars) = DOM_LIST);
     subst:= select(subst, X -> (contains(vars, op(X, 1)) = 0));
     for i from 1 to nops(vars) do 
       if has(subst, vars[i]) then
         Y:= solvelib::getIdent(op(S, 3), indets(S) union indets(subst) union {op(vars)});
         M:= subs(M, vars[i] = Y, Unsimplified);
         vars[i]:= Y
       end_if
     end_for
  end_if;
  solvelib::Union(evalAt(M, subst), vars, evalAt(op(S, 3), subst))
end_proc:


solvelib::Union::testtype:=
proc(x, T)
begin
  if T = Type::Set then
    TRUE
  elif T = Type::Arithmetical then
    FALSE 
  else
    FAIL
  end_if
end_proc:

solvelib::Union::avoidAliasProblem:=
proc(S: "Union", vars: DOM_SET)
  local set1, x, set2, y,
  forbidden: DOM_SET,
  i: DOM_INT;
  
begin
  [set1, x, set2] := [op(S)];
  set1:= solvelib::avoidAliasProblem(set1, vars);
  set2:= solvelib::avoidAliasProblem(set2, vars);
  if type(x) = DOM_IDENT then
    if contains(vars, x) then
      y:= solvelib::getIdent(set2, vars union indets(set1) union indets(set2));
      set1:= subs(set1, x=y, Unsimplified)
    else
      y:= x
    end_if;
  else
    assert(type(x) = DOM_LIST);
    y:= x;
    if {op(x)} intersect vars <> {} then
      forbidden:= vars union indets(set1) union indets(set2);
      for i from 1 to nops(x) do
        if contains(vars, x[i]) then
          // create a new identifier that occurs nowhere and 
          // has also not been generated in previous iterations
          y[i]:= solvelib::getIdent(C_, forbidden union {op(y)});
          set1:= subs(set1, x[i]=y[i], Unsimplified)
        end_if;
      end_for;
    end_if;
  end_if;
  hold(solvelib::Union)(set1, y, set2)
end_proc:



solvelib::Union::Content :=
proc(Out, data)
begin
  Out::Capply(Out::CUnion,
              Out(extop(data, 1)),
              if testtype(extop(data, 2), DOM_LIST) then
                Out(matrix(extop(data, 2)))
              else
                Out(extop(data, 2))
              end,
              Out(extop(data, 3)));
end_proc:

// overload _plus
solvelib::Union::_plus := stdlib::set_plus:

// overload _mult
solvelib::Union::_mult := stdlib::set_mult:

// overload _power
solvelib::Union::_power := stdlib::set_power:


unalias(BOUND):
