#!/usr/bin/perl -w

# language modeller
# (c)2003 Stepan Roh
#
# computes cross entropy and smooth language models
#
# usage: ./assign1b.pl file.txt...

# "precise" sum calculator
# in: @numbers
# out: $sum of @numbers
sub sum(@);
sub sum(@) {
  return 0 if (!@_);
  return $_[0] if (@_ == 1);
  if (@_ % 2 != 0) {
    push (@_, 0);
  }
  my @res = ();
  for (my $i = 0; $i < @_; $i += 2) {
    my $v1 = $_[$i];
    my $v2 = $_[$i + 1];
    push (@res, $v1 + $v2);
  }
  return sum (@res);
}

# probability test (for probabilities which must count to 1)
# in: $test name, @probabilities
sub test_prob($@) {
  my ($name, @data) = @_;
  my $s = sum (@data);
  if (($s < 0.999995) || ($s > 1.000005)) {
    die "Testing ", $name, " probability failed: sum is $s\n";
  }
}

# log() with base 2
# in: $number
# out: $log number
sub log2($) {
  return log(shift) / log(2);
}

# process file and output results on stdout
# in: $filename
sub process_file ($) {
  my ($file) = @_;
  
  print "File: $file\n\n";

  # loading input text
  open (IN, $file) || die "Unable to open file $file : $!\n";
  my @input = map { chomp($_); $_ } <IN>;
  close (IN);

  # 20000 words of test data
  my @test_data = @input[(@input - 20000) .. (@input - 1)];
  # 40000 words of heldout data
  my @heldout_data = @input[(@input - 60000) .. (@input - 20001)];
  # rest is training data
  my @training_data = @input[0 .. (@input - 60001)];

  # compute some statistics (some of them are used later some are not)
  
  # hash (word from input text -> count)
  my %input_word_count = ();
  # hash (word from test data -> count)
  my %test_word_count = ();
  # hash (word from heldout data -> count)
  my %heldout_word_count = ();

  foreach $w (@input) {
    $input_word_count{$w}++;
  }
  foreach $w (@test_data) {
    $test_word_count{$w}++;
  }
  foreach $w (@heldout_data) {
    $heldout_word_count{$w}++;
  }

  # input text's vocabulary size
  my $input_vsize = scalar keys %input_word_count;
  # test data's vocabulary size
  my $test_vsize = scalar keys %test_word_count;
  # heldout data's vocabulary size
  my $heldout_vsize = scalar keys %heldout_word_count;

  print "Input size: ", scalar @input, " (vocabulary has ", $input_vsize, " words)\n\n";
  print "Test data size: ", scalar @test_data, " (vocabulary has ", $test_vsize, " words)\n";
  print "Heldout data size: ", scalar @heldout_data, " (vocabulary has ", $heldout_vsize, " words)\n";

  # training data word counts
  # - unlike in assign1a.pl I decided to use nested hashes and not compound keys, because they are easier
  #   to use for our kind of operations (see nested foreaches later)
  my %training_word_count = ();
  my %training_bigram_count = ();
  my %training_trigram_count = ();
  
  # two empty words ('') are prepended to data to handle bigram and trigram counts properly
  my ($prev, $prev2) = ('', '');
  foreach $w (@training_data) {
    $training_word_count{$w}++;
    $training_bigram_count{$prev}{$w}++;
    $training_trigram_count{$prev2}{$prev}{$w}++;
    $prev2 = $prev;
    $prev = $w;
  }
  
  if ($debug) {
    print "training_word_count:\n";
    foreach $w (sort keys %training_word_count) {
      print '  ', $w, '  ', $training_word_count{$w}, "\n";
    }
  }
  
  # training data's vocabulary size
  my $training_vsize = scalar keys %training_word_count;

  print "Training data size: ", scalar @training_data, " (vocabulary has ", $training_vsize, " words)\n";

  # compute "word coverage" (percentage of words in the test data which have been seen in the training data)
  my $word_coverage_count = 0;
  foreach $w (keys %test_word_count) {
    $word_coverage_count++ if (defined $training_word_count{$w});
  }
  
  print "Word coverage (test data / training data): ", $word_coverage_count / $test_vsize * 100, "%\n";

  # compute probabilities from training data
  
  # uniform probability
  my $uniform_prob = 1 / $training_vsize;
  # hash (word -> probability)
  my %unigram_prob = ();
  # hash (word[i-1] -> hash (word[i] -> probability))
  my %bigram_prob = ();
  # hash (word[i-2] -> hash (word[i-1] -> hash (word[i] -> probability)))
  my %trigram_prob = ();

  # computing all probabilities at once is nice optimization and also solves problem with empty words not
  # being in %training_word_count (each index goes only on valid words)
  foreach $w (keys %training_word_count) {
    $unigram_prob{$w} = $training_word_count{$w} / (scalar @training_data);
    foreach $w2 (keys %{$training_bigram_count{$w}}) {
      $bigram_prob{$w}{$w2} = $training_bigram_count{$w}{$w2} / $training_word_count{$w};
      foreach $w3 (keys %{$training_trigram_count{$w}{$w2}}) {
        $trigram_prob{$w}{$w2}{$w3} = $training_trigram_count{$w}{$w2}{$w3} / $training_bigram_count{$w}{$w2};
      }
    }
  }

  if ($debug) {
    print "uniform_prob: ", $uniform_prob, "\n";
    print "unigram_prob:\n";
    foreach $w (sort keys %training_word_count) {
      print '  ', $w, '  ', $unigram_prob{$w}, "\n";
    }
    print "bigram_prob:\n";
    foreach $w (sort keys %training_bigram_count) {
      foreach $w2 (sort keys %{$training_bigram_count{$w}}) {
        print '  ', $w, ' ', $w2, ' ', $bigram_prob{$w}{$w2}, "\n";
      }
    }
    print "trigram_prob:\n";
    foreach $w (sort keys %training_trigram_count) {
      foreach $w2 (sort keys %{$training_trigram_count{$w}}) {
        foreach $w3 (sort keys %{$training_trigram_count{$w}{$w2}}) {
          print '  ', $w, ' ', $w2, ' ', $w3, ' ', $trigram_prob{$w}{$w2}{$w3}, "\n";
        }
      }
    }
  }
  
  test_prob ('unigram', values %unigram_prob);

  # computing heldout trigrams which are used in EM algorithm
  # - having their count is only a side-effect
  
  # hash (word[i-2] -> hash (word[i-1] -> hash (word[i] -> probability)))
  my %heldout_trigram_count = ();

  # two empty words ('') are prepended to data to handle bigram and trigram counts properly
  ($prev, $prev2) = ('', '');
  foreach $w (@heldout_data) {
    $heldout_trigram_count{$prev2}{$prev}{$w}++;
    $prev2 = $prev;
    $prev = $w;
  }

  # smoothing EM algorithm
  
  print "\nEM algorithm\n\n";
  
  # lambdas
  my @lambda = ( 0, 0, 0, 0 );
  
  # next lambdas (in step 1 their values are initial)
  my @next_lambda = ( 0.25, 0.25, 0.25, 0.25 );
  
  # epsilon used in convergence detection
  my $epsilon = 0.01;
  
  print "Convergence epsilon is $epsilon\n\n";

  # lambdas convergence test
  # in: $epsilon, $old lambda set, @new lambda set
  # out: 1 = sets converge, 0 otherwise
  sub lambda_convergence ($\@\@) {
    my ($epsilon, $lambda_ref, $next_lambda_ref) = @_;
    for (my $i = 0; $i < @{$lambda_ref}; $i++) {
      return 0 if (abs ($$lambda_ref[$i] - $$next_lambda_ref[$i]) >= $epsilon);
    }
    return 1;
  }

  my $step = 1;

  # lambda calculation iteration
  while (!lambda_convergence ($epsilon, @lambda, @next_lambda)) {
    print "Step $step: lambdas = ", join (' ', @next_lambda), "\n";
    @lambda = @next_lambda;

    # computing smoothed probability
    my %sprob = ();
    
    foreach $w (keys %heldout_trigram_count) {
      foreach $w2 (keys %{$heldout_trigram_count{$w}}) {
        foreach $w3 (keys %{$heldout_trigram_count{$w}{$w2}}) {
          my $tp = $trigram_prob{$w}{$w2}{$w3};
          $tp = 0 if (!defined $tp);
          my $bp = $bigram_prob{$w2}{$w3};
          $bp = 0 if (!defined $bp);
          my $up = $unigram_prob{$w3};
          $up = 0 if (!defined $up);
          $sprob{$w}{$w2}{$w3} = $lambda[3] * $tp
                               + $lambda[2] * $bp
                               + $lambda[1] * $up
                               + $lambda[0] * $uniform_prob;
        }
      }
    }

    if ($debug) {
      print "sprob:\n";
      foreach $w (sort keys %sprob) {
        foreach $w2 (sort keys %{$sprob{$w}}) {
          foreach $w3 (sort keys %{$sprob{$w}{$w2}}) {
            print '  ', $w, ' ', $w2, ' ', $w3, ' ', $sprob{$w}{$w2}{$w3}, "\n";
          }
        }
      }
    }
    
    # computing expected counts
    
    my @c = ( 0, 0, 0, 0 );
    my ($prev, $prev2) = ('', '');
    foreach $w (@heldout_data) {
      $c[0] += $lambda[0] * $uniform_prob / $sprob{$prev2}{$prev}{$w};
      $c[1] += $lambda[1] * $unigram_prob{$w} / $sprob{$prev2}{$prev}{$w} if defined $unigram_prob{$w};
      $c[2] += $lambda[2] * $bigram_prob{$prev}{$w} / $sprob{$prev2}{$prev}{$w} if defined $bigram_prob{$prev}{$w};
      $c[3] += $lambda[3] * $trigram_prob{$prev2}{$prev}{$w} / $sprob{$prev2}{$prev}{$w} if defined $trigram_prob{$prev2}{$prev}{$w};
      $prev2 = $prev;
      $prev = $w;
    }

    # computing next lambdas
    
    my $csum = sum (@c);
    for (my $i = 0; $i < 4; $i++) {
      $next_lambda[$i] = $c[$i] / $csum;
    }
    
    $step++;
  }
  
  # there are better results in @next_lambda although their difference is only $epsilon
  @lambda = @next_lambda;
  print "Final step: lambdas = ", join (' ', @lambda), "\n";
  
  # tweaking lambdas
  my @tweaked_lambda = ( \@lambda );
  my @tweak_trigram_inc = ( 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 99 );
  my @tweak_trigram_dec = ( 90, 80, 70, 60, 50, 40, 30, 20, 10, 0 );

  # increasing trigram lambda
  foreach $t (@tweak_trigram_inc) {
    my @nlambda = @lambda;
    my $inc = (1 - $lambda[3]) * $t / 100;
    $nlambda[3] += $inc;
    my $lsum = $nlambda[0] + $nlambda[1] + $nlambda[2];
    for (my $i = 0; $i < 3; $i++) {
      $nlambda[$i] *= 1 - $inc / $lsum;
    }
    # lambda is not probability, but counts to 1
    test_prob ('lambdas', @nlambda);
    push (@tweaked_lambda, \@nlambda);
  }

  # decreasing trigram lambda
  foreach $t (@tweak_trigram_dec) {
    my @nlambda = @lambda;
    my $dec = $lambda[3] * ((100 - $t) / 100);
    $nlambda[3] -= $dec;
    my $lsum = $nlambda[0] + $nlambda[1] + $nlambda[2];
    for (my $i = 0; $i < 3; $i++) {
      $nlambda[$i] *= 1 + $dec / $lsum;
    }
    # lambda is not probability, but counts to 1
    test_prob ('lambdas', @nlambda);
    push (@tweaked_lambda, \@nlambda);
  }
  
  print "\nCross entropy for tweaked lambdas:\n\n";
  printf "%-8s   %-8s   %-8s   %-8s     %-9s\n", 'l(0)', 'l(1)', 'l(2)', 'l(3)', 'entropy';
  
  # computing cross entropy for each set of lambdas
  foreach $lambda_ref (@tweaked_lambda) {
    my @cur_lambda = @{$lambda_ref};
    
    my $ent = 0;
    
    ($prev, $prev2) = ('', '');
    foreach $w (@test_data) {
      my $tp = $trigram_prob{$prev2}{$prev}{$w};
      $tp = 0 if (!defined $tp);
      my $bp = $bigram_prob{$prev}{$w};
      $bp = 0 if (!defined $bp);
      my $up = $unigram_prob{$w};
      $up = 0 if (!defined $up);
      $ent += log2 ($cur_lambda[3] * $tp + $cur_lambda[2] * $bp
                  + $cur_lambda[1] * $up + $cur_lambda[0] * $uniform_prob);
      $prev2 = $prev;
      $prev = $w;
    }
    $ent = -$ent / scalar (@test_data);
    
    printf "%8.7f  %8.7f  %8.7f  %8.7f    %9.7f\n", @cur_lambda, $ent;
  }
  
  print "\n";
}

$debug = 0;
if ($ARGV[0] eq '-debug') {
  shift @ARGV;
  $debug = 1;
}

foreach $file (@ARGV) {
  process_file ($file);
}

1;
