/*++
The Riemann Zeta function

Calls: zeta(x)    (equivalent to zeta(x, 0))
       zeta(x, n) ( = diff(zeta(x), x $ n) )

   x - an expression
   n - a nonnegative integer
++*/

zeta :=
proc(x, n = 0)
  option noDebug;
  local t;
begin
  if args(0) = 0 then
    error("no arguments given")
  elif x::dom::zeta <> FAIL then
    return(x::dom::zeta(args()))
  elif args(0) > 2 then
    error("1 or 2 arguments expected");
  end_if;
  if indets(n) minus Type::ConstantIdents <> {} then
     return(procname(args()));
  end_if;
  if domtype(n) <> DOM_INT or n < 0 then
      error("2nd argument: expecting a nonnegative integer");
  end_if;
  case type(x)
    of DOM_FLOAT do
      if iszero(n) then
         // first try Apery-like formulas for x=2*n+1 
         if rdisfinite(x) then
           if iszero(frac(x)) and round(x) mod 2=1 then
             t := slot(zeta, "float".round(x));
             if t<>FAIL then
               return(t())
             end_if
           end_if;
         end_if;
      end_if;
      return(zeta::float(x, n));

    of DOM_INT do
      if iszero(n) then
        if x < 0 then
          if x mod 2 = 1 then 
             t:= bernoulli(1 - x);
             if type(t) = "bernoulli" then
                break;
             else
                return( -t / (1-x) )
             end_if;
          else 
             return ( 0 )
          end_if
        elif x > 0 then
          if x mod 2 = 0 then 
             t:= bernoulli(x);
             if type(t) = "bernoulli" then
                break;
             else
                return( 1/2*(2*PI)^x*abs(t)/fact(x) )
             end_if;
          elif x = 1 then
            error("singularity")
          end_if
        else // x = 0
          return( -1/2 )
        end_if;
      end_if;

      if x = 1 then
        error("singularity")
      end_if;
     
      break;
      
    of DOM_COMPLEX do
      
      if domtype(op(x, 1)) = DOM_FLOAT or 
         domtype(op(x, 2)) = DOM_FLOAT then
         return(zeta::float(x, n));
      end_if;
      break;

    of DOM_SET do
    of "_union" do
      return(map(x, zeta, n))

    of stdlib::Infinity do
       if x = infinity then
          if n = 0 then 
             return(1)
          else 
             return(0);
          end_if;
       elif x = -infinity then
          return(undefined);
       end_if;
  end_case;
  
  if not testtype(x,Type::Arithmetical) then
    /* generic handling of sets */
    if testtype(x, Type::Set) then
      if type(x)=Dom::ImageSet then
        return(map(x, zeta, n));
      else
        return(Dom::ImageSet(zeta(#x, n), #x, x));
      end_if;
    end_if;

    error("1st argument must be of 'Type::Arithmetical'")
  end_if;

  if iszero(n) then
     procname(x)
  else
     procname(args())
  end_if;
end_proc:

zeta := prog::remember(zeta, 
                       () -> [property::depends(args()),
                              Pref::autoExpansionLimit(), DIGITS,
                              slotAssignCounter("zeta")]):
//----------------------------------------------------
zeta(infinity):= 1:
zeta(infinity, 0):= 1:
zeta(infinity, 1):= 0:
zeta(0):= -1/2:
zeta(0, 0):= -1/2:
zeta(0, 1):= -1/2 * (ln(2) + ln(PI)):
//----------------------------------------------------
zeta:= funcenv(zeta):
zeta:= subsop(zeta, 2=op(specfunc::zeta, 2)):
zeta::print:= "zeta":
zeta::info:= "zeta -- the Riemann zeta function [try ?zeta for details]":
zeta::type:= "zeta":
// zeta::float:= specfunc::zeta:
zeta::f_:= 
proc(n, s, k) 
  option remember; 
  begin
    if k = 0 then
       ln(`#x`)^n/`#x`^s
    else
        diff(zeta::f_(n, s, k-1), `#x`)
    end_if:
end_proc:

//----------------------------------------------------
zeta::diff:= proc(f, x)
local z;
begin
  // f = zeta(z, n)
  z:= [op(f)];
  if nops(z) = 1 then
     return(zeta(z[1], 1)*diff(z[1], x));
  else
     return(zeta(z[1], z[2]+1)*diff(z[1], x));
  end_if;
end_proc:
//----------------------------------------------------

zeta::D:=
proc(idx, f)
  local n;
begin
  case args(0)
    of 1 do
      assert(idx = zeta);
      n:= 1;
      break
    of 2 do
      assert(f = zeta);
      assert(type(idx) = DOM_LIST);
      if nops(idx) = 0 then
        return(zeta)
      end_if;
      if {op(idx)} <> {1} then
        error("zeta is only differentiable w.r.t. its first argument")
      end_if;
      n:= nops(idx);
      break
    otherwise
      error("wrong number of arguments")
  end_case;
  fp::unapply(zeta(hold(x), n), hold(x))
end_proc:


zeta::realDiscont:= {1}:
zeta::complexDiscont:= {1}:
zeta::undefined:= {1}:

//----------------------------------------------------
zeta::Content := stdlib::genOutFunc("Czeta", 1, 2):

zeta::rectform:= loadproc(zeta::rectform, pathname("STDLIB", "RECTFORM"),
                          "zeta"):
zeta::series:= loadproc(zeta::series, pathname("SERIES"), "zeta"):

//----------------------------------------------------
/* these procedures give a faster numerical evaluation at x=3, 5, 7
  using Apery-like formulas. Reference: Empirically Determined Ape'ry-Like
  Formulae for zeta(4*n+3), by Jonathan Borwein and David Bradley,
  http://www.cecm.sfu.ca 
  For DIGITS=300, zeta::float3 is about 8 times faster than the standard
  routine, zeta::float5 is about 9 times faster, zeta::float7 is about 7
  times faster.
*/
/*
zeta::float3:=
proc()
  local c,s,n;
begin
   c:=2.0; s:=0; 
   for n from 1 to 1+iquo(5*DIGITS,3) do
      s:=s+1/n^3/c; c:=-c*(4*n+2)/(n+1);
   end_for;
   5*s/2
end_proc:

zeta::float5:=
proc() local a,c,s,n,g;
begin
   a:=0; c:=2.0; s:=0;
   for n from 1 to 1+iquo(5*DIGITS,3) do
      g:=1.0/n^2; s:=s+(g-1.25*a)/n^3/c; c:=-c*(4*n+2)/(n+1); a:=a+g;
   end_for;
   2*s
end_proc:

zeta::float7:=
proc() local a,c,s,n,g;
begin
   a:=0; c:=2.0; s:=0;
   for n from 1 to 1+iquo(5*DIGITS,3) do
      g:=1.0/n^4; s:=s+(5*a+g)/n^3/c; c:=-c*(4*n+2)/(n+1); a:=a+g;
   end_for;
   5*s/2   
end_proc:
*/
//-------------------------------------------------
// Warning (10.8.06): the present implementation is
// terribly slow for Re(s) < threshold (=0.7) !!!!
//-------------------------------------------------
zeta::float:= proc(s, n = 0)
local fs, f, pochhammer, sm1, m, K1, K2, zeta1, zeta2, zeta3, 
      requestedDIGITS, boost_DIGITS, newboost, 
      a, a_ref, k, ss, lna, aa, err, r, err_tol, result;
save DIGITS;
begin
  if iszero(n) then
     // use PARI's implementation of zeta:
     if args(0) < 3 then // for debugging the case n = 0,
                         // just pass any 3rd argument
        if domtype(float(s))  = DOM_FLOAT then
           return(specfunc::zeta(float(s))); 
        elif domtype(float(s))  <> DOM_COMPLEX then
           return(hold(zeta)(float(s), n));
        end_if:
     end_if;
  end_if;
  if domtype(n) <> DOM_INT then
     return(hold(zeta)(float(s), n)); 
  end_if;
  fs:= float(s):
  if fs = RD_INF then
     if n = 0 then
        return(float(1))
     else
        return(float(0))
     end_if;
  elif fs = RD_NINF then
     return(RD_NAN);
  elif fs = RD_NAN then
     return(RD_NAN);
  end_if;
  //---------------------------------------------------
  // the relative precision goal:
  //---------------------------------------------------
  err_tol:= 10.0^(-DIGITS):
  //---------------------------------------------------
  // for the evaluation near s = 1, make sure that
  // the difference between s and 1 is computed using
  // an exact representation of s (if s comes in as an
  // exact integer or rational):
  //---------------------------------------------------
  sm1:= float(s-1):
  if not contains({DOM_FLOAT, DOM_COMPLEX}, domtype(fs)) then
     if iszero(n) then
        return(hold(zeta)(fs));
     else
        return(hold(zeta)(fs, n));
     end_if;
  end_if;
  if iszero(fs-1) then
     error("singularity");
  end_if;

  if Re(fs) < -0.7 then
     return(zeta::reflect(args()));
  end_if;

  s:= fs;

  //-------------------------------------------------------------
  // We follow Bejoy K. Choudhury: The Riemann Zeta-Function
  // and Its Derivatives, Proceedings: Mathematical and Physical 
  // Sciences, Vol 450, No 1940 (Sept. 8. 1995), pp. 477 -499.
  //
  // Summary: zeta(s, n) = (-1)^n*(zeta1 + zeta2 - zeta3) + err, where
  // zeta1:=  sum(ln(k)^n/k^s, k = 0.. a-1) + 1/2*ln(a)^n/k^a,
  // zeta2:=  n!*a^(1-s)*sum(1/(n-k)!*ln(a)^(n-k)/(s-1)^(k+1),k=0..n),
  // zeta3:= sum(binomial(2*k)/(2*k)!* (D@@(2*k-1))(f)(a), k=1..m-1) 
  // with arbitrary positive integers 'a' and m and f(x) = ln(x)^n/x^s.
  // This holds for any s <> 1 with Re(s) > -1.
  //
  // One challenge is to choose a suitable value for m. This serves
  // for efficiency: Small values of m may lead to large values of 'a'.
  // On the other hand, large values of m lead to a rather expensive
  // evaluation of zeta3.
  // Once m is fixed, the parameter 'a' has to be chosen large enough
  // to ensure that the truncation error |err| is small enough.
  // |err| is estimated as given in the code below.
  //-------------------------------------------------------------
  //-------------------------------------------------------------
  // the following f(k, a) is the k-th derivative of ln(x)^n/x^s
  // w.r.t. x at the point x = a:
  //-------------------------------------------------------------
//f:= (n, s, k, a) -> float(subs(diff(ln(`#x`)^n/`#x`^s, `#x` $ k), `#x`=a));
  f:= (n, s, k, a) -> float(subs(zeta::f_(n, s, k), `#x` = a));
  //-------------------------------------------------------------
  // utility pochhammer
  //-------------------------------------------------------------
  pochhammer:= proc(s, m) local i; begin _mult(s + i $ i = 0..m-1): end_proc:
  //-------------------------------------------------------------
  // Heuristics for choosing a suitable value for the parameter m.
  // The values were found by experimenting with this implementation.
  //-------------------------------------------------------------
  if n = 0 then
     m:= max(5, round(DIGITS/3));
  else
     m:= max(3, round(DIGITS/3), round(n/30));
  end_if;
  if Re(s) < 3 and specfunc::abs(Im(s)) > 10 then
     m:= max(m, ceil(5*log(10, specfunc::abs(Im(s)))));
  end_if;

  requestedDIGITS:= DIGITS;
  boost_DIGITS:= 3: // at first, introduce 3 guard digits
  repeat // compute with more and more DIGITS 
    DIGITS:= requestedDIGITS + boost_DIGITS;
    //-----------------------------------------------
    // Minimal value for the parameter a (for smaller
    // values of 'a', the error bound is not guaranteed)
    //-----------------------------------------------
    a:= ceil((2*m - 1/2)*exp(float(EULER))):
    //-----------------------------------------------
    // Generate the first term: a rough approximation
    // using the sum that defines zeta:
    //-----------------------------------------------
    zeta1:= _plus(ln::float(k)^n/k^s $ k = 1..a);
    //------------------------------------------------
    // Find suitable upper summation index 'a' for the
    // Euler-Maclaurin formula. Add further terms
    // to the sum above. Since the evaluation of the
    // error bound is very expensive, double the
    // value of 'a' until a new error bound is computed.
    // Thus, 'a' tends to be too high (but this is still
    // cheaper than computing the error bound for each
    // value of 'a').
    // On the way, we need a value for zea(s, n) because
    // we aim at a relative precision.
    //-----------------------------------------------
    a_ref:= a + 1;
    repeat
      a:= a + 1;
      if a = a_ref then
         userinfo(10, "check error at a = ".expr2text(a));
         if a > max(1000,  2*ceil((2*m - 1/2)*exp(float(EULER)))) and
            m < 100 then
            userinfo(10, "increasing m from ".expr2text(m)." to ".expr2text(2*m));
            m:= 2*m;
         end_if;
         //-------------------------------------------------------------
         // Some constants needed for the error bound
         //-------------------------------------------------------------
         K1:= float(8/sqrt(PI)/(2*PI)^(2*m)):
         K2:= K1 * n * (2*m)!:
         //-------------------------------------------------------------
         // Compute the absolute error bound
         //-------------------------------------------------------------
         aa:= a/exp(float(EULER));
         ss:= Re(s) + 2*m - 1:
         err:= K1/ ss^(n + 1)*specfunc::abs(pochhammer(s,2*m))*igamma(n + 1, ss*ln(float(a)))
              +K2/ a^ss 
                 / ss^n 
                 * _plus(1/r * specfunc::abs(pochhammer(s, 2*m-r))/(2*m -r)! *
                         (aa/(r - 1/2))^ss * igamma(n, ss * ln(aa/(r-1/2)))
                      $ r = 1..2*m);
         //-------------------------------------------------------------
         // Compute an approximation of zeta(s, n), because we want
         // to control the *relative* error
         //-------------------------------------------------------------
         lna:= ln(float(a));
         // zeta2 = n!*a^(-sm1)*_plus(lna^k/(sm1)^(n-k+1)/k! $ k = 0..n);
         //       = 1/(s-1)^(n+1)*igamma(n + 1, ln(a^(s - 1))).
         // Unfortunately, MuPAD's igamma does not allow a complex argument a^(s-1).
         // So we need to do the sum ourselves. For large n, however, it needs to 
         // be stabilized numerically by boosting DIGITS!
         DIGITS:= DIGITS + ceil(n/3):
         zeta2:= n!*a^(-sm1)*_plus(lna^k/(sm1)^(n-k+1)/k! $ k = 0..n);
         DIGITS:= DIGITS - ceil(n/3):
         zeta3:= _plus(bernoulli(2*k)/(2*k)! * f(n, s, 2*k - 1, a) $ k = 1..m-1);
         // the estimate for the final value of zeta is
         // zeta(s, n) = (-1)^n * (zeta1 + ln(a)^n/a^s/2 + zeta2 - zeta3):
         if err <= err_tol*specfunc::abs(zeta1 + lna^n/a^s/2 + zeta2 - zeta3) then
            //---------------------------------------------------------
            // Add last term in the first sum (it bears the factor 1/2)
            //---------------------------------------------------------
            zeta1:= zeta1 + 1/2*ln::float(a)^n/a^s;
            result:= (-1)^n*(zeta1 + zeta2 - zeta3):
            break;
         end_if;
         a_ref:= 2*a_ref;
        end_if;
      // if the precision goal is not achieved, add a further term 
      // to the zeta sum and proceed in the loop (increase 'a')
      zeta1:= zeta1 + ln::float(a)^n/a^s;
    until FALSE end_repeat;
    //----------------------------------------------------------------
    // Check the numerical stability of the sum zeta1 + zeta2 - zeta3:
    //----------------------------------------------------------------
    newboost:= ceil(log(10, max(specfunc::abs(zeta1),
                                specfunc::abs(zeta2),
                                specfunc::abs(zeta3))
                              / specfunc::abs(result) ));
    if newboost > boost_DIGITS + 3 then
       // Cancellation problems! Do it again with higher DIGITS.
       boost_DIGITS:= 3*newboost;
       userinfo(10, "boosting DIGITS from ".expr2text(DIGITS).
                   " to ".expr2text(requestedDIGITS + boost_DIGITS));
    else
       break; 
    end_if;
  until FALSE end_repeat;
  //-----------------------------------------
  // The final result was numerically stable:
  //-----------------------------------------
  return(result);
end_proc:

//-------------------------------------------------------------
// Implementation of the reflection rule
// zeta(1-s) = PI^(1/2 - s)*gamma(s/2)/gamma((1-s)/2) * zeta(s)
//           = 2*(2*PI)^(-s)*gamma(s)*cos(PI/2*s) * zeta(s)
// i.e.,
// zeta(s) = PI^(s - 1/2)*gamma((1-s)/2)/gamma(s) * zeta(1 - s)
//         = 2*(2*PI)^(s-1)*gamma(1-s)*cos(PI/2*(1-s)) * zeta(1-s)
// used for float evaluation of zeta(s) with Re(s) < 0 (the rule 
// maps the left half plane to the right half plane Re(s) > 1).
//-------------------------------------------------------------
zeta::reflect:= proc(s, n = 0)
local fPI, z, ss, Gamma, r, C, S, m, 
      result, drmax, dr, 
      newboost, boost_DIGITS, requestedDIGITS;
save DIGITS;
begin
   requestedDIGITS:= DIGITS;
   boost_DIGITS:= 0:
   repeat // compute with more and more DIGITS
     DIGITS:= requestedDIGITS + boost_DIGITS;
     fPI:= float(PI):
     z:= -ln(2*fPI) - fPI/2*I:
     ss:= float(1 - s):
     // Compute 
     //   Gamma[r+1] = diff(gamma(x), x $ r+1) | x = ss;
     //              = diff(gamma(x)*psi(x, 0), x $ r) | x = ss
     //              = sum(binomial(r, m)
     //                    * diff(gamma(x), x $ r-m)
     //                    * diff(psi(x, 0), x $ m), 
     //                    m = 0..r) | x = ss
     //              = sum(binomial(r, m)*Gamma[r-m]*psi(ss,m)
     Gamma[0]:= gamma(ss):
     DIGITS:= DIGITS + 30:
     for r from 0 to n-1 do
        Gamma[r+1]:= _plus(binomial(r, m)*Gamma[r-m]*psi(ss, m) $ m = 0..r);
     end_for:
     DIGITS:= DIGITS - 30:
     C:= cos(fPI/2*ss):
     S:= sin(fPI/2*ss):
 
     result:= float(0):
     drmax:= float(0):
     for r from 0 to n do
        DIGITS:= DIGITS + 30:
        dr:= _plus(binomial(n, m)*binomial(m, r)* 
                   (Re(z^(n-m))*C + Im(z^(n-m))*S) *
                   Gamma[m-r] $ m = r..n);
        DIGITS:= DIGITS - 30:
        dr:= zeta(ss, r) * dr;
        drmax:= max(drmax, specfunc::abs(dr));
        result:= result + dr;
     end_for:
     if iszero(drmax) then
        break;
     end_if;
     newboost:= ceil(log(10, drmax/specfunc::abs(result)));
     if newboost > boost_DIGITS +  3 then
       // recompute with more DIGITS;
       boost_DIGITS:= 2*newboost;
     else
       break;
     end_if;
   until FALSE end_repeat;
   return((-1)^n*2/(2*fPI)^ss*result);
end_proc:
// end of file 

zeta::expand:=
proc(z)
  local x;
begin
   // z = zeta(x) or zeta(x, n);
   if nops(z) = 2 then 
      return(zeta(expand(op(z, 1)), expand(op(z, 2))));
   end_if;
   x:= op(z, 1);
   if  domtype(x) <> DOM_INT then
      return(z)
   end_if;
   if x < 0 then
     if x mod 2 = 1 then
        return(-expand(bernoulli(1-x)/(1-x)));
     else
        return ( 0 )
     end_if
   elif x > 0 then
     if x mod 2 = 0 then
        return(expand(1/2*(2*PI)^x*abs(bernoulli(x))/fact(x)));
     elif x = 1 then
       error("singularity")
     end_if
   else // x = 0
     return(-1/2)
   end_if;
   return(zeta(expand(x)));
end_proc:
zeta::expand := prog::remember(zeta::expand, 
  () -> [property::depends(args()), DIGITS]):
