#!/usr/bin/python
# This file is part of ModPipe, Copyright 1997-2010 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from optparse import OptionParser
import modpipe.version
import modpipe.config
import modpipe.filesystem
import modpipe.serialize
import modpipe.sequtils
import sys, os, re

def model_len(model):
    return int(model.region[1]) - int(model.region[0]) + 1

def get_all_models(models):
    return models

def longest_dope(models):
    """Longest model that scores well by normalized DOPE"""
    ms=[]
    for m in models:
        try:
            if m.score.normalized_dope < 0:
                ms.append(m)
        except:
            print "LONGEST_DOPE missing for seq_id %s model_id %s " % (m.sequence.id, m.id)

    ms.sort(key=model_len, reverse=True)
    return ms[0:1]

def best_tsvmod(models):
    """Model that scores well by TSVMod"""
    ms=[]
    for m in models:
        try:
            if m.score.tsvmod.predicted_no35>=0.4:
                ms.append(m)
        except:
            print "TSVMOD missing for seq_id %s model_id %s " % (m.sequence.id, m.id)

    ms.sort(key=model_len, reverse=True)
    return ms[0:1]

def longest_ga341(models):
    """Longest model that scores well by GA341"""
    ms = []
    for m in models:
        try:
            if m.score.ga341.total >= 0.7:
                ms.append(m)
        except:
            print "LONGEST_GA341 missing for seq_id %s model_id %s "% (m.sequence.id, m.id)

    ms.sort(key=model_len, reverse=True)
    return ms[0:1]

def best_seqid(models):
    """Model with the highest sequence identity"""
    ms = models[:]
    ms.sort(key=lambda a: (a.highest_sequence_identity, model_len(a)),
            reverse=True)
    return ms[0:1]

def best_dope(models):
    """Model with the best normalized DOPE score"""
    ms = []
    for m in models:
        try:
            if m.score.normalized_dope < 0:
                ms.append(m)
        except:
            print "DOPE missing for seq_id %s model_id %s "% (m.sequence.id, m.id)

    ms.sort(key=lambda a: (a.score.normalized_dope, -model_len(a)))
    return ms[0:1]

def best_ga341(models):
    """Model with the best GA341 score"""
    ms = []
    for m in models:
        try:
            if m.score.ga341.total >= 0.7:
                ms.append(m)
        except:
            print "GA341 missing for seq_id %s model_id %s "% (m.sequence.id, m.id)

    ms.sort(key=lambda a: (a.score.ga341.total, model_len(a)),
            reverse=True)
    return ms[0:1]

def best_mpqs(models):
    """Model with the best MPQS score"""
    ms=[]
    for m in models:
        try:
            if m.score.quality >= 1.0:
                ms.append(m)
        except:
            print "MPQS missing for seq_id %s model_id %s "% (m.sequence.id, m.id)

    ms.sort(key=lambda a: (a.score.quality, model_len(a)),
            reverse=True)

    return ms[0:1]

def best_input_template(models,input_template_code,input_template_chain):
    """Model with the highest sequence identity from input template"""

    ms=[]
    for m in models:
        for template in m.templates:
            input_template_code=str(input_template_code)
            input_template_chain=str(input_template_chain)
            template.code=str(template.code)
            template.chain=str(template.chain)
            if (input_template_code.capitalize() == template.code.capitalize()) and (input_template_chain.capitalize() == template.chain.capitalize()) :
                ms.append(m)
    ms.sort(key=lambda a: (a.highest_sequence_identity, model_len(a)),
            reverse=True)
    return ms[0:1]

# Mapping from user-specified method names to functions
allowed_methods = {'LONGEST_DOPE' : longest_dope,'INPUT_TEMPLATE' : best_input_template,
                   'LONGEST_GA341' : longest_ga341, 'SEQID' : best_seqid,
                   'DOPE' : best_dope, 'GA341' : best_ga341, 'MPQS' : best_mpqs,'TSVMOD' : best_tsvmod,
                   'ALL' : get_all_models}

def pick_final_models(models, methods,input_template_code,input_template_chain):
    picked_models = []
    for m in methods:
        if m == "INPUT_TEMPLATE" :
            picked_models.extend(allowed_methods[m](models,input_template_code,input_template_chain))
        else:
            picked_models.extend(allowed_methods[m](models))
    return picked_models

def compare_regions(rep, model):
    """Return True iff the two models overlap sufficiently in sequence"""
    rep_region = [int(r) for r in rep.region]
    model_region = [int(r) for r in model.region]
    (overlap, pct_overlap, nonoverlap, pct_nonoverlap, overlap_region) = \
        modpipe.sequtils.get_overlap(rep_region, model_region)

    # Regions cluster if nonoverlap region is no longer than 30 residues or
    # 30% of model's length
    return nonoverlap <= 30 and pct_nonoverlap <= 30

def cluster_models_by_region(models):
    # Sort models by length (longest first)
    models = models[:]
    models.sort(key=model_len, reverse=True)

    # Cluster models by region
    clusters = []

    while len(models) > 0:
        # Take the first sequence as the representative
        rep = models.pop(0)
        cluster = [rep]

        # Compare all other models against the current representative, and
        # add to the cluster if appropriate
        unclustered_models = []
        for model in models:
            if compare_regions(rep, model):
                cluster.append(model)
            else:
                unclustered_models.append(model)
        models = unclustered_models
        clusters.append(cluster)
    return clusters

def get_options():
    parser = OptionParser(version=modpipe.version.message())

    parser.set_usage("""
 This script selects models based on one or more quality scores.

 Run `%prog -h` for help information
""")

    parser.add_option("-c", "--conf_file", dest="conffile",
                      type="string",
                      help="""ModPipe configuration file. Cannot proceed
                           without this option.""", default=None)
    parser.add_option("--unq_file", dest="unqfile",
                      type="string",
                      help="""The file containing the sequence MD5 ids;
                           usually produced by AddSeqMP.py.
                           You either have to use this option, or --seq_id
                           for a single sequence.""",
                      default=None)
    parser.add_option("--seq_id", dest="input_seqid", type="string",
                      help="""For gathering of a single sequence only""",
                      default=None)
    parser.add_option("--gather_fast",dest="gatherfast",type="string",
                      help="""Instead of parsing the yaml output files,
                           just concatinates the mod/fin/sels/hits files,
                           assumes that the ModPipe option --final_models_by
                           has been used previously, and the .fin files are
                           present in the local sequence directories.""",
                           default=False)
    parser.add_option("--select_by_region", dest="selbyregion",
                      type="string",
                      help="""If specified, the models are clustered
                           to identify those that span different
                           regions of the target sequence. The
                           criteria for selection, specified below,
                           is then applied to each region instead of
                           the whole sequence.""", default='ON')
    parser.add_option("--output_modfile", dest="outmodfile", type="string",
                      help="""Filename for storing the models data.
                           Default: models.dat""", default='models.dat')
    parser.add_option("--output_hitfile", dest="outhitfile", type="string",
                      help="""Filename for storing the hits data.
                           Default: hits.dat""", default='hits.dat')
    parser.add_option("--output_selfile", dest="outselfile", type="string",
                      help="""Filename for storing the selected hits data.
                           Default: sels.dat""", default='sels.dat')
    parser.add_option("--output_finfile", dest="outfinfile", type="string",
                      help="""Filename for storing the final models.
                           Default: models.fin""", default='models.fin')
    parser.add_option("--local_only", dest="localflag", type="string",
                      help="""Writes only the local .fin file for
                           the input sequences, no complete
                           .mod/.hits/.sels files created""", default= False)
    parser.add_option("--final_models_by", dest="final_modby", type="string",
                      help="""Final models for each sequence can be
                           selected by one of the following methods: """ \
                           + ", ".join(allowed_methods.keys()) \
                           + """.
                           Multiple options can be specified by
                           multiple copies of the command line switch.
                           For example, "--final_models_by LONGEST_DOPE
                           --final_models_by SEQID" will return two models.
                           Default: MPQS""", action='append')
    parser.add_option("--template", dest="input_template_code",type="string",
                      help="""Needed for option INPUT_TEMPLATE
                           for inclusion of the \"best\" models
                           based on this given template. If no chain is given,
                           \"A\" is assumed.""", default=None)
    opts, args = parser.parse_args()
    if opts.final_modby is None:
        opts.final_modby = ['MPQS']
    # Allow for comma-separated options like DOPE,GA341
    modby = opts.final_modby
    opts.final_modby = []
    for meth in modby:
        if ',' in meth:
            opts.final_modby.extend([x.strip() for x in meth.split(',')])
        else:
            opts.final_modby.append(meth)
    for meth in opts.final_modby:
        if meth not in allowed_methods.keys():
            parser.error("--final_models_by (%s) must be one of %s" \
                         % (meth, ", ".join(allowed_methods.keys())))
        if meth == "INPUT_TEMPLATE":
            if opts.input_template_code is None:
                parser.error("--template must be giving for --final_models_by INPUT_TEMPLATE")
            else:
                opts.input_template_chain=opts.input_template_code[4:5]
                opts.input_template_code=opts.input_template_code[0:4]
                if (len(opts.input_template_code) != 4):
                    parser.error("--template "+opts.input_template_code+" isn't a PDB code\n")
                if opts.input_template_chain is None:
                    opts.input_template_chain="A"
        else:
            if opts.input_template_code is None:
                opts.input_template_code=""
                opts.input_template_chain=""

    if opts.selbyregion not in ('ON', 'OFF'):
        parser.error("--select_by_region (%s) must be either ON or OFF" \
                     % opts.selbyregion)

    # Check for configuration file
    if not opts.conffile:
        parser.error("Cannot proceed without configuration file")

    if opts.gatherfast == "True":
        opts.gatherfast = True
    else:
        opts.gatherfast = False
    # Check for unq file
    if opts.gatherfast and not opts.unqfile:
        parser.error("Cannot gather fast without unq file")

    if not opts.unqfile and not opts.input_seqid:
        parser.error("Cannot proceed without unq file or sequence id")

    return opts


def main():
    opts = get_options()
    # Read in the configuration file and set up filesystem
    config = modpipe.config.read_file(file(opts.conffile, 'r'))
    fs = modpipe.filesystem.FileSystem(config)
    # Open streams
    seqids=[]
    if opts.unqfile and not opts.input_seqid:
        unqfh = file(opts.unqfile, 'r')
        # Processing the UNQ file
        for line in unqfh.readlines():
            seqids.append(line.split()[0])
    else:
        seqids.append(opts.input_seqid)

    if opts.localflag is False:
        fh={}
        fh['mod'] = file(opts.outmodfile, 'w')
        fh['hit'] = file(opts.outhitfile, 'w')
        fh['sel'] = file(opts.outselfile, 'w')
        fh['fin'] = file(opts.outfinfile, 'w')
    append = False

    # Start processing the UNQ file
    tot={}
    tot['mod']=tot['hit']=tot['fin']=tot['sel']=seqcnt=0
    version={}
    version['mod']=version['hit']=version['fin']=version['sel']=False
    for seqid in seqids:
        seqcnt += 1
        seqdir = fs.get_sequence_dir(seqid)
        yfile={}
        yfile['mod'] = os.path.join(seqdir, seqid + '.mod')
        yfile['fin'] = os.path.join(seqdir, seqid + '.fin')
        yfile['hit'] = os.path.join(seqdir, seqid + '.hit')
        yfile['sel'] = fs.get_selected_file(seqid)
        gatherlog  = os.path.join(seqdir, seqid + '.gather.log')
        if opts.gatherfast is True:
            # collect the log file:
            gatherlog  = os.path.join(seqdir, seqid + '.gather.log')
            try:
                gatherfh = file(gatherlog, 'r')
                for line in gatherfh:
                    print line,
            except:
                print "GatherModMP.py__M> %s: Gather log file not found for seqid " % seqid
            for filetype in ('mod', 'hit', 'sel', 'fin'):
                try:
                    localfh=file(yfile[filetype], 'r')
                except:
                    print "GatherModMP.py__M> %s: Local file not found for seqid " % seqid
                    break
                try:
                    linecnt=0
                    for line in localfh.readlines():
                        if re.match(r'^\- \!\<ModPipeVersion\>',line):
                            if version[filetype] is False:
                                print >>fh[filetype] , line,
                            version[filetype]=True
                        else:
                            print >>fh[filetype] , line,
                            if re.match(r'^\- \!\<',line):
                                tot[filetype] += 1
                except IOError:
                    print "GatherModMP.py__M> %s: Local file not found for seqid " % seqid

        else:
            localfinfh = file(yfile['fin'], 'w')
            gatherfh = file(gatherlog, 'w')

        # Get the contents of the models file
            try:
                models = list(modpipe.serialize.read_models_file(file(yfile['mod'])))
            except :
                models = []
                print "GatherModMP.py__W>  error in mod file (or missing or empty mod file) for %s " % seqid
                print >>gatherfh, "GatherModMP.py__W>  error in mod file (or missing or empty mod file) for %s " % seqid

            # Get the contents of the hits file
            try:
                hits = list(modpipe.serialize.read_hits_file(file(yfile['hit'])))
            except IOError:
                hits = []

            # Get the contents of the sel file
            try:
                selhits = list(modpipe.serialize.read_hits_file(file(yfile['sel'])))
            except IOError:
                selhits = []

            # Cluster models for overlapping regions
            if opts.selbyregion == 'ON':
                regions = cluster_models_by_region(models)
            else:
                regions = [models]

            # Select models by criteria, avoiding duplicates
            final = []
            for region in regions:
                for model in pick_final_models(region,opts.final_modby,
                              opts.input_template_code,opts.input_template_chain):
                    if model not in final:
                        final.append(model)

            # Report the number of models found
            print "GatherModMP.py__M>  %s: " % seqid + \
                  "REGIONS: %5d HITS: %5d SEL: %5d MODELS: %5d FINAL: %5d" \
                  % (len(regions), len(hits), len(selhits), len(models), len(final))
            print >>gatherfh, "GatherModMP.py__M>  %s: " % seqid + \
                  "REGIONS: %5d HITS: %5d SEL: %5d MODELS: %5d FINAL: %5d" \
                  % (len(regions), len(hits), len(selhits), len(models), len(final))

            tot['mod'] += len(models)
            tot['hit'] += len(hits)
            tot['sel'] += len(selhits)
            tot['fin'] += len(final)

            # Write out the contents to outfile
            if opts.localflag is False:
                modpipe.serialize.write_models_file(models, fh['mod'], append=append)
                modpipe.serialize.write_models_file(final, fh['fin'], append=append)
                modpipe.serialize.write_hits_file(hits, fh['hit'], append=append)
                modpipe.serialize.write_hits_file(selhits, fh['sel'], append=append)
            modpipe.serialize.write_models_file(final, localfinfh, append= False)
            append = True

    # Report final statement
    print "GatherModMP.py__M> Gathered data for: " \
          + "%8d sequences %8d hits, %8d selected hits, " \
            % (seqcnt, tot['hit'], tot['sel']) \
          + "%8d models and %8d were selected" % (tot['mod'], tot['fin'])


if __name__ == '__main__':
    main()
