# This file is part of ModPipe, Copyright 1997-2010 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

"""Various methods for fold assignment"""

import modeller
import modpipe
import modpipe.profile
import modpipe.pdbutils
import modpipe.sequtils
import modpipe.sequence
import modpipe.resutils as resutils
import time
import re
import sys
import os
import copy
import tempfile
from StringIO import StringIO

class _Base(object):
    def search(self):
        """Return a list of :class:`Hit` objects."""
        pass

class Hit(object):
    """A single target-template hit detected by a fold assignment method."""

    def get_modeling_alignment(self):
        """Return a Modeller alignment object suitable for modeling."""
        target = self.target
        template = self.template
        opts = self.opts
        sub = "main::ProcessAliBPSS"
        # Trim the alignment to match aligned positions
        target.primary, template.primary = \
                            _trim_alignment(target.primary, template.primary)

        if opts.natpdb and opts.assumepdb:
            # Get PDB numbering for target using native structure
            chn = modpipe.pdbutils.fetch_PDB_chain(self.env, opts.natpdb,
                                                   opts.natchn)
            chnseq = modpipe.pdbutils.get_chain_seq(chn)
            beg, end = modpipe.sequtils.find_seq_in_seq(list(target.primary),
                                                        chnseq)
            if opts.nattyp != 'sequence':
                target.atom_file = opts.natpdb
                target.prottyp = 'structureX'
            template.target_start = \
                       resutils.parse_residue_number(chn.residues[beg].num)
            template.target_stop = \
                       resutils.parse_residue_number(chn.residues[end].num)
            target.range = [[template.target_start, opts.natchn],
                            [template.target_stop, opts.natchn]]
        else:
            target.range = [[template.target_start, ''],
                            [template.target_stop, '']]

        if opts.assumepdb:
            # Assume a PDB template and populate the atom_file and range fields
            template.atom_file = pdb_code = template.code[:4]
            pdb_chain = template.code[4:]
            template.range = [['.', pdb_chain], ['.', pdb_chain]]
        # Populate other PIR alignment fields
        target.source = template.source = opts.customtag
        target.name = target.atom_file = target.code
        template.name = template.code

        return _read_alignment(self.env, [target, template])


def _read_alignment(env, seqs):
    tmpdir = tempfile.mkdtemp()
    tmpfile = os.path.join(tmpdir, 'temp.ali')
    fh = file(tmpfile, 'w')
    parser = modpipe.sequence.PIRFile()
    for seq in seqs:
        parser.write(fh, seq)
    fh.close()

    # Read alignment with Modeller.
    # This will automatically determine the residue starting and ending
    # positions if assume_pdb is on.
    try:
        try:
            a = modeller.alignment(env, file=tmpfile, allow_alternates=True)
        except modeller.SequenceMismatchError:
            print "Failed alignment is:"
            for seq in seqs:
                parser.write(sys.stdout, seq)
            raise
    finally:
        os.unlink(tmpfile)
        os.rmdir(tmpdir)
    # Update sequence ranges to match those determined by Modeller
    for (seq, alnseq) in zip(seqs, a):
        range = alnseq.range
        seq.range[0][0] = resutils.parse_residue_number(range[0].split(':')[0])
        seq.range[1][0] = resutils.parse_residue_number(range[1].split(':')[0])
    return a

class SequenceSequence(_Base):
    """Sequence-sequence search for templates against the given database."""

    def __init__(self, env, database, opts):
        self.env, self.database, self.opts = env, database, opts

    def search(self, target_code, sequence):
        """Search for the single sequence *target_code* from the alignment file
           *sequence* in the database. Return a list of all hits."""
        prf = self._read_profile(self.env, sequence, self.opts.seqfmt)
        prf, chi2, kstat = _build_profile(self.env, prf, target_code,
                                          self.database, self.opts,
                                          "HitsSeqSeq.py")
        hits = _parse_profile(prf, target_code, chi2, kstat, self.env,
                              self.opts)
        return hits

    def _read_profile(self, env, sequence, seqfmt):
        """Read and convert input sequence into profile format"""
        aln = modeller.alignment(env, file=sequence,
                                 alignment_format=seqfmt)
        return aln.to_profile()



class ProfileSequence(_Base):
    """Profile-sequence search for templates against the given database."""

    def __init__(self, env, database, opts):
        self.env, self.database, self.opts = env, database, opts

    def search(self, target_code, prffile):
        """Search the profile of the sequence *target_code* (stored in
           *prffile*) against the database. Return a list of all hits."""
        prf = self._read_profile(self.env, prffile, self.opts.proffmt)
        prf, chi2, kstat = _build_profile(self.env, prf, target_code,
                                          self.database, self.opts,
                                          "HitsPrfSeq.py")
        hits = _parse_profile(prf, target_code, chi2, kstat, self.env,
                              self.opts)
        return hits

    def _read_profile(self, env, prffile, proffmt):
        if proffmt == 'PROFILE':
            return modeller.profile(env, file=prffile, profile_format='TEXT')
        else:
            aln = modeller.alignment(env, file=prffile,
                                     alignment_format=proffmt)
            return aln.to_profile()


def _do_build_profile(env, prf, database, opts):
    # Read in sequence database
    sdb = modeller.sequence_db(env)
    sdb.read(seq_database_file=database, seq_database_format=opts.dbfmt,
             chains_list='ALL')

    # Build the profile
    prf.build(sdb, n_prof_iterations=1,
              matrix_offset=opts.matrix_offset,
              rr_file=opts.matrix,
              gap_penalties_1d=(opts.gap_open, opts.gap_extension),
              max_aln_evalue=opts.evalue,
              output_score_file=opts.score_file,
              check_profile=False,
              pssm_weights_type=opts.pssm_weighting,
              gaps_in_target=True,
              score_statistics=(opts.score_statistics == 'ON'))

    # Write out resulting profile
    prf.write(file=opts.outfile, profile_format='TEXT')
    return opts.outfile


def _check_code_in_profile(env, prf, target_code):
    for seq in prf:
        if seq.code == target_code:
            return
    raise modpipe.Error("Cannot find target code %s in input sequence file" \
                        % target_code)

def _parse_profile_log(fh):
    n_iter = 0
    chi2 = []
    kstat = []
    div = False
    statre = re.compile('Iteration, Chi2, nbins, KS\-Stat.*\]\s+=\s+\d+,' + \
                        '\s+([\d.e+-]+),\s+\d+,\s+([\d.e+-]+)')
    fh.seek(0)
    for line in fh:
        s = statre.search(line)
        if s:
            n_iter += 1
            chi2.append(float(s.group(1)))
            kstat.append(float(s.group(2)))
        div = div and 'Profile appears to be diverging' in line
    return n_iter, chi2, kstat, div


def _build_profile(env, prf, target_code, database, opts, prefix):
    _check_code_in_profile(env, prf, target_code)
    start_time = time.time()
    save_stdout = sys.stdout
    sys.stdout = out = StringIO()
    # Need to get Modeller log output to parse
    (mod_out, mod_notes, mod_warn, mod_err, mod_mem) = \
           (modeller.log.output, modeller.log.notes, modeller.log.warnings,
            modeller.log.errors, modeller.log.memory)
    modeller.log.verbose()
    try:
        prf = _do_build_profile(env, prf, database, opts)
    finally:
        sys.stdout = save_stdout
        (modeller.log.output, modeller.log.notes, modeller.log.warnings,
         modeller.log.errors, modeller.log.memory) = \
            (mod_out, mod_notes, mod_warn, mod_err, mod_mem)
    end_time = time.time()
    n_iter, chi2, kstat, div = _parse_profile_log(out)
    if opts.verbose > 0:
        print out.getvalue()
    print prefix + "__M> MODELLER Runtime: %.2f" % (end_time - start_time)
    print prefix + "__M> No. of iterations: %d" % n_iter
    print prefix + "__M> Chi2 fit of distributions: %.4f" % min(chi2)
    print prefix + "__M> KS-Stat for the distributions: %.4f" % min(kstat)
    if div:
        print prefix + "__W> Fit b/w observed & expected distributions is bad"
    return prf, chi2, kstat

def _parse_profile(prf, target_code, chi2, kstat, env, opts):
    hits = []
    target = None
    templates = []
    parser = modpipe.profile.ProfileParser()
    for seq in parser.read(file(prf)):
        if seq.code == target_code:
            target = seq
        elif seq.prottyp.startswith('structure'):
            templates.append(seq)
    for template in templates:
        h = Hit()
        h.env = env
        h.opts = opts
        h.chi2 = min(chi2)
        h.kstat = min(kstat)
        h.target = copy.copy(target)
        h.template = template
        hits.append(h)
    return hits

def _trim_alignment(target, template):
    if len(target) != len(template):
        raise modpipe.Error("Target/Template don't match in length: %s, %s" \
                            % (target, template))
    # Get the beginning and end position of template
    template_beg = 0
    for ch in template:
        if ch == '-':
            template_beg += 1
        else:
            break
    template_end = len(template)
    for i in range(len(template) - 1, 0, -1):
        if template[i] == '-':
            template_end -= 1
        else:
            break
    target = target[template_beg:template_end]
    template = template[template_beg:template_end]

    # Remove aligned gaps
    target_ls = []
    template_ls = []
    for (a,b) in zip(target, template):
        if not (a == '-' and b == 'a'):
            target_ls.append(a)
            template_ls.append(b)
    target = "".join(target_ls)
    template = "".join(template_ls)
    return target, template
