#!/usr/bin/python
# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function
from modeller import *
from optparse import OptionParser
from tempfile import mkstemp
from modpipe.alnutils import *
from modpipe.pdbutils import *
from modpipe.suputils import *
from modpipe.matrix2d import Matrix2D
import modpipe.version
import sys, os

def do_a3d(env, inialn, mdl0, mdl1):
    """Take a list containing one or more initial alignments and
    return the best structure alignment"""
    max_eqvpos = 0
    best_ali = ''
    for a in inialn:
        a.align3d(gap_penalties_3d=(0, 4.0), align3d_repeat=True)
        a.align3d(gap_penalties_3d=(0, 2.5), align3d_repeat=True)
        s = selection(mdl0).only_atom_types('CA')
        r = s.superpose(mdl1, a, rms_cutoff=3.5)
        if r.num_equiv_cutoff_pos > max_eqvpos:
            max_eqvpos = r.num_equiv_cutoff_pos
            best_ali = make_alignment_copy(env, a)
    return best_ali


def alignment_by_sequence(env, mdl1, code1, file1, mdl2, code2, file2):
    """Create a simple sequence alignment using blosum62"""
    a = alignment(env)
    a.append_model(mdl1, align_codes=code1, atom_files=file1)
    a.append_model(mdl2, align_codes=code2, atom_files=file2)
    a.align(rr_file='$(LIB)/blosum62.sim.mat', matrix_offset=-450,
              gap_penalties_1d=(-500, -50), local_alignment=True)
    return a


def alignment_by_resconf(env, mdl1, code1, file1, aconf,
                              mdl2, code2, file2, bconf,
                              match, mismatch, gapo, gape):
    """Create an initial alignment using residue conformational states"""
    a = alignment(env)
    a.append_model(mdl1, align_codes=code1, atom_files=file1)
    a.append_model(mdl2, align_codes=code2, atom_files=file2)

    mat = Matrix2D(len(aconf),len(bconf))
    for i in range(len(aconf)):
        for j in range(len(bconf)):
            if aconf[i] == bconf[j]:
                mat[i,j] = match
            else:
                mat[i,j] = mismatch

    # Create a temporary file
    (tmpfd, tmpfile) = mkstemp(dir=os.getcwd())
    mat.write(file=tmpfile)

    try:
        a.align(matrix_offset=0, gap_penalties_1d=(gapo, gape),
                local_alignment=True, input_weights_file=tmpfile)
    except ModellerError as detail:
        print("Failed to calculate initial alignment using residue \
               conformations: %s" % str(detail))
    os.unlink(tmpfile)

    return a


def alignment_by_ressecstr(env, mdl1, code1, file1, asecs,
                                mdl2, code2, file2, bsecs,
                                match, mismatch, gapo, gape):
    """Create an initial alignment using residue secondary structure states"""
    a = alignment(env)
    a.append_model(mdl1, align_codes=code1, atom_files=file1)
    a.append_model(mdl2, align_codes=code2, atom_files=file2)

    mat = Matrix2D(len(asecs),len(bsecs))
    for i in range(len(asecs)):
        for j in range(len(bsecs)):
            if asecs[i] == bsecs[j]:
                mat[i,j] = match
            else:
                mat[i,j] = mismatch

    # Create a temporary file
    (tmpfd, tmpfile) = mkstemp(dir=os.getcwd())
    mat.write(file=tmpfile)

    try:
        a.align(matrix_offset=0, gap_penalties_1d=(gapo, gape),
                local_alignment=True, input_weights_file=tmpfile)
    except ModellerError as detail:
        print("Failed to calculate initial alignment using residue \
               secondary structure states: %s" % str(detail))
    os.unlink(tmpfile)

    return a


def main():

    # Parse command line options
    parser = OptionParser(version=modpipe.version.message())

    # Set defaults
    parser.set_usage("""
 This script takes two PDB protein chains and aligns them using ALIGN3D.
 The various structural overlap numbers are also reported.

 Usage: %prog [options]

 Run `%prog -h` for help information
 """)

    # Populate options list
    parser.add_option("-p", "--pdb1", dest="pdb1", type='string',
                      help="""PDB code of the first structure.
                              This is a mandatory option.""",
                      metavar="PDB", default=None)
    parser.add_option("-c", "--chn1", dest="chn1", type='string',
                      help="""PDB chain identifier of the first structure.
                              If not specified it will take the first chain.""",
                      metavar="CHN", default='')
    parser.add_option("-q", "--pdb2", dest="pdb2", type='string',
                      help="""PDB code of the second structure.
                              This is a mandatory option.""",
                      metavar="PDB", default=None)
    parser.add_option("-d", "--chn2", dest="chn2", type='string',
                      help="""PDB chain identifier of the second structure.
                              If not specified it will take the first chain.""",
                      metavar="CHN", default='')
    pdb = modpipe.pdbutils.get_pdb_repository(include_local=True)
    parser.add_option("-x", "--pdb_repository",
                      dest="pdbrep",
                      type='string',
                      help="""Directory containing PDB files. The default
                              value is """ + pdb, default=pdb,
                      metavar="DIR")
    parser.add_option("-o", "--output_alignment_file",
                      dest="outalif", type='string',
                      help="""File to store the final alignment.""",
                      metavar="FILE", default='')

    # Check mandatory options
    opts, args = parser.parse_args()

    if not opts.pdb1 or not opts.pdb2:
        parser.print_help()
        sys.exit(1)

    return opts, args

def calculate_a3d_alignment(pdbrep, pdb1, chn1, pdb2, chn2, outalif,
                            ll):
    """Do the actual work of calculating the alignment"""

    # -- Initialize some modeller stuff
    log.level(ll[0], ll[1], ll[2], ll[3], ll[4])
    env = environ()
    env.io.atom_files_directory = pdbrep

    # Fetch various structure based objects
    chain1 = fetch_PDB_chain(env, pdb1, chn1)
    model1 = model(env, file=pdb1,
                   model_segment=('FIRST:'+chain1.name, 'LAST:'+chain1.name))
    rconf1 = get_list_conformation(model1)
    rsecs1 = get_secstr_list(model1)

    chain2 = fetch_PDB_chain(env, pdb2, chn2)
    model2 = model(env, file=pdb2,
                   model_segment=('FIRST:'+chain2.name, 'LAST:'+chain2.name))
    rconf2 = get_list_conformation(model2)
    rsecs2 = get_secstr_list(model2)

    # Create initial alignments by:
    # Simple sequence alignment
    ini0 = alignment_by_sequence(env, model1, pdb1+chain1.name, pdb1,
                                      model2, pdb2+chain2.name, pdb2)
    # Residue conformational states
    ini1 = alignment_by_resconf(env, model1, pdb1+chain1.name, pdb1, rconf1,
                                     model2, pdb2+chain2.name, pdb2, rconf2,
                                     4, -5, -2, -2)

    # Residue secondary structure states
    ini2 = alignment_by_ressecstr(env, model1, pdb1+chain1.name, pdb1, rsecs1,
                                       model2, pdb2+chain2.name, pdb2, rsecs2,
                                       4, -1, -5, -5)

    aln = do_a3d(env, [ini0, ini1, ini2], model1, model2)

    if outalif:
        aln.write(file=outalif, alignment_format='pir')

    # Create superposition
    header, results = calculate_pair_superposition(env, aln, chain1,
                         chain2, 'A3')
    return header, results


if __name__ == "__main__":
    opts, args = main()
    h, r = calculate_a3d_alignment(opts.pdbrep, opts.pdb1, opts.chn1,
                                   opts.pdb2, opts.chn2, opts.outalif,
                                   (1,0,0,1,0))
    print(h)
    print(r)
