"""

⭕ To access the source code, click on the [source] button at the right
side or click on
:download:`[mass_matrices_trace.py]</contents/LIBRARY/ptc/mathischeap_ptc/mass_matrices_trace.py>`.
Dependence may exist. In case of error, check import and install required
packages or download required scripts. © mathischeap.com

"""

import numpy as np
from quadrature import Gauss
from mimetic_basis_polynomials_2d import grid2d, MimeticBasisPolynomials2D
from scipy import sparse as spspa
from coordinate_transformation_surface import extract_surface_coordinate_transformations_of




class MassMatricesTrace(object):
    """ The mass matrices of trace spaces
    :math:`\\text{TN}_{N}(\\partial\\Omega)`,
    :math:`\\text{TE}_{N-1}(\\partial\\Omega)` and
    :math:`\\text{TF}_{N-1}(\\partial\\Omega)`.

    :param bf: A :class:`MimeticBasisPolynomials`
        (not :class:`MimeticBasisPolynomials2D`) instance.
    :type bf: MimeticBasisPolynomials
    :param ct: A CoordinateTransformation instance that represents the
        mapping :math:`\\Phi:\\Omega_{\mathrm{ref}}\\to\\Omega`.
    :type ct: CoordinateTransformation
    :param quad_degree: (default: ``None``) The degree used for the
        numerical integral. It should be a list or tuple of three
        positive integers. If it is ``None``, a suitable degree will be
        obtained from ``bf``.
    :type quad_degree: list, tuple

    :example:

    >>> from coordinate_transformation import CoordinateTransformation
    >>> from coordinate_transformation import Phi, d_Phi
    >>> from mimetic_basis_polynomials import MimeticBasisPolynomials
    >>> ct = CoordinateTransformation(Phi, d_Phi)
    >>> bf = MimeticBasisPolynomials('Lobatto-4', 'Lobatto-3', 'Lobatto-2')
    >>> MMT = MassMatricesTrace(bf, ct)
    >>> MMT.TF.shape # doctest: +ELLIPSIS
    (52, 52)...

    """
    def __init__(self, bf, ct, quad_degree=None):
        assert bf.__class__.__name__ == 'MimeticBasisPolynomials'
        assert ct.__class__.__name__ == 'CoordinateTransformation'
        self.bf = bf
        self.ct = ct
        if quad_degree is None:
            bf_N = bf.degree
            quad_degree = [bf_N[i]+2 for i in range(3)]
        else:
            pass
        self.quad_degree = quad_degree
        qn_0, qw_0 = Gauss(quad_degree[0])
        qn_1, qw_1 = Gauss(quad_degree[1])
        qn_2, qw_2 = Gauss(quad_degree[2])
        self.quad_weights = {'NS': np.kron(qw_2, qw_1),
                             'WE': np.kron(qw_2, qw_0),
                             'BF': np.kron(qw_1, qw_0)}
        self.quad_nodes = {'NS': [qn_1, qn_2],
                           'WE': [qn_0, qn_2],
                           'BF': [qn_0, qn_1],}
        self._rho_tau_ = {'NS':grid2d(qn_1, qn_2),
                          'WE':grid2d(qn_0, qn_2),
                          'BF':grid2d(qn_0, qn_1)}
        nodes = self.bf.nodes
        self.bft = {'NS':MimeticBasisPolynomials2D(nodes[1], nodes[2]),
                    'WE':MimeticBasisPolynomials2D(nodes[0], nodes[2]),
                    'BF':MimeticBasisPolynomials2D(nodes[0], nodes[1])}

    @property
    def TN(self):
        """The mass matrix :math:`\\mathbb{M}_{\\text{N}}`.

        :return: Return a csc_matrix representing the mass matrix.
        """
        raise NotImplementedError("Could you code it?")

    @property
    def TE(self):
        """The mass matrix :math:`\\mathbb{M}_{\\text{E}}`.

        :return: Return a csc_matrix representing the mass matrix.
        """
        raise NotImplementedError("Could you code it?")

    @property
    def TF(self):
        """The mass matrix :math:`\\mathbb{M}_{\\text{F}}`.

        :return: Return a csc_matrix representing the mass matrix.
        """
        N, S, W, E, B, F = extract_surface_coordinate_transformations_of(self.ct)

        rho, tau = self.quad_nodes['NS']
        bf = self.bft['NS'].face_polynomials(rho, tau)
        # North side
        g = N.metric(*self._rho_tau_['NS'])
        M_N = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['NS'],
                        optimize='greedy')
        # South side
        g = S.metric(*self._rho_tau_['NS'])
        M_S = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['NS'],
                        optimize='greedy')

        rho, tau = self.quad_nodes['WE']
        bf = self.bft['WE'].face_polynomials(rho, tau)
        # West side
        g = W.metric(*self._rho_tau_['WE'])
        M_W = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['WE'],
                        optimize='greedy')
        # East side
        g = E.metric(*self._rho_tau_['WE'])
        M_E = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['WE'],
                        optimize='greedy')

        rho, tau = self.quad_nodes['BF']
        bf = self.bft['BF'].face_polynomials(rho, tau)
        # Back side
        g = B.metric(*self._rho_tau_['BF'])
        M_B = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['BF'],
                        optimize='greedy')
        # Front side
        g = F.metric(*self._rho_tau_['BF'])
        M_F = np.einsum('im, jm, m -> ij',
                        bf, bf, np.reciprocal(np.sqrt(g)) * self.quad_weights['BF'],
                        optimize='greedy')

        M_N = spspa.csc_matrix(M_N)
        M_S = spspa.csc_matrix(M_S)
        M_W = spspa.csc_matrix(M_W)
        M_E = spspa.csc_matrix(M_E)
        M_B = spspa.csc_matrix(M_B)
        M_F = spspa.csc_matrix(M_F)

        M = spspa.bmat([(M_N , None, None, None, None, None),
                        (None, M_S , None, None, None, None),
                        (None, None, M_W , None, None, None),
                        (None, None, None, M_E , None, None),
                        (None, None, None, None, M_B , None),
                        (None, None, None, None, None, M_F )], format='csc')

        return M


if __name__ == '__main__':
    import doctest
    doctest.testmod()

    from coordinate_transformation import Phi, d_Phi
    from coordinate_transformation import CoordinateTransformation
    from mimetic_basis_polynomials import MimeticBasisPolynomials
    ct = CoordinateTransformation(Phi, d_Phi)
    bf = MimeticBasisPolynomials('Lobatto-4', 'Lobatto-3', 'Lobatto-2')
    MMT = MassMatricesTrace(bf, ct)
    M2 = MMT.TF

    print(M2)

