import _modeller
import model
import profile
import util.top as top
import util.modutil as modutil
import util.modlist as modlist
from modeller.util.modobject import modobject
from modeller.coordinates import coordinates
from modeller import sequence

class alignment(modobject):
    """Holds an alignment of protein sequences"""
    __modpt = None
    env = None
    top = None

    def __new__(cls, *args, **vars):
        obj = modobject.__new__(cls)
        obj.__modpt = _modeller.new_alignment(obj)
        return obj

    def __init__(self, env, **vars):
        self.env = env.copy()
        self.top = top.top(self.env)
        if len(vars) > 0:
            self.append(**vars)

    def __repr__(self):
        if len(self) == 0:
            return "Empty alignment"
        else:
            return "Alignment of " + ", ".join([repr(s) for s in self])

    def __str__(self):
        if len(self) == 0:
            return "<Empty alignment>"
        else:
            return "<Alignment of " + ", ".join([str(s) for s in self]) + ">"

    def __del__(self):
        _modeller.free_alignment(self.modpt)

    def __get_modpt(self):
        return self.__modpt

    def __len__(self):
        return _modeller.alignment_nseq_get(self.modpt)

    def clear(self):
        """Remove all sequences from the alignment"""
        return _modeller.delete_alignment(self.modpt)

    def append(self, io=None, **vars):
        """Add sequence(s) from an alignment file"""
        if io is None:
            io = self.env.io
        try:
            align_codes = vars['align_codes']
        except KeyError:
            align_codes = 'all'
        try:
            atom_files = vars['atom_files']
        except KeyError:
            atom_files = None
        vars['align_codes'] = [seq.code for seq in self]
        vars['atom_files'] = [seq.atom_file for seq in self]
        if align_codes is not None:
            if type(align_codes) is tuple: align_codes = list(align_codes)
            if type(align_codes) is not list: align_codes = [ align_codes ]
            vars['align_codes'] = vars['align_codes'] + align_codes
        if atom_files is not None:
            if type(atom_files) is tuple: atom_files = list(atom_files)
            if type(atom_files) is not list: atom_files = [ atom_files ]
            vars['atom_files'] = vars['atom_files'] + atom_files
        return self.top.read_alignment('alignment.append', aln=self.modpt,
                                       io=io.modpt, libs=self.env.libs.modpt,
                                       **vars)

    def read(self, **vars):
        """Read sequence(s) from an alignment file"""
        self.clear()
        return self.append(**vars)

    def read_one(self, **vars):
        """Read sequences from file, one by one. Return False when no more
           can be read."""
        vars['align_codes'] = '@oNe@'
        vars['close_file'] = False
        return (self.read(**vars) == 0)

    def append_model(self, mdl, align_codes, atom_files=''):
        """Add a sequence from a model"""
        return self.top.sequence_to_ali('alignment.append_model',
                                        aln=self.modpt, mdl=mdl.modpt,
                                        libs=self.env.libs.modpt,
                                        align_codes=align_codes,
                                        atom_files=atom_files)

    def append_sequence(self, sequence):
        """Add a new sequence, as a string of one-letter codes"""
        return _modeller.append_sequence(self.modpt, sequence,
                                         self.env.libs.modpt)

    def compare_with(self, aln):
        """Compare with another alignment"""
        return _modeller.compare_alignments(aln=self.modpt, aln2=aln.modpt)

    def compare_structures(self, edat=None, io=None, **vars):
        """Compare 3D structures"""
        if io is None:
            io = self.env.io
        if edat is None:
            edat = self.env.edat
        return self.top.compare('alignment.compare_structures', aln=self.modpt,
                                edat=edat.modpt, io=io.modpt,
                                libs=self.env.libs.modpt, **vars)

    def compare_sequences(self, mdl, **vars):
        """Compare sequences in alignment"""
        return self.top.sequence_comparison('alignment.compare_sequences',
                                            aln=self.modpt, mdl=mdl.modpt,
                                            libs=self.env.libs.modpt, **vars)

    def segment_matching(self, **vars):
        """Align segments"""
        return self.top.segment_matching('alignment.segment_matching',
                                         aln=self.modpt,
                                         libs=self.env.libs.modpt, **vars)

    def edit(self, io=None, **vars):
        """Edit overhangs in alignment"""
        if io is None:
            io = self.env.io
        return self.top.edit_alignment('alignment.edit', aln=self.modpt,
                                       io=io.modpt, libs=self.env.libs.modpt,
                                       **vars)

    def append_profile(self, prf):
        """Add sequences from a profile"""
        return self.top.prof_to_aln('alignment.append_profile', prf=prf.modpt,
                                    aln=self.modpt, libs=self.env.libs.modpt)

    def to_profile(self):
        """Converts the alignment to profile format"""
        prf = profile.profile(self.env)
        _modeller.aln_to_prof(aln=self.modpt, prf=prf.modpt)
        return prf

    def check(self, io=None):
        """Check alignment for modeling"""
        if io is None:
            io = self.env.io
        return _modeller.check_alignment(aln=self.modpt, io=io.modpt,
                                         libs=self.env.libs.modpt)

    def describe(self, io=None):
        """Describe proteins"""
        if io is None:
            io = self.env.io
        return _modeller.describe(aln=self.modpt, io=io.modpt,
                                  libs=self.env.libs.modpt)

    def consensus(self, **vars):
        """Produce a consensus alignment"""
        return self.top.align_consensus('alignment.consensus', aln=self.modpt,
                                        libs=self.env.libs.modpt, **vars)

    def align(self, **vars):
        """Align two (blocks of) sequences"""
        return self.top.align('alignment.align', aln=self.modpt,
                              libs=self.env.libs.modpt, **vars)

    def malign(self, **vars):
        """Align two or more sequences"""
        return self.top.malign('alignment.malign', aln=self.modpt,
                               libs=self.env.libs.modpt, **vars)

    def align2d(self, io=None, **vars):
        """Align sequences with structures"""
        if io is None:
            io = self.env.io
        return self.top.align2d('alignment.align2d', aln=self.modpt,
                                io=io.modpt, libs=self.env.libs.modpt, **vars)

    def align3d(self, io=None, **vars):
        """Align two structures"""
        if io is None:
            io = self.env.io
        return self.top.align3d('alignment.align3d', aln=self.modpt,
                                io=io.modpt, libs=self.env.libs.modpt, **vars)

    def malign3d(self, io=None, **vars):
        """Align two or more structures"""
        if io is None:
            io = self.env.io
        return self.top.malign3d('alignment.malign3d', aln=self.modpt,
                                 io=io.modpt, libs=self.env.libs.modpt, **vars)

    def salign(self, io=None, **vars):
        """Align two or more proteins"""
        if io is None:
            io = self.env.io
        return self.top.salign('alignment.salign', aln=self.modpt, io=io.modpt,
                               libs=self.env.libs.modpt, **vars)

    def write(self, file, **vars):
        """Write the alignment to a file"""
        return self.top.write_alignment('alignment.write', aln=self.modpt,
                                        libs=self.env.libs.modpt, file=file,
                                        **vars)

    def id_table(self, matrix_file):
        """Calculate percentage sequence identities"""
        from modeller.id_table import id_table
        return id_table(self, matrix_file)

    def __contains__(self, code):
        return _modeller.find_alignment_code(self.modpt, code) >= 0

    def keys(self):
        return [seq.code for seq in self]

    def __delitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx, _modeller.find_alignment_code,
                                      (self.modpt,))
        _modeller.alignment_del_seq(self.modpt, ret)

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx, _modeller.find_alignment_code,
                                      (self.modpt,))
        if type(ret) is int:
            if _modeller.alignment_has_structure(self.modpt, ret):
                return structure(self, ret)
            else:
                return alnsequence(self, ret)
        else:
            return [self[ind] for ind in ret]

    def __get_comments(self):
        return modlist.simple_varlist(self.modpt,
                                      _modeller.alignment_ncomment_get,
                                      _modeller.alignment_ncomment_set,
                                      _modeller.alignment_comment_get,
                                      _modeller.alignment_comment_set)
    def __set_comments(self, obj):
        modutil.set_varlist(self.comments, obj)
    def __del_comments(self):
        modutil.del_varlist(self.comments)
    def __get_positions(self):
        return alnposlist(self)

    modpt = property(__get_modpt)
    comments = property(__get_comments, __set_comments, __del_comments,
                        doc="Alignment file comments")
    positions = property(__get_positions, doc="Alignment positions")


class alnsequence(sequence.sequence):
    """A single sequence within an alignment"""
    aln = None
    _num = None
    env = None

    def __init__(self, aln, num):
        self.aln = aln
        self.env = self.aln.env
        self._num = num
        sequence.sequence.__init__(self)

    def __len__(self):
        return self.nres

    def __repr__(self):
        return "Sequence %s" % repr(self.code)
    def __str__(self):
        return "<%s>" % repr(self)

    def transfer_res_prop(self):
        """Transfer residue properties of predicted secondary structure"""
        return _modeller.transfer_res_prop(self.aln.modpt, self._num)

    def get_num_equiv(self, seq):
        """Get the number of identical aligned residues between this sequence
           and |seq|."""
        neqv = 0
        for res in self.residues:
            other_res = res.get_aligned_residue(seq)
            if other_res is not None and res.type == other_res.type:
                neqv += 1
        return neqv

    def get_sequence_identity(self, seq):
        """Get the % sequence identity between this sequence and |seq|, defined
           as the number of identical aligned residues divided by the length
           of the shorter sequence."""
        return 100.0 * self.get_num_equiv(seq) / min(len(self), len(seq))

    def __get_naln(self):
        return _modeller.alignment_naln_get(self.aln.modpt)
    def __get_code(self):
        return _modeller.alignment_codes_get(self.aln.modpt, self._num)
    def __set_code(self, val):
        _modeller.alignment_codes_set(self.aln.modpt, self._num, val)
    def __get_atom_file(self):
        return _modeller.alignment_atom_files_get(self.aln.modpt, self._num)
    def __set_atom_file(self, val):
        _modeller.alignment_atom_files_set(self.aln.modpt, self._num, val)
    def __get_residues(self):
        return residuelist(self)
    def __get_seqpt(self):
        return _modeller.alignment_sequence_get(self.aln.modpt, self._num)

    code = property(__get_code, __set_code, doc="Alignment code")
    atom_file = property(__get_atom_file, __set_atom_file, doc="PDB file name")
    naln = property(__get_naln, doc="Length of alignment (including gaps)")
    residues = property(__get_residues, doc="List of residues")
    seqpt = property(__get_seqpt)


class structure(alnsequence, coordinates):
    __num = None
    __read_coord = False

    def __init__(self, aln, num):
        self.__num = num
        alnsequence.__init__(self, aln, num)
        coordinates.__init__(self)

    def __str__(self):
        return "<Structure %s>" % repr(self.code)

    def __get_cdpt(self):
        if not self.__read_coord:
            aln = self.aln
            _modeller.read_template_structure(aln.modpt, self.__num,
                                              aln.env.io.modpt,
                                              aln.env.libs.modpt)
            self.__read_coord = True
        struc = _modeller.alignment_structure_get(self.aln.modpt, self.__num)
        return _modeller.structure_cd_get(struc)
    cdpt = property(__get_cdpt)


class residuelist(object):

    def __init__(self, seq, offset=0, length=None):
        self.seq = seq
        self.offset = offset
        self.length = length

    def __len__(self):
        if self.length is not None:
            return self.length
        else:
            return len(self.seq)

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx)
        if type(ret) is int:
            return alignment_residue(self.seq, ret + self.offset)
        else:
            return [self[ind] for ind in ret]


class alignment_residue(sequence.sequence_residue):
    """A single residue in an aligned sequence"""

    def get_position(self):
        """Get the position in the alignment of this residue"""
        invaln = _modeller.alignment_invaln_get(self.mdl.aln.modpt)
        num = _modeller.f_int2_get(invaln, self._num, self.mdl._num)
        return self.mdl.aln.positions[num - 1]

    def get_aligned_residue(self, seq):
       """Get the residue in |seq| that is aligned with this one, or None"""
       alnpos = self.get_position()
       return alnpos.get_residue(seq)

    def get_leading_gaps(self):
        """Get the number of gaps in the alignment immediately preceding this
           residue."""
        mypos = self.get_position().num
        try:
            prepos = self.mdl.residues[self._num - 1].get_position()
            prepos = prepos.num
        except IndexError:
            prepos = -1
        return mypos - prepos - 1

    def get_trailing_gaps(self):
        """Get the number of gaps in the alignment immediately following this
           residue."""
        mypos = self.get_position().num
        try:
            postpos = self.mdl.residues[self._num + 1].get_position()
            postpos = postpos.num
        except IndexError:
            postpos = self.mdl.naln
        return postpos - mypos - 1


class alnposition(object):
    """An alignment position"""

    def __init__(self, aln, indx):
        self.__aln = aln
        self.__indx = indx

    def get_residue(self, seq):
       """Get the residue in |seq| that is at this alignment position, or None
          if a gap is present."""
       aln = self.__aln
       if not isinstance(seq, alnsequence):
           raise TypeError, "Expected an 'alnsequence' object for seq"
       if seq.aln != aln:
           raise ValueError, "seq must be a sequence in the same alignment"
       ialn = _modeller.alignment_ialn_get(aln.modpt)
       ires = _modeller.f_int2_get(ialn, self.__indx, seq._num)
       if ires == 0:
           return None
       else:
           return seq.residues[ires-1]

    def __get_num(self):
        return self.__indx
    def __get_prof(self, typ):
        prof = _modeller.alignment_prof_get(self.__aln.modpt)
        return _modeller.f_float2_get(prof, self.__indx, typ)
    def __get_helix(self):
        return self.__get_prof(0)
    def __get_strand(self):
        return self.__get_prof(1)
    def __get_buried(self):
        return self.__get_prof(2)
    def __get_straight(self):
        return self.__get_prof(3)
    num = property(__get_num)
    helix = property(__get_helix, doc="Helix secondary structure")
    strand = property(__get_strand, doc="Strand secondary structure")
    buried = property(__get_buried, doc="Buriedness")
    straight = property(__get_straight, doc="Straightness")


class alnposlist(modlist.fixlist):
    def __init__(self, aln):
        self.__aln = aln
        modlist.fixlist.__init__(self)

    def __len__(self):
        return _modeller.alignment_naln_get(self.__aln.modpt)

    def _getfunc(self, indx):
        return alnposition(self.__aln, indx)
