#!/usr/bin/perl

###########################################################
#   mutual-info-2D                                        #
#   written by Rachel Karchin and Kevin Karplus  2004     #
#   rachelk@salilab.org                                   #
###########################################################

use lib "/usr/local/site-perl";  #location of Algorithm, Statistics, and 
                                 #RDBtable perl modules on your system
use RDBtable;
use Algorithm::Permute;
use Statistics::Descriptive;

&process_command_line;

$pairs = new RDBtable;
$pairs->init_from_file($inputfile);

@letters1 = $pairs->get_column($colname1);
#delete column name and header
shift @letters1;
shift @letters1;

@letters2 = $pairs->get_column($colname2);
#delete column name and header
shift @letters2;
shift @letters2;

$let1string = join("", @letters1);
$let2string = join("", @letters2);
$scramble=0;
$observed_mi = &compute_mi($let1string,$let2string,$scramble);

$scramble=1;
$p = new Algorithm::Permute([@letters2]);

for ($i=0; $i<$numpermutes; $i++){
    @res=$p->next;
    $let2string = join("", @res);
    $random_mi = &compute_mi($let1string,$let2string,$scramble);
#    print STDERR $random_mi . "\n";
    push @all_random_mi, $random_mi;
}

$stat = Statistics::Descriptive::Full->new();
$stat->add_data(@all_random_mi); 
$mean = $stat->mean();
$SE  = $stat->standard_deviation();

$prettymean = sprintf("%5.3g",$mean);
$prettySE = sprintf("%5.3g", $SE);
$excessMI = $observed_mi-$mean;
$prettyexcess = sprintf("%5.3g", $excessMI);

print "Observed MI = " . sprintf("%5.3g", $observed_mi) . "\n";
print "Expected MI = $prettymean +- $prettySE\n";
print "Excess MI = $prettyexcess\n";

sub compute_mi{
    local($row_seq,$col_seq,$scramble) = @_;
    local $csub;
    local $rsub;
    local @count;
    local @row_sum;
    local @col_sum;
    local $total_sum;
    local $mi;
    local $i;

    if (length($row_seq) != length($col_seq))
    {   die "Error: $row_type has length " . length($row_seq)
            . " but $col_type has length " . length($col_seq) . "\n";
    }
    for ($i=length($row_seq)-1; $i>=0; $i--)
    {   next if $i+$shift >= length($col_seq);
        last if $i+$shift <0 ;
        # get a pair of corresponding letters
        $r = substr($row_seq,$i,1);
        $c = substr($col_seq,$i+$shift,1);
        
        $r = $row_alias{$r} if defined $row_alias{$r};
        $c = $col_alias{$c} if defined $col_alias{$c};
        
        # convert to subscripts using hash tables
        $rsub = $row_subscript{$r};
        if (!defined($rsub))
        {   $rsub = $num_rows++;
            $row_subscript{$r} = $rsub;
        }
        $csub = $col_subscript{$c};
        if (!defined($csub))
        {   $csub = $num_cols++;
            $col_subscript{$c} = $csub;
        }
        
        $count[$rsub][$csub] ++;
        
    }
    # choose an order for row and column names
    @row_names = sort keys(%row_subscript);
    @col_names = sort keys(%col_subscript);
    
    # get row and column sums, and $total_sum
    for ($r=0; $r<=$#row_names; $r++)
    {   for ($c=0; $c<=$#col_names; $c++)
    	{   my $cnt = $count[$r][$c];
	    if (! defined ($cnt))
	    {   $cnt = 0;
	    	$count[$r][$c]= 0;
	    }
	    $row_sum[$r] += $cnt;
	    $col_sum[$c] += $cnt;
	    $total_sum += $cnt;
	}
    }
    
    if(!$scramble){
        print "Raw counts:\n";
        print_table("\t%7d", $total_sum,@count);
    }

    # rescale to get probabilities;
    for ($c=0; $c<=$#col_names; $c++)
    {	$col_sum[$c] /= $total_sum;
    }
    for ($r=0; $r<=$#row_names; $r++)
    {   $row_sum[$r] /= $total_sum;
    }
    # use row and column probabilities as pseudocounts
    $total_plus_pseudo = $total_sum + $#row_names + $#col_names +2;
    
    for ($r=0; $r<=$#row_names; $r++)
    {	for ($c=0; $c<=$#col_names; $c++)
	{   $cnt = $count[$r][$c] + $row_sum[$r] +$col_sum[$c]; 
	    $probs[$r][$c] = ($cnt +0.0) / $total_plus_pseudo;
	}
    }

   if(!$scramble){    
       print "\n\nProbabilities:\n";
       print_table("\t%7.4f", 1.0,@probs);
   }

    # convert to log-odds ratio
    #log2(P(a,b)/P(a)P(b))
    for ($r=0; $r<=$#row_names; $r++)
    {   my $logr =  ($row_sum[$r] <= 0)? 0: log($row_sum[$r])/log(2);
    	for ($c=0; $c<=$#col_names; $c++)
	{   my $logc =  ($col_sum[$c] <= 0)? 0: log($col_sum[$c])/log(2);
	    my $p = $probs[$r][$c];
	    $logodds[$r][$c] = ($p<=0? 0: log($p)/log(2)) - $logr -$logc;
	}
    }
    
   if(!$scramble){    
       print "\n\nLog odds in bits:\n";
       print_table("\t%7.2f", 1.0, @logodds);
   }

    # compute mutual information
    $mi = 0;
    $row_entropy = 0;	$col_entropy=0;
    $max_mi = 0;
    for ($c=0; $c<=$#col_names; $c++)
    {   $col_entropy -= ($col_sum[$c] <= 0)? 0: 
            $col_sum[$c] * log($col_sum[$c])/log(2);
    }
    for ($r=0; $r<=$#row_names; $r++)
    {   $row_entropy -= ($row_sum[$r] <= 0)? 0: 
            $row_sum[$r] * log($row_sum[$r])/log(2);
        for ($c=0; $c<=$#col_names; $c++)
    	{    $mi += $count[$r][$c] * $logodds[$r][$c];
             
             $max_mi -=  ($row_sum[$r] <= 0)? 0:
             $col_sum[$c] * log($row_sum[$r])/log(2);

#             print STDERR "max_mi_curr:$max_mi row_sum[r]=$row_sum[$r]  col_sum[c]=$col_sum[$c]\n";
         }
    }
    $mi /= $total_sum;
    
  if(!$scramble){    
      print "\n\n";
    
      printf "$colname1 entropy = %7g bits\n", $row_entropy;
      printf "$colname2 entropy = %7g bits\n", $col_entropy ;
      printf "Mutual information = %7g bits\n", $mi;
      printf "Max mutual information =  %7g bits\n", $max_mi;
  }
    return $mi;
}


sub print_table($$@)
{
    my ($format,$total,@tab) = @_;
    print "\t$col_type";
    print  ($shift>0? "[i+$shift]": "[i$shift]") if $shift!=0;
    print "\n";
    foreach $c (@col_names)
    {	print "\t     $c";
    }
    print "\t  total\n";
    print $row_type . ("-" x (8*($#col_names+3)-length($row_type))) . "\n";
    foreach $r (@row_names)
    {   $rsub = $row_subscript{$r};
    	print "  $r";
	foreach $c (@col_names)
	{   $csub = $col_subscript{$c};
	    $cnt = $tab[$rsub][$csub];
	    printf ($format, $cnt);
	}
	printf ("$format\n", $row_sum[$rsub]);
    }
    print "  total";
    foreach $c (@col_names)
    {   $csub = $col_subscript{$c};
    	printf ($format , $col_sum[$csub]);
    }
    printf ("$format\n", $total);
    
}



sub process_command_line {
    $argc = 0;
    
    while($argc <= $#ARGV){
        $_ = $ARGV[$argc++];
      SWITCH: {
          if (/^-inputfile/){$inputfile = $ARGV[$argc++];
                             last SWITCH;
                         }
          if (/^-colname1/){$colname1 = $ARGV[$argc++];
                            last SWITCH;
                        }
          if (/^-colname2/){$colname2 = $ARGV[$argc++];
                            last SWITCH;
                        }
          if (/^-numpermutes/){$numpermutes = $ARGV[$argc++];
                               last SWITCH;
                           }
      }
    }
    if(!defined $inputfile ||!defined $colname1 || !defined $colname2  || !defined $numpermutes)
    {
        &usage; exit(-1); 
    }
}

sub usage{
    print "Usage: mutual-info-2D -inputfile foo.rdb -colname1 LABEL -colname2 FEATURE -numpermutes 100 \n";
}

