// Copyright (C) 2009  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
// This code was adapted from code from the JAMA part of NIST's TNT library.
//    See: http://math.nist.gov/tnt/ 
#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H
#define DLIB_MATRIX_LU_DECOMPOSITION_H
#include "matrix.h" 
#include "matrix_utilities.h"
#include "matrix_subexp.h"
#include "matrix_trsm.h"
#include <algorithm>
#ifdef DLIB_USE_LAPACK 
#include "lapack/getrf.h"
#endif
namespace dlib 
{
    template <
        typename matrix_exp_type
        >
    class lu_decomposition
    {
    public:
        const static long NR = matrix_exp_type::NR;
        const static long NC = matrix_exp_type::NC;
        typedef typename matrix_exp_type::type type;
        typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
        typedef typename matrix_exp_type::layout_type layout_type;
        typedef matrix<type,0,0,mem_manager_type,layout_type>  matrix_type;
        typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
        typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
        // You have supplied an invalid type of matrix_exp_type.  You have
        // to use this object with matrices that contain float or double type data.
        COMPILE_TIME_ASSERT((is_same_type<float, type>::value || 
                             is_same_type<double, type>::value ));
        template <typename EXP>
        lu_decomposition (
            const matrix_exp<EXP> &A
        );
        bool is_square (
        ) const;
        bool is_singular (
        ) const;
        long nr(
        ) const;
        long nc(
        ) const;
        const matrix_type get_l (
        ) const; 
        const matrix_type get_u (
        ) const;
        const pivot_column_vector_type& get_pivot (
        ) const;
        type det (
        ) const;
        template <typename EXP>
        const matrix_type solve (
            const matrix_exp<EXP> &B
        ) const;
    private:
        /* Array for internal storage of decomposition.  */
        matrix<type,0,0,mem_manager_type,column_major_layout>  LU;
        long m, n, pivsign; 
        pivot_column_vector_type piv;
    }; 
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
//                              Public member functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    template <typename EXP>
    lu_decomposition<matrix_exp_type>::
    lu_decomposition (
        const matrix_exp<EXP>& A
    ) : 
        LU(A),
        m(A.nr()),
        n(A.nc())
    {
        using namespace std;
        using std::abs;
        COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
        // make sure requires clause is not broken
        DLIB_ASSERT(A.size() > 0,
            "\tlu_decomposition::lu_decomposition(A)"
            << "\n\tInvalid inputs were given to this function"
            << "\n\tA.size(): " << A.size()
            << "\n\tthis:     " << this
            );
#ifdef DLIB_USE_LAPACK
        matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
        lapack::getrf(LU, piv_temp);
        pivsign = 1;
        // Turn the piv_temp vector into a more useful form.  This way we will have the identity
        // rowm(A,piv) == L*U.  The permutation vector that comes out of LAPACK is somewhat
        // different.
        piv = trans(range(0,m-1));
        for (long i = 0; i < piv_temp.size(); ++i)
        {
            // -1 because FORTRAN is indexed starting with 1 instead of 0
            if (piv(piv_temp(i)-1) != piv(i))
            {
                std::swap(piv(i), piv(piv_temp(i)-1));
                pivsign = -pivsign;
            }
        }
#else
        // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
        piv = trans(range(0,m-1));
        pivsign = 1;
        column_vector_type LUcolj(m);
        // Outer loop.
        for (long j = 0; j < n; j++) 
        {
            // Make a copy of the j-th column to localize references.
            LUcolj = colm(LU,j);
            // Apply previous transformations.
            for (long i = 0; i < m; i++) 
            {
                // Most of the time is spent in the following dot product.
                const long kmax = std::min(i,j);
                type s;
                if (kmax > 0)
                    s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax);
                else 
                    s = 0;
                LU(i,j) = LUcolj(i) -= s;
            }
            // Find pivot and exchange if necessary.
            long p = j;
            for (long i = j+1; i < m; i++) 
            {
                if (abs(LUcolj(i)) > abs(LUcolj(p))) 
                {
                    p = i;
                }
            }
            if (p != j) 
            {
                long k=0;
                for (k = 0; k < n; k++) 
                {
                    type t = LU(p,k); 
                    LU(p,k) = LU(j,k); 
                    LU(j,k) = t;
                }
                k = piv(p); 
                piv(p) = piv(j); 
                piv(j) = k;
                pivsign = -pivsign;
            }
            // Compute multipliers.
            if ((j < m) && (LU(j,j) != 0.0)) 
            {
                for (long i = j+1; i < m; i++) 
                {
                    LU(i,j) /= LU(j,j);
                }
            }
        }
#endif
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    bool lu_decomposition<matrix_exp_type>::
    is_square (
    ) const
    {
        return m == n;
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    long lu_decomposition<matrix_exp_type>::
    nr (
    ) const
    {
        return m;
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    long lu_decomposition<matrix_exp_type>::
    nc (
    ) const
    {
        return n;
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    bool lu_decomposition<matrix_exp_type>::
    is_singular (
    ) const
    {
        /* Is the matrix singular?
          if upper triangular factor U (and hence A) is singular, false otherwise.
        */
        // make sure requires clause is not broken
        DLIB_ASSERT(is_square() == true,
            "\tbool lu_decomposition::is_singular()"
            << "\n\tYou can only use this on square matrices"
            << "\n\tthis: " << this
            );
        type max_val, min_val;
        find_min_and_max (abs(diag(LU)), min_val, max_val);
        type eps = max_val;
        if (eps != 0)
            eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10;
        else
            eps = 1;  // there is no max so just use 1
        return min_val < eps;
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
    get_l (
    ) const
    {
        if (LU.nr() >= LU.nc())
            return lowerm(LU,1.0);
        else
            return lowerm(subm(LU,0,0,m,m), 1.0);
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
    get_u (
    ) const 
    {
        if (LU.nr() >= LU.nc())
            return upperm(subm(LU,0,0,n,n));
        else
            return upperm(LU);
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>::
    get_pivot (
    ) const
    {
        return piv;
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>::
    det (
    ) const
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(is_square() == true,
            "\ttype lu_decomposition::det()"
            << "\n\tYou can only use this on square matrices"
            << "\n\tthis: " << this
            );
        // Check if it is singular and if it is just return 0.  
        // We want to do this because a prod() operation can easily
        // overcome a single diagonal element that is effectively 0 when
        // LU is a big enough matrix.
        if (is_singular())
            return 0;
        return prod(diag(LU))*static_cast<type>(pivsign);
    }
// ----------------------------------------------------------------------------------------
    template <typename matrix_exp_type>
    template <typename EXP>
    const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
    solve (
        const matrix_exp<EXP> &B
    ) const
    {
        COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
        // make sure requires clause is not broken
        DLIB_ASSERT(is_square() == true && B.nr() == nr(),
            "\ttype lu_decomposition::solve()"
            << "\n\tInvalid arguments to this function"
            << "\n\tis_square():   " << (is_square()? "true":"false" )
            << "\n\tB.nr():        " << B.nr() 
            << "\n\tnr():          " << nr() 
            << "\n\tthis:          " << this
            );
        // Copy right hand side with pivoting
        matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
        using namespace blas_bindings;
        // Solve L*Y = B(piv,:)
        triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
        // Solve U*X = Y;
        triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
        return X;
    }
// ----------------------------------------------------------------------------------------
} 
#endif // DLIB_MATRIX_LU_DECOMPOSITION_H