

/* 
testeq(ex, options)
testeq(ex1, ex, options)
testeq(ex1 = ex2, options)

tests whether ex1 is mathematically the same as ex2 for *all* possible parameter values
returns TRUE, FALSE, or UNKNOWN

ex1, ex2 - arithmetical expressions
options: 
NumberOfRandomTests - total number of random tests to be performed
NumberOfRandomRatTests - number of rational tests (included in the total number of tests above)
ZeroTestMode - force testeq to return TRUE or FALSE, but never UNKNOWN
*/


alias(RATRIGHTBORDER = 9/4): // random rational numbers are chosen from the interval (0, RATRIGHTBORDER)


testeq :=
proc(ex1, rhs = 0)
  local OPT: DOM_TABLE, ex, newex1, newex2, n, ex2,
  numericalTest: DOM_PROC,
  SimplifyOptions: DOM_TABLE;
  
begin

  // numericalTest(ex)
  // numerical test for non-constant expressions ex
  // if replacing the variables in ex by values leads to
  // a nonzero result, the expression cannot be zero-equivalent
  // we may only do this if the substitutions are consistent with the
  // properties of our identifiers
    
  
  numericalTest:=
  proc(ex)
    name testeq; // for userinfo
    local i:DOM_INT,
    j: DOM_INT,
    failed: DOM_INT,
    succesful: DOM_INT,
    substitutions,
    ratTests: DOM_INT,
    irratTests: DOM_INT,
    inds: DOM_LIST,
    proplist: DOM_LIST,
    props,
    replaceAllIntegrals: DOM_PROC,
    replaceIndex: DOM_PROC,
    replaceOperators: DOM_PROC,
    numericCheckFails: DOM_PROC,
    setImplicitAssumptions: DOM_PROC,
    rationalRandomValueGenerator: DOM_PROC,
    randomValueGenerator:DOM_PROC;
    save MAXEFFORT;

    
  begin

    // local method rationalRandomValueGenerator(S, i, j, X)
    // returns a rational "random" element from S, with "seed" i and j
    // used to create a value for the j-th variable in the i-th try
    // may also return X if no random value could be created
    rationalRandomValueGenerator:=
    proc(S, i, j, X)
      local r, l, n;
    begin
      case type(S)
        of "_union" do 
          n:= nops(S);
          return(rationalRandomValueGenerator(op(S, modp(i, n) + 1), i div n, j, X))
        of solvelib::BasicSet do
          case S
            of Z_ do
              return(i+j);
            of R_ do
            of C_ do
            of Q_ do   
              return(((7*i+4*j) mod (2*i+3*j+1))/(2*i+3*j+1)* RATRIGHTBORDER)
          end_case
        of Dom::Interval do
          r:= S::dom::right(S);
          l:= S::dom::left(S);
          if {type(l), type(r)} subset {DOM_RAT, DOM_INT, stdlib::Infinity} then 
            // adjust borders such that they are between 0 and RATRIGHTBORDER, but avoid making an empty interval 
            if l < RATRIGHTBORDER and r > RATRIGHTBORDER then
              r:= RATRIGHTBORDER
            end_if;
            if l < 0 and r > 0 then 
              l:= 0
            end_if;
            if l = -infinity then 
            // we have that r <= 0, such that we could not adjust the borders; 
            // take care that l has rather small absolute value
              l:= r-1
            elif r = infinity then 
              // l >= RATRIGHTBORDER; make r rather small
              r:= l+1
            end_if
          else
            // irrational borders; we do not check them. Maybe we create values that are inconsistent with the properties
            l:= 0;
            r:= RATRIGHTBORDER
          end_if;  
          return(l + ((7*i+3*j) mod (2*i+3*j+1))/(2*i+3*j+1) * (r-l) )
      end_case;
      X
    end_proc;
    
    
    // local method randomValueGenerator(S, i, j, X)
    // returns a "random" element from S, with "seed" i and j
    // may also return X if no random value could be created
    randomValueGenerator:=
    proc(S, i, j, X)
      local n;
    begin
      case type(S)
        of "_union" do 
          n:= nops(S);
          return(randomValueGenerator(op(S, modp(i, n) + 1), i div n, j, X))
        of solvelib::BasicSet do
          case S
            of Z_ do
              return( round(
              (-1)^(j)*exp(5/6*j + 3/7*i +123/100)
                     ) );
            of R_ do
              return(
              -j/2*sin(j^2 + i + 4567/10000)
                     )
            of C_ do
              // perform real tests in one third of all cases
              if j mod 3 = 0 then
                 return(
              -j/2*sin(j^2+i+4567/10000)
                     )
              else
                return(
                       (-j*cos(7/12 - i + j^3) + I*j/2*sin(i+ j^2+ 4567/10000))
                       )
              end_if
          end_case
        of Dom::Interval do
          if S::dom::left(S) = -infinity then
            if S::dom::right(S) = infinity then
              return(
              -j/2*sin(j^2+ i + 4567/10000)
                     )
            else
              return(
              S::dom::right(S) - j/2*abs(sin(j^2 + i + 4567/10000))
                     )
            end_if;
          else
            if S::dom::right(S) = infinity then
              return(
              S::dom::left(S) + j/2*abs(sin(j^2 + i + 4567/10000))
                     )
            else
              return(
                  S::dom::left(S) +
                  (S::dom::right(S) - S::dom::left(S))*791/792*abs(sin(i+j^2))
                 )
            end_if;
          end_if;
      end_case;
      X
    end_proc;

   


    /*
      numeric::checkFails(substitutions)

      substitutions - list of substitutions x1=a1, x2=a2, ... for replacing the identifiers in ex
      
      returns TRUE if ex <> 0 has been proven by some random substitution, and FALSE otherwise 
    */
    
    numericCheckFails:=
    proc(substitutions: DOM_LIST): DOM_BOOL
      local res;
    begin
          if traperror((
                    res:=evalAt(ex, substitutions)
                    ), MaxSteps = OPT[KernelSteps]) <> 0 then
                    UNKNOWN
          elif numeric::isnonzero(res) <> TRUE then
            FALSE          
          
          elif OPT[IgnoreProperties] then
            TRUE
          elif 
            traperror((
            res:= _and(op(evalAt(props, substitutions)))
            )) <> 0 then
            UNKNOWN
          else 
            is(res, Goal = TRUE) or UNKNOWN                 
          end_if;                
    end_proc;
    
    
    /* 
       setImplicitAssumptions

       adds some assumptions that must be satisfied in order that expressions like fact(n) be defined at all

    */
    
    setImplicitAssumptions:=
    proc(a)
      local result;
    begin
      result:= TRUE;
      misc::maprec(a, 
      {"fact", "fact2", "polylog", "lambertW"} = (X -> (result:= result and op(X, 1) in Z_; X)),
      {"zeta"} = (X -> (if nops(X)=2 then result:= result and op(X, 2) in Z_ end_if; X))
      );
      
      if result <> TRUE then
        context(hold(assumeAlso)(result))
      end_if  
    end_proc;
      
    replaceOperators:=
    proc()
      option remember;
    begin
      genident("testeq")
    end_proc;  
    

// We replace X[i1, i2, i3, ...] by X()^2 + i1 + 2*i2 + 3*i3 + ....
// By the additional factors 2,3, ..., 
// we want to avoid X[i, j] - X[j, i] -> 0
    replaceIndex:=
    proc(X)
      local i;
    begin
      replaceOperators(X)^2 + _plus(i*args(i) $i=2..args(0))
    end_proc;
    

    if OPT[NumberOfRandomTests] = 0 then
      // do not do any random tests
      return(UNKNOWN)
    end_if;  

    // We do NumberOfRandomTests tests, and split them as follows:
    // the first test consists of plugging in zero.
    // then the specified number of rationals is plugged in
    // the remaining tests consist of plugging in irrational values    

    ratTests:= OPT[hold(NumberOfRandomRatTests)];
    irratTests:= OPT[NumberOfRandomTests] - ratTests - 2;

  
    // Replace all indefinite integrals and sums by random values
    // Replace all expressions where the 0-th operand is an unknown function f(x1, x2, x3, ...) by 
    // X(f)^2 + x1 + 2*x2 + ...
    // where X is a newly generated identifier (for each f, a different one)
    // alternatively, we might want to accept false results for testeq(f(sin(x)^2 + cos(x)^2), f(1))
    // Walter: suggests to replace definite integrals and sums, too (for speed)
    
    if OPT[hold(ZeroTestMode)] then
      replaceAllIntegrals:=
      proc(ex)
      begin 
        misc::maprec(ex, 
        {"int", "sum"} = 
        proc(J)
          option remember;
        begin
          if type(op(J, 2)) in {"_equal", "_in"} then
          // definite integral, must not replace this
            J
          else
            genident("testeq_")       
          end_if  
        end_proc,
        {O} = id,
        {"function"} =
        proc(F)
          local opF;
        begin
          opF:= eval(op(F, 0));
          if type(opF) <> DOM_FUNC_ENV or opF::float = FAIL or op(F, 0) = hold(dirac) then
            replaceOperators(op(F, 0))^2 + _plus(i*op(F, i) $i=1..nops(F))
          else
          // functions like sin, cos, etc. must remain unchanged
            F
          end_if
        end_proc
        )
      end_proc;
      
      
    else
    
      replaceAllIntegrals:=
      proc(ex)
      begin 
        misc::maprec(ex, 
        {"int", "sum"} = 
        proc(J)
          option remember;
        begin
          if type(op(J, 2)) in {"_equal", "_in"} then
          // definite integral, must not replace this
            J
          else
            genident("testeq_")       
          end_if  
        end_proc,
        {O} = id,
        {"function"} =
        proc(F)
          local opF;
        begin
          opF:= eval(op(F, 0));
          if not contains({DOM_FUNC_ENV, DOM_PROC}, type(opF)) or op(F, 0) = hold(dirac) then
            replaceOperators(op(F, 0))^2 + _plus(i*op(F, i) $i=1..nops(F))
          else
          // functions like sin, cos, etc. must remain unchanged
            F
          end_if
        end_proc
        )
      end_proc;
    
    end_if;
    
    props:= property::showprops(ex);
    // ignore properties depending on free variables
    props:= select(props, X -> freeIndets(X) minus freeIndets(ex) = {});
   
    
    [ex, props]:= replaceAllIntegrals([ex, props]);
    [ex, props]:= subs([ex, props], hold(_index) = replaceIndex, EvalChanges);
       
    inds:= [op(freeIndets(ex))];
    
    
    // plug in zero 
       
    userinfo(30, "Trying  zero substitution");
    if numericCheckFails(map(inds, _equal, 0)) = TRUE then 
      userinfo(10, "Substituted values: ".expr2text(map(inds, _equal, 0)));                             
      return(FALSE);
    end_if;  
       
    userinfo(30, "Trying  to substitute ones");
    if numericCheckFails(map(inds, _equal, 1)) = TRUE then 
      userinfo(10, "Substituted values: ".expr2text(map(inds, _equal, 1)));                             
      return(FALSE);
    end_if;  
    
    if OPT[IgnoreProperties] then
      proplist:= [C_ $nops(inds)]
    else  
      setImplicitAssumptions(ex);
      proplist:= map(inds, getprop)
    end_if;  
    
    if ratTests + irratTests > 1 then
      MAXEFFORT:= MAXEFFORT / (ratTests + irratTests)
    end_if;  
   
    failed:= 0;
    succesful:= 0;
    // rational tests
   
    j:= 1;
    while j<= 2*ratTests and succesful < ratTests do
      userinfo(30, "Trying ".expr2text(j)."th rational random substitution");
      substitutions:= [op(inds, i) = rationalRandomValueGenerator(proplist[i], i, j, op(inds, i)) $i=1..nops(inds)];
      case numericCheckFails(substitutions) 
        of TRUE do
          userinfo(10, "Substituted values: ".expr2text(substitutions));                             
          return(FALSE)
        of UNKNOWN do 
          failed:= failed + 1;
          break
        of FALSE do
          // if we had a substitution of the form X = X, we cannot rely on it 
          if _lazy_or(op(substitutions)) then
            failed:= failed+1
          else
            succesful:= succesful + 1
          end_if
      end_case;
      j:= j+1;  
    end_while;  
    if OPT[hold(ZeroTestMode)] and irratTests = 0 and succesful < ratTests then
      // too many tests failed
      return(FALSE)
    end_if;  
    
    userinfo(20, "Rational tests: ".expr2text(succesful)." admissible substitutions, ".expr2text(failed)." not admissible");
    
    j:= 1;
    while succesful < ratTests + irratTests do
      userinfo(30, "Trying ".expr2text(j)."th random substitution");
      substitutions:= [op(inds, i) = randomValueGenerator(proplist[i], i, j, op(inds, i)) $i=1..nops(inds)];
      case numericCheckFails(substitutions) 
        of TRUE do
          userinfo(10, "Substituted values: ".expr2text(substitutions));                             
          return(FALSE)
        of UNKNOWN do
          failed:= failed + 1;
          break
        of FALSE do
          // if we had a substitution of the form X = X, we cannot rely on it 
          if _lazy_or(op(substitutions)) then
            failed:= failed + 1
          else
            succesful:= succesful + 1
          end_if
        end_case;
        j:= j+1;
      if failed  >  ratTests + irratTests then
        if OPT[hold(ZeroTestMode)] then
        // already too many tests failed. We give up
        // in doubt, we believe that the expression is not zero
          return(FALSE)
        else
          return(UNKNOWN)
        end_if  
      end_if  
    end_while;
    
    if OPT[hold(ZeroTestMode)] then
      TRUE
    else  
      UNKNOWN
    end_if  
  end_proc:


  /**********************************************/
  // m a i n   p r o g r a m   o f    t e s t e q
  /**********************************************/

  if type(ex1) = "_approx" and (ex:= bool(ex1)) <> UNKNOWN then
    return(ex)
  end_if;

  // accept one equation as input
  if type(ex1) = "_equal" or type(ex1) = "_approx" then
    [ex1, ex2] := [op(ex1)];
    n:= 1 // start of options - section
  else
    ex2:= rhs;
    n:= 2
  end_if;

  if ex1::dom::testeq <> FAIL then
    return(ex1::dom::testeq(ex1, ex2, args(n+1..args(0))))
  elif ex2::dom::testeq <> FAIL then
    return(ex2::dom::testeq(ex2, ex1, args(n+1..args(0))))
  end_if;


    
  OPT := prog::getOptions(n+1, [args()],
   table(Steps = 100,                         // number of steps to be done by Simplify
         Seconds = infinity,                  // number of seconds to be spent by Simplify 
         IgnoreAnalyticConstraints = FALSE,   // passed on to Simplify 
         IgnoreProperties = FALSE,            // also try random substitutions that are not consistent with the assumptions
         KernelSteps = 10,                    // kernel steps per simplification step; and also per random substitution
         RuleBase = Simplify,                 // passed on to Simplify
         NumberOfRandomTests = 5,             // total number of random tests, including rational tests
         hold(NumberOfRandomRatTests) = 2,    // number of rational tests (included in the total number of tests above)
         Goal = 0,                             // passed on to Simplify 
         hold(ZeroTestMode) = FALSE         
         ),                          
   TRUE,
   table(Steps = Type::Union(Type::NonNegInt, stdlib::Infinity),
         Seconds = Type::Union(Type::Positive, stdlib::Infinity),
         IgnoreAnalyticConstraints = DOM_BOOL,
         IgnoreProperties = DOM_BOOL, 
         KernelSteps = Type::AnyType,
         RuleBase    = DOM_DOMAIN,
         NumberOfRandomTests = Type::NonNegInt,
         hold(NumberOfRandomRatTests) = Type::NonNegInt,
         Goal = Type::AnyType,
         hold(ZeroTestMode) = DOM_BOOL
        ))[1];

  
  

  if {domtype(ex1), domtype(ex2)} minus
    {DOM_INT, DOM_RAT, DOM_COMPLEX, DOM_FLOAT, DOM_EXPR, DOM_IDENT,
     piecewise, O} <> {} or
     not testtype(ex1, Type::Arithmetical) or
     not testtype(ex2, Type::Arithmetical)
    then
    if ex1 = ex2 then
      return(TRUE)
    end_if;
    if domtype(ex1) <> domtype(ex2) then
      return(FALSE)
    end_if;

    // handling of particular datatypes
    case domtype(ex1)
      of DOM_LIST do
        // two lists [a1, a2, ...] and [b1, b2, ..] are equal iff
        // a_i = b_i for all i
        return(_lazy_and(op(zip(ex1, ex2, testeq, args(n..args(0))))))
      of DOM_POLY do
        return(_lazy_and(op(ex1, 2) = op(ex2, 2),
                         op(ex1, 3) = op(ex2, 3),
                         testeq(op(ex1, 1), op(ex2, 1), args(n..args(0)))))
    end_case;
    // default
    if OPT[hold(ZeroTestMode)] then
      return(FALSE)
    else  
      return(UNKNOWN)
    end_if  
  end_if;


  ex := ex1 - ex2;


  // numerical test
  // works only for constant expressions
  if (n := numeric::isnonzero(ex)) <> UNKNOWN then
    userinfo(1, "isnonzero returns ".expr2text(n)." [testeq]");
    return(not n)
  end_if;
  
  // general numerical test
  if (n:= numericalTest(ex)) <> UNKNOWN then
    userinfo(1, "random substitution succeeds");
    return(n)
  end_if;
  
  if OPT[Steps] = 0 then
    // do not attempt to simplify
    return(UNKNOWN)
  end_if;

  SimplifyOptions:= OPT;
  delete SimplifyOptions[NumberOfRandomTests];
  delete SimplifyOptions[hold(NumberOfRandomRatTests)];
  delete SimplifyOptions[IgnoreProperties];
  delete SimplifyOptions[hold(ZeroTestMode)];
  
  // simplification of the expanded difference
  ex := Simplify::defaultSimplifier(ex, SimplifyOptions);
  userinfo(1, "Simplification of difference returns ".
           expr2text(ex)." [testeq]");
  if (n := is(ex = 0)) <> UNKNOWN then
    userinfo(1, "'is = 0' returns ".expr2text(n)." [testeq]");
    return(n)
  end_if;



  // simplification
  newex1 := Simplify::defaultSimplifier(ex1, SimplifyOptions);
  newex2 := Simplify::defaultSimplifier(ex2, SimplifyOptions);
  userinfo(1, "Simplification returns ".expr2text(newex1, newex2)." [testeq]");

  if (n := is(newex1 = newex2)) <> UNKNOWN then
    userinfo(1, "'is =' returns ".expr2text(n)." [testeq]");
    return(n)
  elif (n := is(newex1 <> newex2)) <> UNKNOWN then
    userinfo(1, "'is <>' returns ".expr2text(n)." [testeq]");
    return(not n)
  end_if;

  if newex1 <> ex1 or newex2 <> ex2 then
    // second numerical test after simplification
    numericalTest(newex1 - newex2)
  else
    UNKNOWN
  end_if

end_proc:
