// 

// interval evaluation of lambertW

// utility function: enclose a single point
//
// the idea behind the implementation is the same as
// suggested by Corless et al. for floating point
// approximations: First, find a rough approximation
// (in our case, an enclosure) *in the correct branch*,
// then use Newton iteration of w*exp(w)-z=0 to
// improve the value of w found.

alias(newton_step(appr, z) =
      DOM_INTERVAL::center(appr)+(z*exp(-appr)-appr)/(appr+1)):
// alias(left_side(z) = subsop(z, [1,2]=op(z, [1,1]))):
// alias(right_side(z) = subsop(z, [1,1]=op(z, [1,2]))):
// alias(top_side(z) = subsop(z, [2,2]=op(z,[2,1]))):
// alias(bot_side(z) = subsop(z, [2,1]=op(z,[2,2]))):
left_side  := z -> subsop(z, [1,2]=op(z, [1,1])):
right_side := z -> subsop(z, [1,1]=op(z, [1,2])):
top_side   := z -> subsop(z, [2,2]=op(z, [2,1])):
bot_side   := z -> subsop(z, [2,1]=op(z, [2,2])):

blcorner   := z -> subsop(z, [1,2]=op(z, [1,1]), [2,2]=op(z, [2,1])):
brcorner   := z -> subsop(z, [1,1]=op(z, [1,2]), [2,2]=op(z, [2,1])):
ulcorner   := z -> subsop(z, [1,2]=op(z, [1,1]), [2,1]=op(z, [2,2])):
urcorner   := z -> subsop(z, [1,1]=op(z, [1,2]), [2,1]=op(z, [2,2])):
corners := [blcorner, brcorner, ulcorner, urcorner]:

DOM_INTERVAL::_lambertW_point :=
proc(k : DOM_INT, z : DOM_FLOAT)
  local appr, prev_appr, starting_region_is_safe;
begin
  appr := hull(lambertW(k, z));

  repeat
    if has(appr, {RD_INF, RD_NINF}) then
      error("ouch.  TODO.");
    end_if;
    // enlarge appr if necessary
    // Newton step
    prev_appr := appr;
    appr := newton_step(appr, z);
    if lhs(Re(appr)) > lhs(Re(prev_appr)) and
       rhs(Re(appr)) < rhs(Re(prev_appr)) and
       lhs(Im(appr)) > lhs(Im(prev_appr)) and
       rhs(Im(appr)) < rhs(Im(prev_appr)) then
      starting_region_is_safe := TRUE;
    elif 0 in hull(map(corners(appr), x->x*exp(x)-z)) then
      starting_region_is_safe := TRUE;
    else
      starting_region_is_safe := FALSE;
      appr := appr + ((-1...1)+(-I...I))*1e-7;
    end_if;
  until starting_region_is_safe end_repeat;
  
  while DOM_INTERVAL::width(appr)
	/(specfunc::abs(DOM_INTERVAL::center(appr))+1)
	> 10^(2-DIGITS) do
    appr := newton_step(appr, z);
  end_while;
  appr;
end_proc:

// utility functions to find the extrema along the sides.
// See ccr's thesis for the theoretical justification
DOM_INTERVAL::_lambertW_extrema_real_const_x :=
proc(k : DOM_INT, x : DOM_FLOAT, yiv : DOM_INTERVAL)
  local eta, c, y, f;
begin
  f := e -> cos(e)*exp(e*tan(e)-1)-x;
  eta := numeric::realroot(f(`#eta`),
			   `#eta` = if iszero(k) then
				      if yiv >=0 then
					0..PI
				      else
					-PI..0
				      end_if
				    else
				      if k * yiv >= 0 then
					((2*k-1/2)*PI)..((2*k+1/2)*PI)
				      else
					((2*k-3/2)*PI)..((2*k-1/2)*PI)
				      end_if
				    end_if);
  if nops(eta) <> 1 then
    warning("Spurious solution (or theory bug)");
  end_if;

  eta := hull(op(eta, 1));
  while sign(f(hull(op(eta,1)))) = sign(f(hull(op(eta, 2)))) do
    eta := eta*(1+(-1...1)*1e-5);
  end_while;
  
  y := (eta*cos(eta)+sin(eta)*(eta*tan(eta)-1))*exp(eta*tan(eta)-1);
  if y intersect yiv = {} then
    return({});
  end_if;
  
  while DOM_INTERVAL::width(eta) > 10^(3-DIGITS) do
    c := (op(eta,1)+op(eta,2))/2;
    if sign(f(hull(c))) = sign(f(hull(op(eta,1)))) then
      eta := subsop(eta, 1=c);
    else
      eta := subsop(eta, 2=c);
    end_if;
  end_while;
  
  eta*tan(eta)-1;
end_proc:

DOM_INTERVAL::_lambertW_extrema_imag_const_y :=
proc(k : DOM_INT, xiv : DOM_INTERVAL, y : DOM_FLOAT)
  local eta, c, x, f;
begin
  f := e -> (e*cos(e)+sin(e)*(e*tan(e)-1))*exp(e*tan(e)-1)-y;
  eta := numeric::realroot(f(`#eta`),
			   `#eta` = if iszero(k) then
				      if yiv >=0 then
					0..PI
				      else
					-PI..0
				      end_if
				    else
				      if k * yiv >= 0 then
					((2*k-1/2)*PI)..((2*k+1/2)*PI)
				      else
					((2*k-3/2)*PI)..((2*k-1/2)*PI)
				      end_if
				    end_if);
  if nops(eta) <> 1 then
    warning("Spurious solution (or theory bug)");
  end_if;

  eta := hull(op(eta, 1));
  while sign(f(hull(op(eta,1)))) = sign(f(hull(op(eta, 2)))) do
    eta := eta*(1+(-1...1)*1e-5);
  end_while;
  
  x := cos(eta)*exp(eta*tan(eta)-1);
  if x intersect xiv = {} then
    return({});
  end_if;
  
  while DOM_INTERVAL::width(eta) > 10^(3-DIGITS) do
    c := (op(eta,1)+op(eta,2))/2;
    if sign(f(hull(c))) = sign(f(hull(op(eta,1)))) then
      eta := subsop(eta, 1=c);
    else
      eta := subsop(eta, 2=c);
    end_if;
  end_while;
  
  eta;
end_proc:

DOM_INTERVAL::_lambertW_extrema_real_const_y :=
proc(k : DOM_INT, xiv : DOM_INTERVAL, y : DOM_FLOAT)
  local eta, c, x, f;
begin
  f := e -> sin(e)/exp(e*cot(e)+1)-y;
  eta := numeric::realroot(f(`#eta`),
			   `#eta` = if iszero(k) then
				      if y >=0 then
					0..PI
				      else
					-PI..0
				      end_if
				    else
				      if k * y >= 0 then
					((2*k-2)*PI)..((2*k-1)*PI)
				      else
					((2*k-1)*PI)..(2*k*PI)
				      end_if
				    end_if);
  if nops(eta) <> 1 then
    warning("Spurious solution (or theory bug)");
  end_if;

  eta := hull(op(eta, 1));
  while sign(f(hull(op(eta,1)))) = sign(f(hull(op(eta, 2)))) do
    eta := eta*(1+(-1...1)*1e-5);
  end_while;
  
  y := -(eta*sin(eta)+cos(eta)*(eta*cot(eta)+1))/exp(eta*cot(eta)+1);
  if y intersect yiv = {} then
    return({});
  end_if;
  
  while DOM_INTERVAL::width(eta) > 10^(3-DIGITS) do
    c := (op(eta,1)+op(eta,2))/2;
    if sign(f(hull(c))) = sign(f(hull(op(eta,1)))) then
      eta := subsop(eta, 1=c);
    else
      eta := subsop(eta, 2=c);
    end_if;
  end_while;
  
  -eta*cot(eta)-1;
end_proc:

DOM_INTERVAL::_lambertW_extrema_imag_const_x :=
proc(k : DOM_INT, x : DOM_FLOAT, yiv : DOM_INTERVAL)
  local eta, c, y, f;
begin
  f := e -> -(e*sin(e)+cos(e)*(e*cot(e)+1))/exp(e*cot(e)+1)-x;
  eta := numeric::realroot(f(`#eta`),
			   `#eta` = if iszero(k) then
				      if y >=0 then
					0..PI
				      else
					-PI..0
				      end_if
				    else
				      if k * y >= 0 then
					((2*k-2)*PI)..((2*k-1)*PI)
				      else
					((2*k-1)*PI)..(2*k*PI)
				      end_if
				    end_if);
  if nops(eta) <> 1 then
    warning("Spurious solution (or theory bug)");
  end_if;

  eta := hull(op(eta, 1));
  while sign(f(hull(op(eta,1)))) = sign(f(hull(op(eta, 2)))) do
    eta := eta*(1+(-1...1)*1e-5);
  end_while;
  
  y := sin(eta)/exp(eta*cot(eta)+1);
  if y intersect yiv = {} then
    return({});
  end_if;
  
  while DOM_INTERVAL::width(eta) > 10^(3-DIGITS) do
    c := (op(eta,1)+op(eta,2))/2;
    if sign(f(hull(c))) = sign(f(hull(op(eta,1)))) then
      eta := subsop(eta, 1=c);
    else
      eta := subsop(eta, 2=c);
    end_if;
  end_while;
  
  eta;
end_proc:

// TODO: purely real results for k in {0, -1}

DOM_INTERVAL::_encl_lambertW :=
proc(k, iv)
  local xiv, yiv, xmin, xmax, ymin, ymax, ret, real_ext, imag_ext;
begin
  xiv := Re(iv);
  yiv := Im(iv);
  [xmin, xmax] := [op(xiv)];
  [ymin, ymax] := [op(yiv)];
  // step 1: enclose the corners
  ret := hull(dom::_lambertW_point(k, xmin + I*ymin),
	      dom::_lambertW_point(k, xmin + I*ymax),
	      dom::_lambertW_point(k, xmax + I*ymin),
	      dom::_lambertW_point(k, xmax + I*ymax));

  // step 2: find the extrema along the sides and enlarge ret accordingly
  real_ext := hull(dom::_lambertW_extrema_real_const_y(k, xiv, ymin),
		   dom::_lambertW_extrema_real_const_y(k, xiv, ymax),
		   dom::_lambertW_extrema_real_const_x(k, xmin, yiv),
		   dom::_lambertW_extrema_real_const_x(k, xmax, yiv));
  imag_ext := hull(dom::_lambertW_extrema_imag_const_y(k, xiv, ymin),
		   dom::_lambertW_extrema_imag_const_y(k, xiv, ymax),
		   dom::_lambertW_extrema_imag_const_x(k, xmin, yiv),
		   dom::_lambertW_extrema_imag_const_x(k, xmax, yiv));
  
  ret := hull(Re(ret), real_ext) + I*hull(Im(ret), imag_ext);
  
  ret intersect
  ((RD_NINF...RD_INF) + I*(((2*k-2)*PI)...((2*k+1)*PI)));
end_proc:


DOM_INTERVAL::lambertW :=
proc(k, iv)
  local appr, t, x, i;
begin
  iv := interval(iv);
  if iv::dom <> DOM_INTERVAL then
    if iv = {} then return(iv); end_if;
    return(FAIL);
  end_if;
  if k::dom <> DOM_INT then
    if testtype(k, Type::Constant) then
      error("illegal branch");
    end_if;
  end_if;
  if op(iv, 0) = hold(_union) then
    _union(map(op(iv), a->DOM_INTERVAL::lambertW(k, a)));
  else
    DOM_INTERVAL::_encl_lambertW(k, iv);
  end_if;
end_proc:

/* Notes:
   
eqx := [x - exp(a)*(a*cos(b)-b*sin(b)),
        y - exp(a)*(b*cos(b)+a*sin(b)),
        x + exp(a)*cos(b)]:
eqy := [x - exp(a)*(a*cos(b)-b*sin(b)),
        y - exp(a)*(b*cos(b)+a*sin(b)),
        y + exp(a)*sin(b)]:
eqx2 := [exp(a)*cos(b) + exp(a)*(a*cos(b)-b*sin(b)),
         y - exp(a)*(b*cos(b)+a*sin(b))]:
eqy2 := [x - exp(a)*(a*cos(b)-b*sin(b)),
         exp(a)*sin(b) + exp(a)*(b*cos(b)+a*sin(b))]:
	
// have a look at lambertW(2, (-1+I)...(200+I)):

// interval::newton can use a little help here,
// since by itself it would choose grossly inferior
// directions for subdivision.  The starting region
// should be enclosed much tighter.
   
numeric::fsolve(eqx | y=1,
  [x = -1..200, a=-500..500,
   b = 3*PI..5*PI]);
   
interval::newton(eqx | y=1,
[x = %[1][2]* (1+(-1e-5...1e-5)),
a = %[2][2] * (1+(-1e-5...1e-5)),
b = %[3][2] * (1+(-1e-5...1e-5))], 1e-6)

// this approximation "probably" has more than sufficient
// round-out error to provide an enclosure, but having a guarantee
// would still be much better.
approx :=
proc(k, z)
  local Log, logLog;
begin
  Log := interval(ln(z)+2*k*PI*I);
  logLog := ln(Log);
  Log - logLog + logLog/Log + logLog*(logLog - 2)/(2*Log^2)
  + logLog * (6-9*logLog + 2*logLog^2)/(6*Log^3)
  + logLog * (3*logLog^3 - 22*logLog^2 + 36*logLog - 12)/(12*Log^4);
end_proc:

   // round-out error way too large:
//approx(2, (-1...200)+I);
//interval::newton(eqx2 | y=1,
//  [a in Re(%), b in Im(%)]);
//interval::newton(eqy | y=1,
//  [x in -1...200, a in Re(%2), b in Im(%2)]);
  
approx(2, (-1...0)+I);
interval::newton(eqx2 | y=1,
  [a in Re(%), b in Im(%)])

interval::newton(eqy | y=1,
  [x = -1...200, a=-500...500,
   b = ((3+k/10)*PI)...((2+(k+1)/10)*PI)],
   1e-3) $ k = 0..19
   
 */
 