import unittest
import os
import shutil
import modpipe.test
import modpipe.serialize
from modpipe.main import GatherModMP

seq_md5 = 'd18b74a08d8b83e15835daaf76471976FCGHGNVV'

class GatherTests(modpipe.test.TestCase):

    def _make_model(self, model_md5, start=5, end=120, ga341=0.8,
                    zdope=-0.5, seqid=80.0, mpqs=1.0,tsvmod=0.5):
        """Make a single example Model object"""
        ga341 = modpipe.serialize.GA341(
                  total=ga341, compactness=0, distance=0, surface_area=0,
                  combined=0, z_distance=0, z_surface_area=0, z_combined=0)
        score =  modpipe.serialize.Score(
                 objfunc=0, dope=0, dope_hr=0, normalized_dope=zdope,
                 quality=mpqs, ga341=ga341)
        seq = modpipe.serialize.Sequence(id=seq_md5, length=120)
        aln = modpipe.serialize.Alignment(id='1d5263d70500c21bab0dcf7d9d0480e9',
                                          evalue=0, gap_percentage=1,
                                          score_chi_squared=0.0578,
                                          score_ks=0.0120)
        return modpipe.serialize.Model(
                 sequence=seq, alignment=aln,
                 region=[start, end],
                 fold_assignment_method='1000',
                 id=model_md5, hetatms=1, waters=4, score=score,
                 highest_sequence_identity=seqid, rating='111111101',
                 templates=[{'pdb_code': '1kllA', 'region': [6,123],
                             'sequence_identity': 100}])

    def assert_final_models(self, finfile, expected_models):
        """Read all model MD5s from the file, and check they match expected"""
        models = modpipe.serialize.read_models_file(file(finfile, 'r'))
        models = [m.id for m in models]
        self.assertEqual(models, expected_models)

    def _setup_models_file(self):
        """Create files needed to run GatherModMP"""
        self.conf = 'modpipe.conf'
        print >> file(self.conf, 'w'), "DATDIR  %s" % os.getcwd()
        self.unq = 'testseq.unq'
        print >> file(self.unq, 'w'), "%s :  testseq" % seq_md5
        self.mod = 'd18/%s/sequence/%s.mod' % (seq_md5, seq_md5)
        if os.path.exists('d18'):
            shutil.rmtree('d18')
        os.mkdir('d18')
        os.mkdir('d18/%s' % seq_md5)
        os.mkdir('d18/%s/sequence' % seq_md5)
        return file(self.mod, 'w')

    def _check_selected_models(self, expected_models, final_models_by,
                               by_region='OFF'):
        modfile = 'test-mod.out'
        hitfile = 'test-hit.out'
        selfile = 'test-sel.out'
        finfile = 'test-fin.out'
        self.run_script('main', GatherModMP,
                        ['--conf_file', self.conf, '--unq_file', self.unq,
                         '--output_modfile', modfile,
                         '--output_hitfile', hitfile,
                         '--output_selfile', selfile,
                         '--output_finfile', finfile,
                         '--select_by_region', by_region,
                         '--final_models_by', final_models_by])
        self.assert_final_models(finfile, expected_models)
        for f in (self.conf, self.unq, self.mod, modfile, hitfile,
                  selfile, finfile):
            os.unlink(f)
        os.unlink('d18/%s/sequence/%s.fin' % (seq_md5, seq_md5))
        os.unlink('d18/%s/sequence/%s.gather.log' % (seq_md5, seq_md5))
        os.rmdir('d18/%s/sequence' % seq_md5)
        os.rmdir('d18/%s' % seq_md5)
        os.rmdir('d18')

    def test_gather_mpqs(self):
        """Check gather of models by MPQS score"""
        # Should only select models with MPQS >= 1.0:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', mpqs=0.9)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models([], 'MPQS')
        # Should pick model sorted first by MPQS, second by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', mpqs=10),
                  self._make_model('MODEL2', mpqs=30),
                  self._make_model('MODEL3', mpqs=30, start=10),
                  self._make_model('MODEL4', mpqs=20, start=1)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL2'], 'MPQS')

    def test_gather_dope(self):
        """Check gather of models by DOPE score"""
        # Should only select models with normalized DOPE < 0:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', zdope=0)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models([], 'DOPE')
        # Should pick model sorted first by DOPE, second by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', zdope=-0.2),
                  self._make_model('MODEL2', zdope=-0.4),
                  self._make_model('MODEL3', zdope=-0.4, start=10),
                  self._make_model('MODEL4', zdope=-0.3, start=1)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL2'], 'DOPE')

    def test_gather_ga341(self):
        """Check gather of models by GA341 score"""
        # Should only select models with GA341 >= 0.7:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', ga341=0.65)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models([], 'GA341')
        # Should pick model sorted first by GA341, second by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', ga341=0.7),
                  self._make_model('MODEL2', ga341=0.9),
                  self._make_model('MODEL3', ga341=0.9, start=10),
                  self._make_model('MODEL4', ga341=0.8, start=1)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL2'], 'GA341')

    def test_gather_all(self):
        """Check gather of models by ALL method"""
        # Should return all models regardless of score:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', ga341=0.1),
                  self._make_model('MODEL2', zdope=10.0),
                  self._make_model('MODEL3', mpqs=-10.0),
                  self._make_model('MODEL4', seqid=5.0)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL1', 'MODEL2', 'MODEL3', 'MODEL4'],
                                    'ALL')

    def test_longest_dope(self):
        """Check gather of models by LONGEST_DOPE method"""
        # Should only select models with z-dope < 0:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', zdope=0)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models([], 'LONGEST_DOPE')
        # Should pick any model with z-dope < 0 sorted by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', start=1, zdope=-0.1),
                  self._make_model('MODEL2', start=5, zdope=-0.2),
                  self._make_model('MODEL3', start=10, zdope=-0.3)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL1'], 'LONGEST_DOPE')

    def test_longest_ga341(self):
        """Check gather of models by LONGEST_GA341 method"""
        # Should only select models with GA341 >= 0.7:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', ga341=0.65)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models([], 'LONGEST_GA341')
        # Should pick any model with GA341 >= 0.7 sorted by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', start=1, ga341=0.7),
                  self._make_model('MODEL2', start=5, ga341=1.0),
                  self._make_model('MODEL3', start=10, ga341=1.0)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL1'], 'LONGEST_GA341')

    def test_gather_seqid(self):
        """Check gather of models by SEQID score"""
        # Should pick model sorted first by seqid, second by sequence length:
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', seqid=50),
                  self._make_model('MODEL2', seqid=80),
                  self._make_model('MODEL3', seqid=80, start=10),
                  self._make_model('MODEL4', seqid=20, start=1)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['MODEL2'], 'SEQID')

    def test_gather_multiple(self):
        """Check gather of models by multiple methods"""
        fh = self._setup_models_file()
        models = [self._make_model('BESTDOPE', zdope=-1.0, ga341=0.1),
                  self._make_model('BESTGA341', zdope=-0.1, ga341=1.0),
                  self._make_model('MODEL3', zdope=-0.5, ga341=0.5)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['BESTDOPE', 'BESTGA341'], 'DOPE,GA341')
        # Should only return one model, not two copies of BESTBOTH:
        fh = self._setup_models_file()
        models = [self._make_model('BESTBOTH', zdope=-1.0, ga341=1.0),
                  self._make_model('MODEL2', zdope=-0.5, ga341=0.5)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        self._check_selected_models(['BESTBOTH'], 'DOPE,GA341')

    def test_gather_by_region(self):
        """Check gather of models by region"""
        fh = self._setup_models_file()
        models = [self._make_model('MODEL1', zdope=-0.8, start=1, end=60),
                  self._make_model('MODEL2', zdope=-1.0, start=1, end=60),
                  self._make_model('MODEL3', zdope=-0.2, start=61, end=120),
                  self._make_model('MODEL4', zdope=-0.4, start=61, end=120)]
        modpipe.serialize.write_models_file(models, fh)
        fh.close()
        # Should pick best model for each region, not best DOPE overall:
        self._check_selected_models(['MODEL2', 'MODEL4'],
                                    'DOPE', by_region='ON')
        for start, end, separate in (
             (87, 136, False),   # Regions are separate if the non-overlap
             (86, 135, True),    # region is 30% of the sequence
                                 # length or more (30% of 50 = 15).
             (72, 271, False),   # Regions are also separate if the non-overlap
             (71, 270, True),    # region is 30 residues or more.
             (150, 152, False),  # Non-overlap region is for the longer
                                 # sequence, not the shorter one.
             ):
            fh = self._setup_models_file()
            models = [self._make_model('MODEL1', zdope=-1, start=101, end=400),
                      self._make_model('MODEL2', zdope=-2,
                                       start=start, end=end)]
            modpipe.serialize.write_models_file(models, fh)
            fh.close()
            if separate:
                expected_models = ['MODEL1', 'MODEL2']
            else:
                expected_models = ['MODEL2']
            self._check_selected_models(expected_models, 'DOPE', by_region='ON')

if __name__ == '__main__':
    unittest.main()
