// 

/* rewriteTrig.mu
 *
 * Rewrites functions in exp to trigonometric functions.
 */

Simplify::rewriteTrig := proc( X )
      local rewriteMult, rewriteSum, distributiveMult, rewriteTan, rewriteExpToTrig;
begin
      /* distributiveMult( a, b )
       * Description:
       * returns a*b with using the distributive law if a is a sum
       */
      distributiveMult := proc( a, b )
      begin
            if testtype( a, "_plus" ) then
                  return( _plus( map( op(a) , X->X*b ) ) );
            else
                  return( a*b );
            end_if;
      end:

      /* rewriteExpToTrig( EXPR, x )
       * Description:
       * Tries to rewrite appearance of exp depending on x to trigonometric Terms depending on x.
       * This function has to be called explicit with (EXPR, I*x) to rewrite exp(I*x)-terms.
       *
       * Input:
       * EXPR: arithmetical expression
       * x: The identifier to be tested (can also be a sum or another arithmetic expression).
       */
      rewriteExpToTrig := proc( TEXPR, x )
            local i, j, summand, T, minExp, maxExp, indexList, avgVal, indx, result,
                  isImaginary, _isImaginary, diffSum, testBinom, getConstant, checkSumTerms, trySubstitution, realFactor;
      begin
            /* getConstant( EXPR )
            * Description:
            * Returns the numeric constant factor in EXPR
            */
            getConstant := proc( EXPR: Type::Arithmetical )
                  local A;
            begin
                  if testtype( EXPR, "_mult" ) then
                        A := split( EXPR, testtype, Type::Constant );
                        // there should be no unknown
                        if A[3]<>1 then
                              return( null() );
                        end;
                        return( A[1] );
                  elif testtype( EXPR, Type::Constant ) then
                        return( EXPR );
                  else
                        return( 1 );
                  end;
            end:

            /* checkSumTerm( A, B )
            * Description:
            * Tries to find an as simple as possible segmentation to describe A*exp(x) + B*exp(-x) in trigonometric functions.
            *
            * Input:
            * A and B arithmetical expression such that X := A*exp(x) + B*exp(-x)
            *
            * Output:
            * [ cosTerm, sinTerm ] such that X = (cosTerm+sinTerm)*exp(x) + (cosTerm-sinTerm)*exp(-x) + remainder
            * respectively X = cosTerm*2*cosh(x) + sinTerm*2*sinh(x) + remainder
            * with remainder being as small as possible.
            */
            checkSumTerms := proc( A:Type::Arithmetical, B:Type::Arithmetical )
                  local cosTerm, sinTerm;
            begin
                  if ( A=0 ) or ( B=0 ) then
                        return( [0,0] );
                  end;

                  if not testtype( A, "_plus" ) then
                        A := [ A ];
                  end;
                  if not testtype( B, "_plus" ) then
                        B := [ B ];
                  end;

                  A := map( {op(A)}, (proc(X) begin getConstant(X); return( [ %, X/% ] ); end;) );
                  B := map( {op(B)}, (proc(X) begin getConstant(X); return( [ %, X/% ] ); end;) );
                  A := map( A, X -> ([Re(X[1]),X[2]],[ I*Im(X[1]),X[2] ] ) );
                  B := map( B, X -> ([Re(X[1]),X[2]],[ I*Im(X[1]),X[2] ] ) );

                  cosTerm := _plus( op( map( A intersect B, X->X[1]*X[2] ) ) );
                  B := map( B, X->[ -X[1], X[2] ] );
                  sinTerm := _plus( op( map( A intersect B, X->X[1]*X[2] ) ) );

                  return( [ cosTerm, sinTerm ] );
            end:

            /* realFactor( EXPR, x )
            * Description:
            * Returns the real quantity of x in EXPR
            *
            * Input:
            * EXPR an arithmetical Expression containing the identifier x
            *
            * Output:
            * The real quantity of x in EXPR. I.e. EXPR could be written as
            * result*x + remainder with Re(remainder) being (hopefully) independet from Re(x).
            */
            realFactor := proc( EXPR : Type::Arithmetical, x )
                  local diff;
            begin
                  if testtype( EXPR, "_mult" ) then
                        diff := EXPR/x;
                        if type(diff)=DOM_COMPLEX then
                          diff := op(diff,1);
                        end_if;
                        if testtype( diff, Type::Constant ) and is( diff, Type::Real )=TRUE then
                              return( diff );
                        end;
                  elif testtype( EXPR, "_plus" ) then
                        diff := EXPR/x;
                        if type(diff)=DOM_COMPLEX then
                          diff := op(diff,1);
                        end_if;
                        if testtype( diff, Type::Constant ) and is( diff, Type::Real )=TRUE then
                              return( diff );
                        else
                              return( _plus( map( op(EXPR), realFactor, x ) ) );
                        end;
                  elif EXPR=x then
                        return( 1 );
                  end;
                  return( 0 );
            end;

            T := table();

            _isImaginary := proc( x : Type::Arithmetical )
                  local isImaginary, i;
            begin
                  if testtype( x, "_plus" ) then
                        return( _and( _isImaginary( op(x), _isImaginary ) ) );
                  elif not testtype( x, "_mult" ) then
                        return( FALSE );
                  else
                        i := op( x, nops(x) );
                        if not testtype( i, Type::Constant ) then
                              return( FALSE );
                        else
                              return( not bool(is( i, Type::Real )=TRUE) );
                        end;
                  end;
            end;

            isImaginary := _isImaginary( x );;

            // initialize variables
            indexList := null();
            minExp := +infinity;
            maxExp := -infinity;

            // the following loop goes through EXPR (which is expected to be a sum) and extracts the multiplicands
            // exp( n*x ) with n an numeric constant. The Values a stored in the table T with n being the index and the remainder
            // being the entry.

            // go through every summand in EXPR
            map( op(TEXPR), proc( i )
                  local j;
                  begin
                        diffSum := realFactor( op(i,1), x );
                        j := op(i,1) - diffSum*x;
                        if testtype( j, "_plus" ) then
                              summand := op(i,2) * _mult( map( op(j), exp ) );
                        else
                              summand := op(i,2) * exp( j );
                        end;

                        maxExp := max( diffSum, maxExp );
                        minExp := min( diffSum, minExp );
                        if contains( T, diffSum ) then
                              T[ diffSum ] := T[ diffSum ] + summand;
                        else
                              indexList := indexList, diffSum;
                              T[ diffSum ] := summand;
                        end;
                  end
                  );

            if maxExp=minExp then
                  return( FAIL );
            end;

            result := 0;
            avgVal := (maxExp+minExp)/2;

            indexList := [ indexList ];

            /* testbinom()
             * Description:
             * Tests if EXPR is a binom ( exp(x) +- exp(-x) )^n
             */
            testBinom := proc()
                  local n, val, lTab, bestVal, intFact, k;
            begin
                  n := nops( T )-1;
                  if n<2 then
                        return();
                  end;
                  intFact := (maxExp-minExp) / n;
                  if not is( n, Type::PosInt ) then
                        return();
                  end;
                  lTab := table();
                  bestVal := null();
                  for i in indexList do
                        k := (i-minExp) / intFact;
                        // test for sin
                        val :=T[i]/binomial(n,k)*_power( -1, n-k );
                        if contains( lTab, [ val, -1 ] ) then
                              lTab[ [val,-1] ] := lTab[ [val,-1] ] + 1;
                        else
                              lTab[ [val,-1] ] := 1;
                        end;
                        if bestVal=null() or lTab[ [val,-1] ] > lTab[ bestVal ] then
                              bestVal := [val,-1];
                        end;
                        // test for cos
                        val :=T[i]/binomial(n,k);
                        if contains( lTab, [ val, 1 ] ) then
                              lTab[ [val,1] ] := lTab[ [val,1] ] + 1;
                        else
                              lTab[ [val,1] ] := 1;
                        end;
                        if lTab[ [val,1] ] > lTab[ bestVal ] then
                              bestVal := [val,1];
                        end;
                  end;
                  if lTab[ bestVal ]-1>(nops( T )-1)/2 then
                        if op( bestVal, 2 )=-1 then
                              if isImaginary then
                                    result := ( _power( 2*sin(-I*x*intFact/2)*(I), n)*op( bestVal,1 ) )*exp(avgVal*x);
                              else
                                    result := ( _power( 2*sinh(x*intFact/2), n)*op( bestVal,1 ) )*exp(avgVal*x);
                              end;
                              for i in indexList do
                                    k := (i-minExp) / intFact;
                                    T[i] := T[i] - binomial(n,k)*_power( -1, n-k )*op(bestVal,1);
                              end;
                        elif op( bestVal, 2 )=1 then
                              if isImaginary then
                                    result := ( _power( 2*cos(I*x*intFact/2), n)*op( bestVal,1 ) )*exp(avgVal*x);
                              else
                                    result := ( _power( 2*cosh(x*intFact/2), n)*op( bestVal,1 ) )*exp(avgVal*x);
                              end;
                              for i in indexList do
                                    k := (i-minExp) / intFact;
                                    T[i] := T[i] - binomial(n,k)*op(bestVal,1);
                              end;
                        end;
                  end;
            end;

            /* trySubsitution( T, avgVal, result, indexList )
             * Description:
             * Substitutes appearance of exp(x)+exp(-x) to trigonometric Terms.
             *
             * Input:
             * T: Table containig all summands of EXPR indicated by N with N being the exponent of exp(x) of each summand
             * avgVal: An Value that is tried to multiple to the summands to group to trigonometric expressions
             * result: The (already) received result
             * indexList: List of all inidices of T
             *
             * Output:
             * [ T, result ]: If achieved any substitution an updated table and a new result
             *  Returns FAIL, if no substitutions could be found.
             */
            trySubstitution := proc( T: DOM_TABLE, avgVal: Type::Rational, result : Type::Arithmetical, indexList : DOM_LIST )
                  local i, j, A, B, found;
            begin
                  i := 1;
                  found := FALSE;
                  while i <= nops(indexList) do
                        j := op( indexList, i );
                        if T[j] = 0 or is(j>=avgVal)=TRUE then
                              ;
                        else
                              indx := contains( indexList, -j+2*avgVal );
                              if indx<>0 then
                                    [ B, A ] := checkSumTerms( T[j], T[-j+2*avgVal] );
                                    if A<>0 or B<>0 then
                                          if isImaginary then
                                                result := result + distributiveMult( 2*A, sin(-I*x*(j-avgVal))*(I)*exp(avgVal*x) );
                                          else
                                                result := result + distributiveMult( 2*A, sinh(x*(j-avgVal))*exp(avgVal*x) );
                                          end;
                                          if isImaginary then
                                                result := result + distributiveMult(2*B, cos(I*x*(j-avgVal) )*exp(avgVal*x) );
                                          else
                                                result := result + distributiveMult(2*B, cosh(x*(j-avgVal) )*exp(avgVal*x) );
                                          end;
                                          found := TRUE;
                                          T[j] := T[j] - A - B;
                                          T[2*avgVal-j] := T[2*avgVal-j] + A - B;
                                    end;
                              else
                                    result := result + distributiveMult( T[j], exp(j*x) );
                              end;
                        end;
                        i := i + 1
                  end;
                  if not found then
                        return( FAIL );
                  end;
                  return( T, result );
            end;

            testBinom();

            i := trySubstitution( T, 0, result, indexList );
            if i<>FAIL then
                  T := i[1];
                  result := i[2];
            end;
            if testtype( avgVal, Type::Rational ) then
                  i := trySubstitution( T, avgVal, result, indexList );
                  if i<>FAIL then
                        T := i[1];
                        result := i[2];
                  end;
            end;
            if result<>0 then
                  for i in indexList do
                        result := result + distributiveMult( T[i], exp(i*x) );
                  end;
                  return( result );
            else
                  return( FAIL );
            end;

            return( result );
      end:

      /* rewriteSum( EXPR )
       * Description:
       * Searches for the indeterminants of EXPR and calls rewriteExpToTrig.
       *
       * Input:
       * EXPR: A arithmetical expression, typically a sum to be rewritten.
       *
       * Output:
       * An equivalant expression with compatible parts being substituted by trigonometric functions.
       */
      rewriteSum := proc ( EXPR : Type::Arithmetical )
            local i, T, inds, tmp, toCheck, j, checkIndets, alreadyChecked, doStep, buildTable;
      begin

            /* checkIndets( indet )
             * Description:
             * Checks if the indet has a numeric factor and returns the indet without the factor
             * and I*indet if it has an complex factor.
             */
            checkIndets := proc( indet )
                  local A, res;
            begin
                  if testtype( indet, "_mult" ) then
                        A := split( indet, testtype, Type::Constant );
                        // there should be no unknown
                        if A[3]<>1 then
                              return( null() );
                        end;
                        if Re( A[1] ) <> 0 and Im( A[1] ) <> 0 then
                              return( A[2], I*A[2], indet );
                        end;
                        if Re( A[1] ) = 0 then
                              return( I*A[2] );
                        else
                              return( A[2] );
                        end;
                  elif testtype( indet, "_plus" ) then
                        return( indet, map( op(indet), checkIndets ) );
                  else
                        return( indet );
                  end;
            end;

            /* buildTable( EXPR )
             * Description:
             * Splits EXPR (which is supposed to be a sum) in parts and groups it by the exp()-factor.
             *
             * Input:
             * EXPR: arithmetical expression
             *
             * Output:
             * T: Table with EXPR split in parts where EXPR is the sum of (exp(index)*entry) of all table entries.
             */
            buildTable := proc( EXPR )
                  local i, expsum, diffSum, sumList, power, quot,
                            summand, T, diff, j;
            begin
                  T := table();
                  if not testtype( EXPR, "_plus" ) then
                        return( table() );
                  end;

                  // go through every summand in EXPR
                  for summand in op(EXPR) do
                        if testtype( summand, "_mult" ) then
                              sumList := op( summand );
                        else
                              sumList := [ summand ];
                        end;
                        diff := null();

                        diffSum := 0;
                        quot := 1;
                        // go through every multiplicand in summand
                        expsum := 0;
                        for i from 1 to nops(sumList) do
                              power := 1;
                              j := op( sumList, i );
                              if testtype( j, "_power" ) then
                                    power := op( j, 2 );
                                    j := op( j, 1 );
                              end;
                              if testtype( j, "exp" ) then
                                    expsum := expsum + power*op(j);
                              else
                                    if power<>1 then
                                          expsum := expsum + ln(j)*power;
                                    else
                                          quot := quot * op( sumList, i );
                                    end;
                              end;
                        end;
                        if contains( T, expsum ) then
                              T[ expsum ] := T[expsum] + quot;
                        else
                              T[ expsum ] := quot;
                        end;
                  end;
                  return( T );
            end;

            // call rewriteMult recursive
            if testtype( EXPR, "_plus" ) then
                  EXPR := _plus( map( op(EXPR), rewriteMult ) );
            else
                  return( EXPR );
            end;

            T := buildTable( EXPR );

            if nops( T )<=1 then
                  return( EXPR );
            end;

            alreadyChecked := {};

            doStep := proc( EXPR )
                  local i, j, k, tmp, toCheck;
            begin
                  for i from 1 to nops(T)-1 do
                        for j from i+1 to nops(T) do
                              toCheck := { checkIndets( (op(T,[i,1])-op(T,[j,1]))/2 ) } minus alreadyChecked;
                              for k in toCheck do
                                    alreadyChecked := alreadyChecked union {k};
                                    tmp := rewriteExpToTrig( T, k );
                                    if tmp<>FAIL then
                                          return( tmp );
                                    end;
                              end;
                        end;
                  end;
                  return( FAIL );
            end;

            repeat
                  tmp := doStep( EXPR );
                  if tmp<>FAIL then
                        EXPR := tmp;
                        T := buildTable( EXPR );
                        if nops( T )<=1 then
                              return( EXPR );
                        end;
                  end;
            until tmp=FAIL end_repeat;
            return( EXPR );
      end:

      /* rewriteMult( X )
       * Description:
       * Maps rewriteSum() to any sum found in X.
       */
      rewriteMult := proc( X )
            local i, res, tmp;
      begin
            if { hold(exp), hold(_power) } intersect indets( X, All ) = {} then
                  return( X );
            end;
            if testtype( X, DOM_LIST ) then
                  return( map( X, rewriteMult ) );
            end;
            if not testtype( X, Type::Arithmetical ) then
                  return( X );
            end;

            res := 1;
            if testtype( X, "_mult" ) then
                  X := op(X);
            else
                  X := [ X ];
            end;

            for i in X do
                  if testtype( i, "_power" ) then
                        tmp := rewriteSum( op(i,1) );
                         res := ( res*_power( tmp, op(i,2) ) );
                  else
                        tmp :=rewriteSum( i );
                         res := ( res*tmp );
                  end;
            end;
            res;
      end;

      rewriteTan := proc( EXPR )
            local i, search, base, exponent, m, T, j, res;
      begin
            if testtype( EXPR, "_plus" ) then
                  EXPR := _plus( map( EXPR, rewriteTan ) )
            elif testtype( EXPR, "_mult" ) then
                  EXPR := op( EXPR );
                  res := 1;
                  T := table();
                  for i from 1 to nops(EXPR) do
                        if testtype( op(EXPR,i), "_power" ) then
                              base := op(EXPR, [i,1]);
                              exponent := op(EXPR, [i,2]);
                        else
                              base := op(EXPR,i);
                              exponent := 1;
                        end;
                        j := contains( [ "cos", "sin", "cosh", "sinh" ], type( base ) );
                        if j>0 then
                              if not contains( T, op(base) ) then
                                    T[op(base)] := [0,0,0,0];
                              end;
                              T[op(base)] [j] := T[op(base)] [j] + exponent;
                        else
                              res := res* op(EXPR,i);
                        end;
                  end;
                  for i in T do
                        if sign( op(i,[2,1]) * op(i,[2,2] ) )=-1 then
                              j := min( abs(op(i,[2,1])), abs(op(i,[2,2])) ) * sign(op(i,[2,2]));
                        else
                              j := 0;
                        end;
                        res := res * _power( cos(op(i,1)), op(i,[2,1])+j ) *_power( sin(op(i,1)), op(i,[2,2])-j );
                        if  j<0 then
                              res := res* _power( cot(op(i,1)), -j )
                        elif j>0 then
                              res := res* _power( tan(op(i,1)), j )
                        end;

                        if sign( op(i,[2,3]) * op(i,[2,4] ) )=-1 then
                              j := min( abs(op(i,[2,3])), abs(op(i,[2,4])) ) * sign(op(i,[2,4]));
                        else
                              j := 0;
                        end;
                        res := res * _power( cosh(op(i,1)), op(i,[2,3])+j ) *_power( sinh(op(i,1)), op(i,[2,4])-j );
                        if  j<0 then
                              res := res* _power( coth(op(i,1)), -j )
                        elif j>0 then
                              res := res* _power( tanh(op(i,1)), j )
                        end;
                  end;
                  EXPR := res;
            end;
            EXPR;
      end;

      X := combine( X, exp);
      X := rewriteMult( X );
      X := rewriteTan( X );
      combine(X, exp);
end:
