# This file is part of ModPipe, Copyright 1997-2010 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

"""Handling of amino acid sequences and common file formats"""

import modpipe
import re
import os
# Use hashlib module if available (Python 2.5 or later) since md5 is deprecated
try:
    import hashlib
except ImportError:
    import md5 as hashlib

class Sequence(object):
    """Representation of a single amino acid sequence"""

    def __init__(self):
        self.prottyp = 'sequence'
        self.range = [['', ''], ['', '']]
        self.atom_file = self.name = self.source = ''
        self.resolution = self.rfactor = ''

    def get_id(self):
        """Return the ModPipe sequence identifier"""
        m = hashlib.md5()
        m.update(self.primary)
        return m.hexdigest() + self.primary[:4] + self.primary[-4:]

    def clean(self):
        """Clean up the primary sequence"""
        s = self.primary
        s = re.sub('\W+', '', s)    # remove non-word char
        s = re.sub('_+', '', s)     # remove underscore, since not covered by \W
        s = re.sub('\d+', '', s)    # remove numbers
        s = s.replace('*', '')      # remove asterisks
        s = re.sub('\s+', '', s)    # remove spaces
        s = s.upper()               # convert to uppercase
        s = s.replace('B', 'N')     # convert ASX to ASN
        s = s.replace('Z', 'Q')     # convert GLX to GLN
        # convert everything else to GLY
        s = re.sub('[^ACDEFGHIKLMNPQRSTVWY]', 'G', s)
        self.primary = s


class FASTAFile(object):
    """Representation of a FASTA-format file"""

    def read(self, fh):
        """Read sequences from the given stream in FASTA format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        seq = None
        for (num, line) in enumerate(fh):
            if line.startswith('>'):
                if seq:
                    yield seq
                seq = Sequence()
                seq.primary = ''
                seq.code = line.rstrip()[1:]
            else:
                line = line.rstrip()
                if line and seq is None:
                    raise modpipe.FileFormatError( \
"Found FASTA sequence before first header at line %d: %s" % (num + 1, line))
                seq.primary += line
        if seq:
            yield seq

    def write(self, fh, seq, width=70):
        """Write a single :class:`Sequence` object to the given stream in
           FASTA format."""
        print >> fh, ">" + seq.code
        for pos in range(0, len(seq.primary), width):
            print >> fh, seq.primary[pos:pos+width]


class PIRFile(object):
    """Representation of a PIR-format file"""

    def _parse_pir_header(self, num, line, seq):
        seq.primary = ''
        spl = line.rstrip().split(':')
        if len(spl) != 10:
            raise modpipe.FileFormatError( \
"Invalid PIR header at line %d (expecting 10 fields split by colons): %s" \
% (num + 1, line))
        (seq.prottyp, seq.atom_file, seq.range[0][0], seq.range[0][1],
         seq.range[1][0], seq.range[1][1], seq.name, seq.source,
         seq.resolution, seq.rfactor) = spl
        if seq.prottyp == '':
            seq.prottyp = 'sequence'

    def read(self, fh):
        """Read sequences from the given stream in PIR format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        seq = None
        terminator = re.compile('\*\s*$')
        for (num, line) in enumerate(fh):
            if line.startswith('C;') or line.startswith('R;'):
                # Skip comment lines
                continue
            elif line.startswith('>P1;'):
                if seq:
                    raise modpipe.FileFormatError( \
"PIR sequence without terminating * at line %d: %s" % (num + 1, line))
                seq = Sequence()
                seq.primary = None
                seq.code = line.rstrip()[4:]
            elif seq and seq.primary is None:
                self._parse_pir_header(num, line, seq)
            else:
                line = line.rstrip()
                if line:
                    if seq is None:
                        raise modpipe.FileFormatError( \
"PIR sequence found without a preceding header at line %d: %s" \
% (num + 1, line))
                    (line, count) = terminator.subn("", line)
                    seq.primary += line
                    # See if this was the last line in the sequence
                    if count == 1:
                        yield seq
                        seq = None
        if seq:
            raise modpipe.FileFormatError( \
                     "PIR sequence without terminating * at end of file")

    def write(self, fh, seq, width=70):
        """Write a single :class:`Sequence` object to the given stream in
           PIR format."""
        print >> fh, ">P1;" + seq.code
        start, end = seq.range
        print >> fh, ":".join(str(x) for x in [seq.prottyp, seq.atom_file,
                                               start[0], start[1], end[0],
                                               end[1], seq.name, seq.source,
                                               seq.resolution, seq.rfactor])
        for pos in range(0, len(seq.primary), width):
            print >> fh, seq.primary[pos:pos+width]
        print >> fh, '*'


class SPTRFile(object):
    """Representation of a file containing UniProtKB/SwissProt or TrEMBL
       database entries"""

    def read(self, fh):
        """Read sequences from the given stream in SPTR format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        AC = re.compile('AC   (\w+);')
        seq = None
        for (num, line) in enumerate(fh):
            m = AC.match(line)
            if m:
                if seq:
                    raise modpipe.FileFormatError( \
"SPTR file contains AC record before end of previous sequence at line %d: %s" \
% (num + 1, line))
                seq = Sequence()
                seq.code = m.group(1)
                seq.primary = None
            elif line.startswith('SQ   SEQUENCE') and seq:
                seq.primary = ''
            elif line.startswith('//') and seq:
                yield seq
                seq = None
            elif seq and seq.primary is not None:
                seq.primary += line.rstrip().replace(' ', '')
        if seq and seq.primary is not None:
            yield seq


class UniqueFile(object):
    """Mapping file from alignment codes to ModPipe IDs"""

    def __init__(self):
        self._unqseq = {}

    def add_sequence(self, modpipe_id, align_code):
        """Add a single mapping from an align code to a ModPipe ID"""
        if self._unqseq.has_key(modpipe_id):
            self._unqseq[modpipe_id].append(align_code)
        else:
            self._unqseq[modpipe_id] = [align_code]

    def file_name_from_seqfile(self, seqfile):
        """Given an input sequence file, return a suitable name for the
           unique file"""
        return os.path.splitext(os.path.basename(seqfile))[0] + '.unq'

    def write(self, fh):
        """Write the code-ID mapping to a stream"""
        for (key, value) in self._unqseq.iteritems():
            print >> fh, "%s : %s" % (key, " ".join(value))
