#!/usr/bin/python
# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function, division
from optparse import OptionParser
import modpipe.version
import modpipe.config
import modpipe.filesystem
import modpipe.serialize
import modpipe.sequtils
import modpipe.pdbutils
import sys, os

sys.argv[0] = 'modpipe benchmark'

def get_native_sequence_range(native_mdl, search_mdl):
    """Return the range of residues in the native that match the model"""
    refseq = [x.code for x in native_mdl.residues]
    searchseq = [x.code for x in search_mdl.residues]
    start, end = modpipe.sequtils.find_seq_in_seq(searchseq, refseq)
    return native_mdl.residues[start].num, native_mdl.residues[end].num

def superpose(modfile, seqid, native, native_chain, native_align_code, pdbrep):
    """Superpose a model on its native structure, and return information
       about the native and the RMS/equivalent positions"""
    import modeller
    modeller.log.minimal()
    env = modeller.environ()
    env.io.atom_files_directory = pdbrep
    aln = modeller.alignment(env)
    native_mdl = modeller.model(env, file=native,
                                model_segment=('FIRST:%s' % native_chain,
                                               'LAST:%s' % native_chain))
    try:
        mdl = modeller.model(env, file=modfile,
                          model_segment=('FIRST', 'LAST'))
    except:
        mdl = modeller.model(env, file=modfile,
                          model_segment=('FIRST:A', 'LAST:A'))

    # Since the model may not cover the entire native chain, reread only that
    # section of the native sequence that matches the model
    start, end = get_native_sequence_range(native_mdl, mdl)
    native_mdl = modeller.model(env, file=native,
                                model_segment=('%s:%s' % (start, native_chain),
                                               '%s:%s' % (end, native_chain)))
    aln.append_model(native_mdl, atom_files=native,
                     align_codes=native_align_code)
    aln.append_model(mdl, atom_files=modfile, align_codes=seqid)
    atmsel = modeller.selection(native_mdl).only_atom_types('CA')
    output = []
    for rms_cutoff in (3.5, 1.0, 2.0, 3.0, 4.0, 5.0):
        r = atmsel.superpose(mdl, aln, rms_cutoff=rms_cutoff)
        output.append((rms_cutoff, r))
    native_range = (len(native_mdl.residues), native_mdl.residues[0].num,
                    native_mdl.residues[-1].num)
    return native_range, output

def main():
    import modeller
    from modeller.scripts import complete_pdb
    env = modeller.environ()
    env.libs.topology.read(file='$(LIB)/top_heav.lib')
    env.libs.parameters.read(file='$(LIB)/par.lib')

    parser = OptionParser(version=modpipe.version.message())

    parser.set_usage("""
 This script takes a model file and compares each model against the native.

 Run `%prog -h` for help information
""")

    parser.add_option("-c", "--conf_file", dest="conffile",
                      type="string",
                      help="""ModPipe configuration file. Cannot proceed
                           without this option.""", default=None)
    parser.add_option("-s", "--sequence_id", dest="seqid",
                      type="string",
                      help="""Sequence Id. This is the MD5 digest of the
                           sequence that has been added to the ModPipe
                           filesystem. It will not proceed without this
                           option.""", default=None)
    parser.add_option("-n", "--native", dest="native",
                      type="string",
                      help="""PDB code of the native structure. Cannot
                            proceed without this.""", default=None)
    parser.add_option("--native_chain", dest="chnat",
                      type="string",
                      help="""Chain identifier for the native. Will assume
                            chain id is ' ' if none is specified.""",
                      default=' ')
    parser.add_option("--native_acode", dest="acoden",
                      type="string",
                      help="""Align code for native structure.
                           Default: native-pdb-code + chnid""", default=None)
    pdb = modpipe.pdbutils.get_pdb_repository(include_local=True)
    parser.add_option("--pdb_repository", dest="pdbrep",
                      type="string",
                      help="""Search path for PDB files.
                           Default: """ + str(pdb),
                      default=pdb)
    parser.add_option("--output_filename", dest="outfile",
                      type="string",
                      help="""File to write output data. Default: STDOUT""",
                      default=None)
    parser.add_option("--tsvmod_file", dest="tsvmod_file",
                      type="string",
                      help="""File to write TSVMod training data. Not written if None""",
                      default=None)
    opts, args = parser.parse_args()

    # Check for configuration file
    if not opts.conffile:
        parser.error("Cannot proceed without configuration file")

    # Check for sequence MD5 hash
    if not opts.seqid:
        parser.error("Cannot proceed without sequence id")

    # Check for required options for native overlap
    if not opts.native:
        parser.error("Cannot proceed without native structure")

    # Read in the configuration file and set up filesystem
    config = modpipe.config.read_file(open(opts.conffile, 'r'))
    fs = modpipe.filesystem.FileSystem(config)

    # Set defaults
    if not opts.acoden:
        opts.acoden = (opts.native + opts.chnat).rstrip(' ')

    seqdir = fs.get_sequence_dir(opts.seqid)
    moddat = os.path.join(seqdir, "%s.mod" % opts.seqid)

    if opts.outfile:
        outfile = os.path.join(seqdir, opts.outfile)
        fh_out = open(outfile, 'w')
    else:
        fh_out = sys.stdout
    if opts.tsvmod_file:
        tsvmodfile = os.path.join(seqdir, opts.tsvmod_file)
        print("TSVMOd file", tsvmodfile)
        tsvmod_out = open(tsvmodfile,'w')
        tsvmod = True
    else:
        tsvmod = False

    append = False
    models = modpipe.serialize.read_models_file(open(moddat, 'r'))
    for model in models:
        seqid = model.sequence.id
        modid = model.id
        modfile = os.path.join(fs.get_model_dir(opts.seqid), "%s.pdb" % modid)
        if tsvmod:
            mdl = complete_pdb(env, modfile)
            mdl.write_data(file="/tmp/file", output='SSM')
            strand_res = [r for r in mdl.residues if r.atoms[0].biso == 1]
            helix_res = [r for r in mdl.residues if r.atoms[0].biso == 2]

            pct_helix = len(helix_res) * 100.0 / len(mdl.residues)
            pct_strand = len(strand_res) * 100.0 / len(mdl.residues)

        # check here whether salign should be run before superpose
        native_range, output = superpose(modfile, seqid, opts.native,
                                             opts.chnat, opts.acoden, opts.pdbrep)
        num = 0
        totrmsd = toteqvp = 0
        for (cutoff, r) in output[1:]:
            totrmsd += r.cutoff_rms
            toteqvp += r.num_equiv_cutoff_pos
            num += 1
        # Put benchmarking data into a reasonable structure
        nat = {'code': opts.native, 'chain': opts.chnat,
               'length': native_range[0], 'region': list(native_range[1:3]),
               'global_rms': output[0][1].rms,
               'num_equiv_pos_35': output[0][1].num_equiv_cutoff_pos,
               'cutoff_rms_35': output[0][1].cutoff_rms,
               'cutoff_rms': [{'cutoff': x[0], 'rms': x[1].cutoff_rms,
                               'num_equiv_pos': x[1].num_equiv_cutoff_pos} \
                              for x in output[1:]],
               'mean_cutoff_rms': totrmsd / num,
               'mean_num_equiv_pos': toteqvp // num}

        model.native_benchmark = nat
        # Write out the new model plus benchmarking data
        modpipe.serialize.write_models_file([model], fh_out, append)
        append = True

        if tsvmod:
            # print tsvmod input in Dave's format
            print(
                nat['code'] \
              + nat['chain'] + " " \
              + str(nat['region'][0]) + " " \
              + str(nat['region'][1]) + " " \
              + str(nat['length']) + " | " \
              + str(model.templates[0].code)  \
              + str(model.templates[0].chain) + " " \
              + str(model.templates[0].region[0]) + " " \
              + str(model.templates[0].region[1]) + " " \
              + str(model.templates[0].region[1]-model.templates[0].region[0]+1) + " | " \
              + str(float(model.alignment.gap_percentage)) + " " \
              + str(model.score.ga341.total) + " " \
              + str(float(model.highest_sequence_identity)) + " " \
              + str(model.score.ga341.z_distance) + " " \
              + str(model.score.ga341.z_surface_area) + " " \
              + str(model.score.ga341.z_combined) + " " \
              + str(model.score.normalized_dope) + " | " \
              + str(model.score.quality) + " | " \
              + str(nat['global_rms']) + " " \
              + str(float(nat['num_equiv_pos_35'])/float(nat['length'])) + " | " \
              + str(float(pct_helix)) + " " \
              + str(float(pct_strand)) + " ", file=tsvmod_out)

        # Write out the new model plus benchmarking data


if __name__ == '__main__':
    main()
