# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function
from modpipe.miscutils import get_arch_by_uname
import modpipe.pdbutils
from modpipe.pdbutils import *
from modpipe.alnutils import *
import modpipe
import modpipe.binaries
from shutil import rmtree
import subprocess, os, re, tempfile

class Error(modpipe.Error):
    """An error raised by CE"""
    pass

class CE:
    """Defines a CE object and methods."""

    def __init__(self, modenv, pdb0, chn0, pdb1, chn1):
        self.__get_exe()
        self.pdb0 = pdb0
        self.pdb1 = pdb1
        self.chn0 = chn0
        self.chn1 = chn1
        self.modenv = modenv
        self.__prepare()

        self.output = ''
        self.error = ''
        self.comment = ''


    def __get_exe(self):
        """Get the absolute path to the CE executable."""
        self.exe = modpipe.binaries.get_ce()
        self.mkDB = modpipe.binaries.get_ce_mkdb()

    def __get_pdbs(self, tmpdir):
        """Get the locations for the pdb files."""
        return [modpipe.pdbutils.get_uncompressed_pdb(x,
                       self.modenv.io.atom_files_directory, tmpdir) \
                for x in (self.pdb0, self.pdb1)]


    def __get_reference_pdb_data(self):
        """This uses Modeller to create chain and sequence objects for
        the input structures."""
        self.modchain0 = fetch_PDB_chain(self.modenv, self.pdb0, self.chn0)
        self.modchain1 = fetch_PDB_chain(self.modenv, self.pdb1, self.chn1)

        self.modseq0 = get_chain_seq(self.modchain0)
        self.modseq1 = get_chain_seq(self.modchain1)


    def __prepare(self):
        """Prepare for CE run by setting up a few parameters."""
        self.__get_reference_pdb_data()
        self.chn0 = self.modchain0.name
        self.chn1 = self.modchain1.name
        self.code0 = self.pdb0 + self.chn0
        self.code1 = self.pdb1 + self.chn1


    def get_alignment(self):
        """Calculate CE alignment."""

        # Create a temporary directory
        tmpdir = tempfile.mkdtemp()

        pdbf0, pdbf1 = self.__get_pdbs(tmpdir)

        # Convert chain ids for CE
        chn0 = self.chn0
        if chn0 == ' ':
            chn0 = '_'
        chn1 = self.chn1
        if chn1 == ' ':
            chn1 = '_'

        # Actually, run CE
        self.output, self.error = subprocess.Popen([self.exe, "-",
                                    pdbf0, self.chn0, pdbf1, self.chn1,
                                    tmpdir, self.mkDB],
                                    stdout=subprocess.PIPE).communicate()

        # Remove the temporary directory
        rmtree(tmpdir)


    def parse_output(self):
        """This parses the output from a CE run and returns
        the alignment and some details."""

        self.cealnseq0 = []
        self.cealnseq1 = []

        # Compile necessary reg expressions
        e0 = re.compile(r'Alignment')
        e1 = re.compile(r'^Chain 1:(\s+)(\d+)\s([-\w]+)')
        e2 = re.compile(r'^Chain 2:(\s+)(\d+)\s([-\w]+)')

        for line in self.output.splitlines():
            # Generate matches
            m0 = e0.match(line)
            m1 = e1.match(line)
            m2 = e2.match(line)

            # Parse matches
            if m0:
                self.comment = line
            elif m1:
                self.cealnseq0.extend([x for x in m1.group(3)])
            elif m2:
                self.cealnseq1.extend([x for x in m2.group(3)])

            # Break condition
            if re.match(r'^\s+Z2 = ', line):
                break
        if len(self.cealnseq0) == 0 or len(self.cealnseq1) == 0:
            raise Error("CE returned an alignment of length 0")


    def read_program_output(self, file):
        """This routine reads an existing file with
        CE output."""
        f = open(file, 'r')
        self.output = ''.join(f.readlines())
        f.close()


    def write_program_output(self, file):
        """This routine writes the raw output of the CE
        program to a specified file."""
        f = open(file, 'w')
        f.write(self.output)
        f.close()


    def __patch_alignment(self):
        """This takes the sequences in the CE alignment and
        patches it to ensure Modeller compatibility."""
        self.cealnseq0 = fix_aligned_sequence(self.modenv,
                           self.cealnseq0, self.modseq0)
        self.cealnseq1 = fix_aligned_sequence(self.modenv,
                           self.cealnseq1, self.modseq1)


    def __get_pdb_bounds(self):
        """Will take current alignment and match it against the
        reference PDB chain to get residue number bounds."""
        beg, end = find_seq_in_seq(self.cealnseq0, self.modseq0)
        start, stop = fetch_PDB_num(self.modchain0, [beg, end])
        self.modaln[0].range = (start, stop)

        beg, end = find_seq_in_seq(self.cealnseq1, self.modseq1)
        start, stop = fetch_PDB_num(self.modchain1, [beg, end])
        self.modaln[1].range = (start, stop)


    def fix_alignment(self):
        """Clean up the CE alignment and return a Modeller
        alignment object."""
        self.__patch_alignment()
        self.modaln = create_modeller_alignment(self.modenv,
                        'C; '+self.comment, self.cealnseq0, self.cealnseq1,
                        self.code0, 'structureX', self.pdb0, ('', ''),
                        self.code1, 'structureX', self.pdb1, ('', ''))
        self.__get_pdb_bounds()


    def write_alignment(self, file, format='pir'):
        """Write out the current alignment to file."""
        self.modaln.write(file=file, alignment_format=format)


    def calculate_superpositions(self):
        """Calculates superpositions using the current alignment and prints
        numbers."""
        # Create superposition
        m1 = model(self.modenv, file=self.pdb0,
                   model_segment=self.modaln[0].range)
        m2 = model(self.modenv, file=self.pdb1,
                   model_segment=self.modaln[1].range)
        sel = selection(m1).only_atom_types('CA')

        # Print a header
        columns = ['code1', 'len1', 'alnlen1', 'code2', 'len2',
                   'alnlen2', 'ce_alnlen', 'ce_rmsd', 'ce_zscore',
                   'ce_numgaps', 'ce_seqid', 'mod_grmsd', 'mod_geqvp',
                   '(mod_cutoff, mod_cutoff_rmsd, mod_cutoff_eqvp)*']

        # Format and print results
        results = "> %-5s %5d %5d %-5s %5d %5d " % (self.code0,
                                          len(self.modchain0.residues),
                                          len(m1.residues),
                                          self.code1,
                                          len(self.modchain1.residues),
                                          len(m2.residues))

        # Add the CE alignment details here
        c = self.comment.split(' ')

        # Do some special processing for the percentage of gaps
        p = re.compile(r'\([\.\d%]+\)')
        results = results + "%5s %6s %5s %5s %8s" % (c[3].strip().rjust(5),
                     c[6].rstrip('A').rjust(5), c[9].strip().rjust(5),
                     p.sub('', c[12]).strip().rjust(5),
                     c[19].rstrip('%').strip().rjust(8))

        # Now add the modeller numbers
        r = sel.superpose(m2, self.modaln, rms_cutoff=3.5)
        results = results +  "%8.4f %6d %5.2f %8.4f %6d " % (r.rms,
                  r.num_equiv_pos, 3.5, r.cutoff_rms, r.num_equiv_cutoff_pos)

        cuts = [1.0, 2.0, 3.0, 4.0, 5.0, 8.0]
        for c in cuts:
            r = sel.superpose(m2, self.modaln, rms_cutoff=c)
            results = results +  "%5.2f %8.4f %6d " % ( c, r.cutoff_rms,
                                                       r.num_equiv_cutoff_pos)

        print('# ' + ' '.join(columns))
        print(results)
        return results
