import unittest
import os
import modpipe.test
from modpipe.scripts import ClusterPDB

class ClusterPDBTests(modpipe.test.TestCase):

    def test_seqfile_exists(self):
        """Non-existing seqlist file should result in a ClusterPDB error"""
        self.assertRaises(IOError, self.run_script, 'scripts', ClusterPDB,
                          ['/does/not/exist'])

    def test_required_arguments(self):
        """Check required ClusterPDB arguments"""
        self.assertRaises(SystemExit, self.run_script, 'scripts',
                          ClusterPDB, [])

    def test_cluster_pdb(self):
        """Check PDB clustering"""
        codes = ['3agzA', '1opaA', '1u2kA', '1crbA', '1g7nA', '1llpA', '1oafA',
                 '1jd0A', '1ecsA', '1ftpA', '1kopA', '1apxA', '1mdcA', '1f9zA',
                 '1fljA', '1iynA', '1kqwA', '1lpjA', '1kllA', '1itkA', '1o8vA',
                 '1gglA', '1hmrA', '1hcbA', '1cbiA', '1n8yC', '1bwyA', '1sj2A',
                 '1ub2A', '1fdqA', '1keqA', '1b56A', '1rj5A', '1zncA', '1zbyA',
                 '1zncB', '1j4wA', '1vyfA', '1cbsA', '1h0zA', '1pmpA', '1v9eA',
                 '1qipA', '1touA', '1mwvA', '1lugA', '3dh4A']
        groups = [['3agzB'], ['1opaB'],[], [], [], [], [], ['1jd0B'],['1ecsB'],['1ftpB'],
                  ['1kopB'],['1apxB','1apxC','1apxD'],[], ['1f9zB'],[], [], [],
                  [], [], ['1itkB'],[], ['1gglB'],[], [], ['1cbiB'],[], [],
                  ['1sj2B'],[], ['1fdqB'],['1keqB'],[], ['1rj5B'],[], [],[],[],[],[],[],['1pmpB', '1pmpC'],['1v9eB'],
                  ['1qipB','1qipC','1qipD'],[],
                  ['1mwvB'],[],['3dh4B','3dh4C','3dh4D']]
        self.maxDiff=None
        self.run_script('scripts', ClusterPDB, ['-f', 'pdb_95', '-t', '95',
                                                '../db/test-pdb.pir'])
        with open('pdb_95.cod') as fh:
            # Clusters are dict keys, so we need to sort them to compare
            self.assertEqual(sorted(line.rstrip('\r\n') for line in fh),
                             sorted(codes))
        os.unlink('pdb_95.cod')
        with open('pdb_95.grp') as fh:
            self.assertEqual(sorted(line.rstrip('\r\n ') for line in fh),
                sorted("%s : %s" % (code, " ".join([code] + group))
                                    for (code, group) in zip(codes, groups)))
        os.unlink('pdb_95.grp')
        pir_codes = []
        with open('pdb_95.pir') as fh:
            for line in fh:
                if line.startswith('>P1;'):
                    pir_codes.append(line[4:].rstrip('\r\n'))
        pir_codes.sort()
        sorted_codes = codes[:]
        sorted_codes.sort()
        self.assertEqual(sorted_codes, pir_codes)
        os.unlink('pdb_95.pir')
        for f in ['test-pdb.fsa', 'test-pdb95.clstr', 'test-pdb95.bak.clstr',
                  'test-pdb95']:
            os.unlink(f)

if __name__ == '__main__':
    unittest.main()
