# Note that this code needs Python 2.6 or later, unlike much of Modeller
# which works as far back as 2.3, so imports need to be protected by a
# version check

import modeller.features
import modeller.forms
import _modeller
import string


class _AsymIDs(object):
    """Map indices to multi-character asym (chain) IDs.
       We label the first 26 chains A-Z, then we move to two-letter
       chain IDs: AA through AZ, then BA through BZ, through to ZZ.
       This continues with longer chain IDs."""
    def __getitem__(self, ind):
        chars = string.ascii_uppercase
        lc = len(chars)
        ids = []
        while ind >= lc:
            ids.append(chars[ind % lc])
            ind = ind // lc - 1
        ids.append(chars[ind])
        return "".join(reversed(ids))


class _CifAtom(object):
    """Map a Modeller Atom to mmCIF ids"""
    def __init__(self, atom, asym_map):
        self._atom = atom
        self.asym_id, self.entity_id, self.seq_id = asym_map.map_residue(
            atom.residue)
    # Evaluate only when needed
    atom_id = property(lambda self: self._atom.name)
    comp_id = property(lambda self: self._atom.residue.name)


class TemplateChain(object):
    """A single chain of a template"""
    pass


class TargetChain(object):
    """A single chain in the target"""
    pass


class Alignment(object):
    """A multiple alignment between one TargetChain and one or
       more TemplateChains"""
    pass


class Data(object):
    """Some data used in the modeling."""
    pass


class TemplateData(Data):
    """Template data used in the modeling."""
    content_type = 'template structure'


class AlignmentData(Data):
    """Alignment data used in the modeling."""
    content_type = 'target-template alignment'


class RestraintData(Data):
    """Restraint data used in the modeling."""
    content_type = 'spatial restraints'


class TargetData(Data):
    """Target sequence"""
    content_type = 'target'


class ModelData(Data):
    """Final model output."""
    content_type = 'model coordinates'


class Restraints(object):
    """A set of restraints used in the modeling."""
    def _get_atom_keywords(self):
        """Get the mmCIF keywords for the atoms in this restraint type"""
        kws = []
        for i in range(1, self._num_atoms + 1):
            kws.extend(('entity_id_%d' % i, 'asym_id_%d' % i, 'seq_id_%d' % i,
                        'comp_id_%d' % i, 'atom_id_%d' % i))
        return kws

    def _get_granularity_keyword(self):
        # Only distance restraints have a granularity in the dictionary
        return []


class DistanceRestraints(Restraints, list):
    """A list of distance restraints used in the modeling."""
    _num_atoms = 2
    restraint_type = 'distance restraint'
    category = '_ma_distance_restraints'

    def _get_granularity_keyword(self):
        return ['granularity']


class AngleRestraints(Restraints, list):
    """A list of angle restraints used in the modeling."""
    _num_atoms = 3
    restraint_type = 'angle restraint'
    category = '_ma_angle_restraints'


class DihedralRestraints(Restraints, list):
    """A list of dihedral restraints used in the modeling."""
    _num_atoms = 4
    restraint_type = 'dihedral restraint'
    category = '_ma_dihedral_restraints'


class _AsymMap(object):
    def __init__(self, mdl):
        self.mdl = mdl
        self._asym_from_int_id = _AsymIDs()
        _modeller.mod_model_make_mmcif_asyms(mdl.modpt, mdl.env.libs.modpt)

    def map_residue(self, res):
        """Map a Residue object to asym, entity, sequence IDs"""
        int_asym_id, entity_id, seq_id = _modeller.mod_model_get_residue_asym(
            self.mdl.modpt, res._num)
        return self._asym_from_int_id[int_asym_id], entity_id, seq_id

    def map_chain(self, nchain):
        """Map a chain index to polymer asym, entity IDs.
           If the given chain does not contain a polymer, return None for
           each."""
        int_asym_id, entity_id = _modeller.mod_model_get_chain_asym(
            self.mdl.modpt, nchain)
        if entity_id == 0:
            return None, None
        else:
            return self._asym_from_int_id[int_asym_id], entity_id


class CifData(object):
    def __init__(self, mdl):
        self.data = []
        self.mdl = mdl
        _modeller.mod_model_clear_mmcif_file_data(self.mdl.modpt)
        self._asym_map = _AsymMap(mdl)

    def add_alignment_info(self, aln, knowns, seq):
        self._add_target_chains(aln[seq], self.mdl)
        self._add_template_chains(aln, knowns, aln[seq])
        self._add_alignments(aln)

    def add_restraints(self, user_restraints):
        self.restraints = []
        # Only handle single-feature Gaussian-type restraints for now
        gauss_restraints = [r for r in user_restraints
                            if isinstance(r, (modeller.forms.Gaussian,
                                              modeller.forms.UpperBound,
                                              modeller.forms.LowerBound))]
        for rsrcls, feat in [
                (DistanceRestraints, modeller.features.Distance),
                (AngleRestraints, modeller.features.Angle),
                (DihedralRestraints, modeller.features.Dihedral)]:
            rsr = rsrcls(r for r in gauss_restraints
                         if isinstance(r._features, feat))
            if len(rsr) > 0:
                self.restraints.append(rsr)
                rsr.id = len(self.restraints)

        if self.restraints:
            # add Data for restraints
            d = RestraintData()
            d.name = 'User-provided restraints'
            self._add_data(d, True)
            for rs in self.restraints:
                rs.data = d

    def _add_data(self, data, is_input):
        """Add some data used in the modeling."""
        self.data.append(data)
        data.id = len(self.data)
        # Register data type with the lower-level Fortran code, as that
        # writes out the ma_data mmCIF table and other tables that need
        # data IDs
        _modeller.mod_model_add_mmcif_file_data(self.mdl.modpt, data.name,
                                                data.content_type, is_input)

    def _get_gapped_sequence(self, chain):
        """Return a sequence including gaps for the given chain"""
        # todo: this is not quite right because leading gaps for the N terminus
        # will include trailing gaps of the C terminus of the previous chain
        def rescode_with_leading_gap(residue):
            return '-' * residue.get_leading_gaps() + residue.code
        return ''.join(rescode_with_leading_gap(r) for r in chain.residues)

    def _add_alignments(self, aln):
        """Add an Alignment for each TargetChain."""
        # add Data for this alignment
        d = AlignmentData()
        d.name = 'Target Template Alignment'
        self._add_data(d, True)

        target_to_alignment = {}
        self.alignments = []
        for target in self.target_chains:
            # Skip non-polymer chains
            if target is None:
                continue
            a = Alignment()
            target_to_alignment[target] = a
            a.data = d
            a.target = target
            a.templates = []
            a.id = target.id
            self.alignments.append(a)
        # Add template information
        for template in self.template_chains:
            a = target_to_alignment[template.target_chain]
            a.templates.append(template)

    def _get_target_chain(self, chain, target):
        """Get the TargetChain object that aligns with this chain"""
        # We just return the first match. This will miss cases where a template
        # chain aligns with multiple target chains, but this isn't handled
        # by the dictionary anyway (and is usually a modeling error)
        for r in chain.residues:
            target_r = r.get_aligned_residue(target)
            if target_r:
                # Modeller chain index is 1-based
                return self.target_chains[target_r.chain.index - 1]

    def _add_template_chains(self, aln, knowns, target):
        ordinal = 1
        self.template_chains = []
        for k in knowns:
            # Add Data for this template
            d = TemplateData()
            d.name = 'Template Structure'
            self._add_data(d, True)

            seq = aln[k]
            for chain in seq.chains:
                target_chain = self._get_target_chain(chain, target)
                # If chain contains only ligands, we can't align it
                if target_chain is None:
                    continue
                t = TemplateChain()
                t.pdb_accession = seq.pdb_accession
                t.id = ordinal
                t.asym_id = chain.name
                t.template_data = d
                t.seq_range = (1, len(chain.residues))
                # todo: handle non-standard residues
                t.sequence = ''.join(r.code for r in chain.residues)
                t.gapped_sequence = self._get_gapped_sequence(chain)
                t.sequence_can = ''.join(r.code for r in chain.residues)
                t.target_chain = target_chain
                self.template_chains.append(t)
                ordinal += 1

    def _add_target_chains(self, seq, mdl):
        ordinal = 1
        self.target_chains = []
        for nchain, chain in enumerate(seq.chains):
            asym_id, entity_id = self._asym_map.map_chain(nchain)
            # No polymer in this Modeller segment. Add a None chain so
            # that we can still use the chain index to access the
            # target_chains list.
            if asym_id is None:
                self.target_chains.append(None)
                continue
            t = TargetChain()
            t.id = ordinal
            t.asym_id = asym_id
            t.entity_id = entity_id
            t.sequence = ''.join(r.code for r in chain.residues)
            t.gapped_sequence = self._get_gapped_sequence(chain)
            t.seq_range = (1, len(chain.residues))
            self.target_chains.append(t)
            ordinal += 1

    def write_mmcif(self, writer):
        target_seq = TargetData()
        target_seq.name = 'Target Sequence'
        self._add_data(target_seq, True)
        model_data = ModelData()
        model_data.name = 'Target Structure'
        self._add_data(model_data, False)
        self._write_template_details(writer)
        self._write_template_segments(writer)
        self._write_poly_mapping(writer)
        self._write_alignment(writer)
        self._write_restraints(writer)

    def _write_template_details(self, writer):
        with writer.loop(
                "_ma_template_trans_matrix",
                ["id",
                 "rot_matrix[1][1]", "rot_matrix[2][1]", "rot_matrix[3][1]",
                 "rot_matrix[1][2]", "rot_matrix[2][2]", "rot_matrix[3][2]",
                 "rot_matrix[1][3]", "rot_matrix[2][3]", "rot_matrix[3][3]",
                 "tr_vector[1]", "tr_vector[2]", "tr_vector[3]"]) as lp:
            lp.write(id=1, rot_matrix11=1.0, rot_matrix21=0.0,
                     rot_matrix31=0.0, rot_matrix12=0.0,
                     rot_matrix22=1.0, rot_matrix32=0.0,
                     rot_matrix13=0.0, rot_matrix23=0.0,
                     rot_matrix33=1.0, tr_vector1=0.0,
                     tr_vector2=0.0, tr_vector3=0.0)

        with writer.loop('_ma_template_details',
                         ['ordinal_id', 'template_id',
                          'template_origin',
                          'template_entity_type', 'template_trans_matrix_id',
                          'template_data_id', 'target_asym_id',
                          'template_model_num',
                          'template_auth_asym_id']) as lp:
            for t in self.template_chains:
                origin = 'reference database' if t.pdb_accession else '?'
                lp.write(ordinal_id=t.id, template_id=t.id,
                         template_origin=origin,
                         template_entity_type='polymer',
                         template_trans_matrix_id=1,
                         template_data_id=t.template_data.id,
                         target_asym_id=t.target_chain.asym_id,
                         template_model_num=1, template_auth_asym_id=t.asym_id)

        with writer.loop('_ma_template_ref_db_details',
                         ['template_id', 'db_name',
                          'db_accession_code']) as lp:
            for t in self.template_chains:
                if t.pdb_accession:
                    lp.write(template_id=t.id, db_name='PDB',
                             db_accession_code=t.pdb_accession)

        with writer.loop('_ma_template_poly',
                         ['template_id', 'seq_one_letter_code',
                          'seq_one_letter_code_can']) as lp:
            for t in self.template_chains:
                lp.write(template_id=t.id,
                         seq_one_letter_code=t.sequence,
                         seq_one_letter_code_can=t.sequence_can)

    def _write_template_segments(self, writer):
        ordinal = 1
        with writer.loop('_ma_template_poly_segment',
                         ['id', 'template_id', 'residue_number_begin',
                          'residue_number_end']) as lp:
            for t in self.template_chains:
                lp.write(id=ordinal, template_id=t.id,
                         residue_number_begin=t.seq_range[0],
                         residue_number_end=t.seq_range[1])
                t.segment_id = ordinal
                ordinal += 1

    def _write_poly_mapping(self, writer):
        ordinal = 1
        with writer.loop('_ma_target_template_poly_mapping',
                         ['id', 'template_segment_id', 'target_asym_id',
                          'target_seq_id_begin', 'target_seq_id_end']) as lp:
            for a in self.alignments:
                for t in a.templates:
                    lp.write(id=ordinal, template_segment_id=t.segment_id,
                             target_asym_id=a.target.asym_id,
                             target_seq_id_begin=a.target.seq_range[0],
                             target_seq_id_end=a.target.seq_range[1])
                    ordinal += 1

    def _write_alignment(self, writer):
        # todo: populate with info on how the alignment was made
        with writer.loop('_ma_alignment_info',
                         ['alignment_id', 'data_id',
                          'alignment_length', 'alignment_type']) as lp:
            for a in self.alignments:
                lp.write(alignment_id=a.id, data_id=a.data.id,
                         alignment_length=len(a.target.gapped_sequence),
                         alignment_type='target-template pairwise alignment')
        with writer.loop('_ma_alignment_details',
                         ['ordinal_id', 'alignment_id', 'template_segment_id',
                          'target_asym_id']) as lp:
            ordinal = 1
            for a in self.alignments:
                for template in a.templates:
                    lp.write(ordinal_id=ordinal, alignment_id=a.id,
                             template_segment_id=template.segment_id,
                             target_asym_id=a.target.asym_id)

        with writer.loop('_ma_alignment',
                         ['ordinal_id', 'alignment_id', 'target_template_flag',
                          'sequence']) as lp:
            ordinal = 1
            for a in self.alignments:
                for template in a.templates:
                    lp.write(ordinal_id=ordinal, alignment_id=a.id,
                             target_template_flag=2,  # Template
                             sequence=template.gapped_sequence)
                    ordinal += 1
                lp.write(ordinal_id=ordinal, alignment_id=a.id,
                         target_template_flag=1,  # Target
                         sequence=a.target.gapped_sequence)

    def _write_restraints(self, writer):
        with writer.loop('_ma_restraints',
                         ['restraint_id', 'data_id', 'name',
                          'restraint_type', 'details']) as lp:
            for r in self.restraints:
                lp.write(restraint_id=r.id, data_id=r.data.id,
                         restraint_type=r.restraint_type)
        for r in self.restraints:
            self._write_type_restraints(writer, r)

    def _get_rsr_limits(self, r):
        """Get the limits for a given restraint"""
        mean, stddev = r._parameters[0:2]
        lower = upper = lower_esd = upper_esd = '.'
        if isinstance(r, modeller.forms.Gaussian):
            rsr_type = 'lower and upper bound'
            lower = upper = mean
            lower_esd = upper_esd = stddev
        elif isinstance(r, modeller.forms.LowerBound):
            rsr_type = 'lower bound'
            lower = mean
            lower_esd = stddev
        else:
            rsr_type = 'upper bound'
            upper = mean
            upper_esd = stddev
        return lower, upper, lower_esd, upper_esd, rsr_type

    def _get_rsr_atom_items(self, r):
        """Get the mmCIF data items for the atoms in a given restraint"""
        atoms = r._features.indices_to_atoms(
            self.mdl, r._features.get_atom_indices()[0])
        atoms = [_CifAtom(a, self._asym_map) for a in atoms]
        keys = {}
        for i, atom in enumerate(atoms):
            ki = "%d" % (i + 1)
            keys['atom_id_' + ki] = atom.atom_id
            keys['comp_id_' + ki] = atom.comp_id
            keys['asym_id_' + ki] = atom.asym_id
            keys['seq_id_' + ki] = atom.seq_id
            keys['entity_id_' + ki] = atom.entity_id
        return keys

    def _write_type_restraints(self, writer, restraints):
        """Write one type of restraint (distance, angle, dihedral)"""
        with writer.loop(restraints.category,
                         ['ordinal_id', 'restraint_id', 'restraint_type']
                         + restraints._get_atom_keywords()
                         + restraints._get_granularity_keyword()
                         + ['lower_limit', 'lower_limit_esd',
                            'upper_limit', 'upper_limit_esd']) as lp:
            ordinal = 1
            for r in restraints:
                (lower, upper, lower_esd,
                 upper_esd, rsr_type) = self._get_rsr_limits(r)
                atom_items = self._get_rsr_atom_items(r)
                lp.write(ordinal_id=ordinal,
                         restraint_id=restraints.id,
                         granularity='by-atom', restraint_type=rsr_type,
                         upper_limit=upper, lower_limit=lower,
                         upper_limit_esd=upper_esd, lower_limit_esd=lower_esd,
                         **atom_items)
                ordinal += 1
