#!/usr/bin/python
# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function
from modeller import *
from optparse import OptionParser
from modpipe.alnutils import *
import modpipe.pdbutils
import modpipe.version
import sys, os

def doSALIGN(aln, fw, ogp1d, egp1d, ogp3d, egp3d):
    """ This routine takes in user specified feature weights and
    gap_penalties_1d and returns an alignment."""

    nseg = 2
    q = None

    q =  aln.salign(rms_cutoff=3.5,
         normalize_pp_scores=False,
         rr_file='$(LIB)/as1.sim.mat', overhang=0,
         auto_overhang=True, overhang_auto_limit=5, overhang_factor=1,
         gap_penalties_1d=(ogp1d, egp1d),
         local_alignment=False, matrix_offset = -0.2,
         gap_penalties_3d=(ogp3d, egp3d), gap_gap_score=0, gap_residue_score=0,
         dendrogram_file='salign.tree', alignment_type='tree',
         nsegm=nseg, feature_weights=fw,
         improve_alignment=True, fit=True, write_fit=False ,
         output='ALIGNMENT QUALITY' )

    return q


def frange(start, end=None, step=None):
    """This routine is similar to the 'range' function but returns
    a list of floating point numbers instead of integers."""

    if end == None:
        end = start + 0.0
        start = 0.0

    if step == None:
        step = 1.0

    count = int((end - start)/step)
    if start + (count*step) != end:
        count += 1

    L = [None,]*count
    for i in range(count):
        L[i] = start + i*step

    return L


def main():

    # Parse command line options
    parser = OptionParser(version=modpipe.version.message())

    # Set defaults
    parser.set_usage("""
 This script takes two PDB protein chains and aligns them using SALIGN.
 The various structural overlap numbers are also reported.

 Usage: %prog [options]

 Run `%prog -h` for help information
 """)

    # Populate options list
    parser.add_option("-p", "--pdb1", dest="pdb1", type='string',
                      help="""PDB code of the first structure.
                              This is a mandatory option.""",
                      metavar="PDB", default=None)
    parser.add_option("-c", "--chn1", dest="chn1", type='string',
                      help="""PDB chain identifier of the first structure.
                              If not specified it will take the first chain.""",
                      metavar="CHN", default='')
    parser.add_option("-q", "--pdb2", dest="pdb2", type='string',
                      help="""PDB code of the second structure.
                              This is a mandatory option.""",
                      metavar="PDB", default=None)
    parser.add_option("-d", "--chn2", dest="chn2", type='string',
                      help="""PDB chain identifier of the second structure.
                              If not specified it will take the first chain.""",
                      metavar="CHN", default='')
    pdb = modpipe.pdbutils.get_pdb_repository(include_local=True)
    parser.add_option("-x", "--pdb_repository",
                      dest="pdbrep", type='string',
                      help="""Directories containing PDB files. The default
                              value is """ + pdb, default=pdb,
                      metavar="DIR")
    parser.add_option("-o", "--output_alignment_file",
                      dest="outalif", type='string',
                      help="""File to store the final alignment from SALIGN.""",
                      metavar="FILE", default=None)

    # Check mandatory options
    opts, args = parser.parse_args()

    if not opts.pdb1 or not opts.pdb2:
        parser.print_help()
        sys.exit(1)

    # Search for PDB in specified path
    pdbf1 = modpipe.pdbutils.locate_PDB(opts.pdb1, opts.pdbrep)
    pdbf2 = modpipe.pdbutils.locate_PDB(opts.pdb2, opts.pdbrep)

    # -- Initialize some modeller stuff
    log.minimal()
    env = Environ()
    env.io.atom_files_directory = opts.pdbrep

    # Fetch the PDB chain objects for the two structures
    chain1 = modpipe.pdbutils.fetch_PDB_chain(env, opts.pdb1, opts.chn1)
    chain2 = modpipe.pdbutils.fetch_PDB_chain(env, opts.pdb2, opts.chn2)

    # Create an empty model & alignment objects
    mdl = Model(env)
    aln = Alignment(env)

    # Read in the first structure and add it to the alignment
    mdl.read(file=opts.pdb1,
             model_segment=('FIRST:'+chain1.name, 'LAST:'+chain1.name))
    aln.append_model(mdl, align_codes=opts.pdb1+chain1.name,
                     atom_files=opts.pdb1)

    # Read in the second structure and add it to the alignment
    mdl.read(file=opts.pdb2,
             model_segment=('FIRST:'+chain2.name, 'LAST:'+chain2.name))
    aln.append_model(mdl, align_codes=opts.pdb2+chain2.name,
                     atom_files=opts.pdb2)

    # Initialize some variables
    qmax = 0.0
    fw1 = (1., 0., 0., 0., 1., 0.)
    fw2 = (0., 1., 0., 0., 0., 0.)
    fw3 = (0., 0., 0., 0., 1., 0.)

    # Make copies of alignment
    ini_aln = make_alignment_copy(env, aln)
    ref_aln = make_alignment_copy(env, aln)

    # Iterate over 1D gap penalties to get initial alignment
    for ogp1d in frange(-150, 1, 50):
        for egp1d in frange(-50, 1, 50):
            try:
                aln = make_alignment_copy(env, ref_aln)
                q = doSALIGN(aln, fw1, ogp1d, egp1d, 0, 3)
                if q.qscorepct >= qmax:
                    qmax = q.qscorepct
                    ref_aln = make_alignment_copy(env, aln)
                print("Qlty scrs [%8.2f %8.2f]: %8.2f" % (ogp1d, egp1d, q.qscorepct))
            except ModellerError as detail:
                print("Parameter set [%8.2f %8.2f] resulted in the following error: %s" % \
                   (ogp1d, egp1d, str(detail)))

    # Iterate over 3D gap penalties to get final alignment
    for ogp3d in frange(0, 3, 1) :
        for egp3d in range (2, 5, 1) :
            try:
                aln = make_alignment_copy(env, ref_aln)
                q = doSALIGN(aln, fw2, ogp3d, egp3d, ogp3d, egp3d)
                if q.qscorepct >= qmax:
                    qmax = q.qscorepct
                    ref_aln = make_alignment_copy(env, aln)
                print("Qlty scrs [%8.2f %8.2f]: %8.2f" % (ogp3d, egp3d, q.qscorepct))
            except ModellerError as detail:
                print("Parameter set [%8.2f %8.2f] resulted in the following error: %s" % \
                   (ogp3d, egp3d, str(detail)))

    # Do some extra trials
    if ( qmax <= 70 ):
        for ogp1d in frange(0.0, 2.2, 0.3):
            for egp1d in frange(0.1, 2.3, 0.3):
                try:
                    aln = make_alignment_copy(env, ini_aln)
                    q = doSALIGN(aln, fw3, ogp1d, egp1d, 0, 3)
                    if q.qscorepct >= qmax:
                        qmax = q.qscorepct
                        ref_aln = make_alignment_copy(env, aln)
                    print("Qlty scrs [%8.2f %8.2f]: %8.2f" % (ogp1d, egp1d, q.qscorepct))
                except ModellerError as detail:
                    print("Parameter set [%8.2f %8.2f] resulted in the following error: %s" % \
                       (ogp1d, egp1d, str(detail)))

        # Try some 3D parameters
        for ogp3d in frange(0,3,1) :
            for egp3d in range (2,5,1) :
                try:
                    aln = make_alignment_copy(env, ref_aln)
                    q = doSALIGN(aln, fw2, ogp3d, egp3d, ogp3d, egp3d)
                    if q.qscorepct >= qmax:
                        qmax = q.qscorepct
                        ref_aln = make_alignment_copy(env, aln)
                    print("Qlty scrs [%8.2f %8.2f]: %8.2f" % (ogp3d, egp3d, q.qscorepct))
                except ModellerError as detail:
                    print("Parameter set [%8.2f %8.2f] resulted in the following error: %s" % \
                       (ogp3d, egp3d, str(detail)))

    # Print final max quality
    print("Final quality measure: %8.2f" % qmax)

    # Reassign the best alignment to the aln object
    aln = make_alignment_copy(env, ref_aln)

    if opts.outalif:
        aln.write(file=opts.outalif, alignment_format='pir')

    # Create superposition
    m1 = Model(env, file=opts.pdb1,
                model_segment=aln[0].range)
    m2 = Model(env, file=opts.pdb2,
                model_segment=aln[1].range)
    sel = Selection(m1).only_atom_types('CA')

    # Print a header
    columns = ['code1', 'len1', 'alnlen1', 'code2', 'len2',
               'alnlen2', 'mod_grmsd', 'mod_geqvp',
               '(mod_cutoff, mod_cutoff_rmsd, mod_cutoff_eqvp)*']

    # Format and print results
    results = "SA> %-5s %5d %5d %-5s %5d %5d " % (opts.pdb1+chain1.name,
                                      len(chain1.residues),
                                      len(m1.residues),
                                      opts.pdb2+chain2.name,
                                      len(chain2.residues),
                                      len(m2.residues))

    # Now add the modeller numbers
    r = sel.superpose(m2, aln, rms_cutoff=3.5)
    results = results +  "%8.4f %6d %5.2f %8.4f %6d " % (r.rms,
              r.num_equiv_pos, 3.5, r.cutoff_rms, r.num_equiv_cutoff_pos)

    cuts = [1.0, 2.0, 3.0, 4.0, 5.0, 8.0]
    for c in cuts:
        r = sel.superpose(m2, aln, rms_cutoff=c)
        results = results +  "%5.2f %8.4f %6d " % ( c, r.cutoff_rms,
                                                   r.num_equiv_cutoff_pos)

    print('# ' + ' '.join(columns))
    print(results)


if __name__ == "__main__":
    main()
