
solvelib::diophantine:=
proc(eq, var, options = solvelib::getOptions(): DOM_TABLE)
  local solutions;	
begin
  assert(type(eq) <> "_equal"); // has been handled in solve
 
  case type(eq)
  of DOM_LIST do
    eq:= {op(eq)};
      // fall through
  of DOM_SET do
    case type(var)
    of DOM_LIST do
      break;
    of DOM_SET do
      options[VectorFormat]:= FALSE;
      var:= [op(var)]; 
      break
    otherwise
      var:= [var]
    end_case;
    solutions:= solvelib::diophantineSystem(eq, var, options);
    if options[VectorFormat] or type(solutions) = "solve" then        
      return(solutions)
    else
      return(matrix(var) in solutions)
    end_if
  otherwise
    case type(var)
    of DOM_LIST do
      break;
    of DOM_SET do
      options[VectorFormat]:= FALSE;
      var:= [op(var)]; 
      break
    otherwise
      // solve one equation for one variable
      if testtype(eq, Type::PolyExpr([var], Type::Integer)) then
        return(solvelib::iroots(eq))
      end_if;
      options[Domain] := Expr;
      return(solve(eq, var, options) intersect Z_)
    end_case;
    solutions:= solvelib::diophantineSolveEq(eq, var, options);
    if options[VectorFormat] or type(solutions) = "solve" then
      return(solutions)
    else
      return(matrix(var) in solutions)
    end_if
  end_case;
 
  // NOT REACHED
  
end_proc:


solvelib::diophantineSystem:=
proc(sys: DOM_SET, vars: DOM_LIST, options)
  local newsys;
begin
  options[VectorFormat]:= TRUE;

  // can only handle linear systems at the moment
  if (newsys:= solvelib::solve_islinear(sys, vars)) <> FALSE then
    return(solvelib::linearDiophantineSystem(newsys, vars, options))
  end_if;
  
  hold(solve)(sys, vars, solvelib::nonDefaultOptions(options))
end_proc:



/***************************************************************

solvelib::linearDiophantineSystem

solves a system of linear diophantine equations

Source: Harold Greenberg, Integer programming, p. 143


****************************************************************/

solvelib::linearDiophantineSystem:=
proc(sys: DOM_SET, vars: DOM_LIST, options)
  local A, B, n, i, j, s, minimum, nonzeros, quotient;
begin
  A:= linalg::expr2Matrix(sys, vars);
  // now A is the extended coefficient matrix of the system

  n:= nops(vars);
  // a coefficient matrix should have one column per variable,
  // plus a final column for the constant coefficients
  assert(/*dense*/matrix::matdim(A)[2] = n+1);
  
  // for each variable x_i, we add one more equation x_i = y_i,
  // where y_i is a free identifier
  
  B:= /*dense*/matrix(n, n+1, 1, Diagonal);
  for i from 1 to n do
   // B[i, n+1] := genident("k")
  end_for;

  A:= linalg::stackMatrix(A, B);
  delete B;

  while linalg::nrows(A) > nops(vars) do


    repeat
      
    // look for the smallest nonzero entry in the first row
      minimum:= infinity;
      nonzeros:= 0;
      for j from 1 to n do
        if not iszero(A[1, j]) then
          nonzeros:= nonzeros+1;
          if abs(A[1, j]) < abs(minimum) then
            s:= j;
            minimum:= A[1, j]
          end_if
        end_if;
      end_for;

      if nonzeros = 0 then
        // either the row represents 0=0, or the system is unsolvable
        if A[1, n+1] <> 0 then
          return({})
        else
          A:= linalg::delRow(A, 1);
          break
        end_if

      elif nonzeros > 1 then
        // reduce
        for j from 1 to n do
          if s<>j then
            quotient:= round(A[1, j] / minimum);
            // multiply column
            for i from 1 to linalg::nrows(A) do
              A[i, j] := A[i, j] - quotient * A[i, s]
            end_for;
          end_if;
        end_for;
      end_if

    until nonzeros = 1 end_repeat;

    if nonzeros = 0 then break; end_if;

    // we have an equation A[1, s] x = A[1, n+1] which can only hold
    // if A[1,s] divides the right hand side
      
    if type((quotient:= A[1, n+1]/ A[1, s])) <> DOM_INT then
      return({})
    end_if;

    // otherwise, we multiply the s-th column by the quotient and shift
    // it to the right hand side

    // beta:= linalg::multCol(A, s, quotient);

    A:= linalg::addCol(A, s, n+1, -quotient);
   // does the same as: 
   // for i from 1 to n do
   //   A[i, n+1] := A[i, n+1] + quotient * A[i, s]
   // end_for;

    A:= linalg::delCol(A, s);
    A:= linalg::delRow(A, 1);
    // one column has been deleted
    n:= n-1;
      
  end_while;

  // finally, the last column should contain a solution

  // the other columns should contain a base of the set of solutions
  // of the homogenous system

  if n > 0 then 
  
    j:= [hold(k).i $i=1..n];
    s:= -linalg::col(A, n+1) +
    _plus(j[i] * linalg::col(A, i) $i=1..n);
    
    
    Dom::ImageSet([s[i, 1] $i=1..linalg::nrows(s)], j, [Z_ $ n])

  else

    s:= -linalg::col(A, n+1);
    {[s[i, 1] $i=1..linalg::nrows(s)]}

  end_if
  
  
end_proc:


/****************************************
solvelib::diophantineSolveEq(eq, vars, options)

solve one equation for all variables in vars

****************************************/

solvelib::diophantineSolveEq:=
proc(eq, vars: DOM_LIST, options)
  local f, g, i, coefflist, rhs, x, y, exchangeVars;
begin
  assert(type(eq) <> "_equal");
  
  // convert to polynomial

  f:= poly(eq, vars, Dom::Integer);
  if f= FAIL then
    userinfo(3, "Equation is not polynomial or does ".
             "not have integer coefficients");
    return(hold(solve)(eq, vars, solvelib::nonDefaultOptions(options)))
  end_if;

  // divide off gcd
  f:= polylib::primpart(f);
  if lcoeff(f) < 0 then
    f:= -f
  end_if;
  
  // is eq linear in vars?
  if degree(f) = 1 then
    // linear diophantine equation
    if options[PrincipalValue] then
      coefflist:=[ expr(coeff(f, vars[i], 1)) $i=1..nops(vars)];
      rhs:= - ground(f);
    // we have to solve \sum coefflist[i]*vars[i] = rhs
      return(solvelib::linearDiophantineEquationPrincipal(coefflist,
                                                                  rhs))
    else
      return(solvelib::linearDiophantineSystem({eq}, vars, options))
    end_if
  end_if;

  
  if degree(f) = 2 then
    case nops(vars)
      of 1 do
        return(solve(f, op(vars, 1)) intersect Z_)
      of 2 do
        return(solvelib::quadraticDiophantineBivariate(f, vars))
    end_case;
  end_if;


  f:= poly(f, hold(Expr));

  
  if nops(vars) = 2 then
    [x, y] := vars;
    if degree(f, x) = 2 and iszero(coeff(f, x, 1)) and
      expr(coeff(f, x, 2)) = 1 and
      degree((g:=coeff(f, x, 0))) mod 2 = 0 and lcoeff(g) = -1 then

      exchangeVars:=
      proc(vectorset)

      begin
        case type(vectorset)
          of DOM_SET do
          of Dom::ImageSet do  
            return(map(vectorset, l-> [op(l, 2), op(l,1)]))            
        end_case;
        error("Unknown set type")
      end_proc;


      
      return(
             exchangeVars(solvelib::solveSzalay(-g))
             )

    elif degree(f, y) = 2 and iszero(coeff(f, y, 1)) and
      expr(coeff(f, y, 2)) = -1 and
      degree((g:=coeff(f, y, 0))) mod 2 = 0 and lcoeff(g) = 1 then

      return(
             solvelib::solveSzalay(g)
             )
    end_if;
  end_if;

  
  hold(solve)(eq, vars, solvelib::nonDefaultOptions(options))
end_proc:


/*******************************************

solvelib::linearDiophantineEquationPrincipal(a, d)

solves \sum_i a[i]*x[i] = d for x[i]
and returns *one* solution


********************************************/

solvelib::linearDiophantineEquationPrincipal:=
proc(a: Type::ListOf(DOM_INT), d: DOM_INT)
  // solves sum a[i]*x[i] = d
  local gc: DOM_INT, i, igcdx, quot, loes;
begin
  gc:= a[1];
  loes:= [1];

  // loop invariant:
  // gc = gcd(a[1], .., a[i-1])
  // loes[1]*a[1] + ... + loes[i-1]*a[i-1] = gc  
  for i from 2 to nops(a) do
    igcdx:= [igcdex(gc, a[i])];
    // now igcdx[1] = gcd(gc, a[i]) = gcd(a[1], ..., a[i])
    // and igcdx[1] = igcdx[2] * gc + igcdx[3] * a[i]
    //              = igcdx[2]*loes[1]*a[1] + ... + gc*loes[i-1]*a[i-1]
    //                + igcdx[3]*a[i]

    loes:= map(loes, _mult, igcdx[2]).[igcdx[3]];
    gc:= igcdx[1]
  end_for;

  if type((quot:= d/gc)) = DOM_INT then
    loes:= map(loes, _mult, quot)
  else
    return({})
  end_if;

  // now, loes is one solution
  // the general solution could be obtained by adding a solution
  // of the homogenous equation
  // \sum a[i]*x[i] = 0

  {loes}

end_proc:



/*********************************************

solvelib::quadraticDiophantineBivariate

solves bivariate quadratic diophantine equations

the raw structure of the code is similar to Dario Alpern's
Java implementation of the quadratic diophantine solver

**********************************************/


solvelib::quadraticDiophantineBivariate:=
proc(h: DOM_POLY, vars: DOM_LIST)
  local x, y,
  a: DOM_INT,
  b: DOM_INT,
  c: DOM_INT,
  d: DOM_INT,
  e: DOM_INT,
  f: DOM_INT,
  g: DOM_INT,
  u, t,
  sqrta, sqrtc, divisors, discriminant, roots;
begin
  assert(nops(vars) = 2);
  [x, y]:= vars;
  
  // extract coefficients
  a:= expr(coeff(h, x, 2));
  c:= expr(coeff(h, y, 2));
  b:= coeff(h, x, 1);
  [b, d]:= [coeff(b, y, 1), coeff(b, y, 0)];
  e:= coeff(h, x, 0);
  [e, f]:= [coeff(e, y, 1), coeff(e, y, 0)];

  assert(h = poly(a*x^2 + b*x*y + c*y^2 + d*x + e*y + f, vars, Dom::Integer));


  // the equation has no solution unless f is a multiple
  // of igcd(a, b, c, d, e)
  // in the latter case, we may divide all coefficients by that igcd

  
  g:= igcd(a, b, c, d, e);
  if f mod g <> 0 then
    return({})
  end_if;

  a:= a/g; b:= b/g; c:= c/g; d:= d/g; e:= e/g; f:= f/g;



  
  if a = 0 and c = 0 then
  
    // the equation is not linear, so b<>0
    assert(b<>0);
    // we have to solve bxy + dx + ey + f + 0
      // multiplying by b gives (bx + e)*(by + d) - de + bf = 0
    if (discriminant:= d*e - b*f) = 0 then
        // x = - e/b or y = - d/b
      return(
             (if e mod b = 0 then
                Dom::ImageSet([-e/b, hold(k)], hold(k), Z_)
              else
                {}
              end_if
              )
             union
             (if d mod b = 0 then
                Dom::ImageSet([hold(k), -d/b], hold(k), Z_)
              else
                {}
              end_if
              )
             )
    else
      divisors:= numlib::divisors(discriminant);
      divisors:= divisors. map(divisors, _negate);
      map(divisors,
          proc(x)
          begin
            if (x - e) mod b = 0 then
              assert((discriminant/x - d) mod b = 0);
              [(x-e)/b, (discriminant/x - d)/b]
            else
              null()
            end_if
          end_proc
          );
      return({op(%)})
    end_if;
  end_if; // a = c = 0


  discriminant:= b^2 - 4*a*c;

  
  if discriminant < 0 then

    // elliptic case
    
    roots:= numeric::polyroots
    (discriminant*x^2 + 2*(b*e - 2*c*d)*x + e^2 - 4*c*f);

    if hastype(roots, DOM_COMPLEX) then
      // no real solution
      return({})
    end_if;

    return
    (select
     ({([x, ((-b*x - e) + sqrt( (b*x + e)^2 - 4*c*(a*x^2 + d*x + f) ))/ 2 / c],
        [x, ((-b*x - e) - sqrt( (b*x + e)^2 - 4*c*(a*x^2 + d*x + f) ))/ 2 / c])
       $x = ceil(min(op(roots))) .. floor(max(op(roots)))},
      pair -> type(op(pair, 2)) = DOM_INT
      )
     );

  end_if; // discriminant < 0


  if discriminant = 0 then

    // parabolic case

    g:= gcd(a, c);


    // make sure that g has the same sign as a
    if a*g <0 then g:=-g end;
    c:= c/g;
    // use the original a here!
    sqrtc:= sqrt(c) * sign(b/a);
    a:= a/g;
    sqrta:= sqrt(a);
    
    assert(a >= 0);
    assert(c >= 0);
    assert(type(sqrta) = DOM_INT);
    assert(type(sqrtc) = DOM_INT);

    if sqrtc*d = sqrta*e then
      // no parabola, but parallel lines

      // we abuse x for a variable u satisfying u = x + y
      roots:= solve(sqrta*g*x^2+ d*x + f, x);
      roots:= select(roots, testtype, DOM_INT);    
      
      return(
             _union(op(solvelib::diophantine(sqrta*x + sqrtc *y - u, [x, y]),
                       2)
                 $ u in roots));

    end_if;
      

    // again with u = x+y,
    // sqrta g u^2 + du + sqrta f \equiv 0 modulo sqrtc d - sqrta e.
    
    // we should use quadratic congruences to solve this for u;
    // at the moment, we just test all integers

    roots:= select([$0..abs(sqrtc*d - sqrta * e)],
                   u -> (sqrta * g*u^2 + d*u +sqrta *f) mod
                   (sqrtc*d - sqrta * e) = 0);

    t:= genident();
    return(
    (_union
     (
      Dom::ImageSet(
                    [sqrtc * g * (sqrta*e - sqrtc *d) * t^2 +
                     (e + 2*sqrtc*g*u)*t -
                     (sqrtc*g*u^2 + e*u + sqrtc*f)/(sqrtc*d - sqrta *e),
                     sqrta*g*(sqrtc*d - sqrta*e)*t^2 -
                     (d + 2*sqrta*g*u)*t +
                     (sqrta*g*u^2 + d*u + sqrta*f)/(sqrtc*d - sqrta *e)],
                    t, Z_)
      $u in roots)
     ))
    
  end_if;
    


  assert(discriminant > 0);
  // hyperbolic case


  if d = 0 and e = 0 then
    // homogenous hyperbolic case

    if f = 0 then

      if not type((sqrta := sqrt(discriminant))) = DOM_INT then
        return({[0, 0]})
      end_if;

      // [x, y] are solutions if they are either solutions
      // of 2*a*x + (b+sqrta)*y = 0 or of
      //    2*a*x + (b-sqrta)*y = 0

      t:= ilcm(2*a, b+sqrta);
      u:= ilcm(2*a, b-sqrta);

      return(
      Dom::ImageSet([hold(k)*t/2/a, -hold(k)*t/(b+sqrta)], hold(k), Z_) union
      Dom::ImageSet([hold(k)*u/2/a, -hold(k)*u/(b-sqrta)], hold(k), Z_)
             )


    else
      // f <> 0 

      if type((sqrta := sqrt(discriminant))) = DOM_INT then

        // we know that
        // (2*a*x + (b + sqrta)*y) * (2*a*x + (b - sqrta)*y) = -4*a*f
        // so one factor must be u and the other -4*a*f/u, we have
        // to check this for all divisors u of -4*a*f

        
        divisors:= numlib::divisors(4*a*f);
        divisors:= {op(divisors.map(divisors, _negate))};

        return
        (
         map(divisors,
             proc(u)
               local yval;
             begin
               yval:=(u + 4*a*f/u)/2/sqrta;
               if type(yval) = DOM_INT then
                 [(u - (b+sqrta)*yval)/2/a, yval]
               else
                 null()
               end_if
             end_proc)
         )
        
      end_if; // sqrt(discriminant) is an integer 

      // case discriminant non-square not yet implemented
      
      

    end_if  // f = 0     
    

  end_if; // d = e = 0
    
  // case d<>0 or e<>0 not yet implemented
  

  
  hold(solve)(expr(h), vars, Domain = Z_)      
  
end_proc:
