PYTHON 15
Discomt_baseline.py Guest on 21st November 2020 05:33:49 PM
  1. #!/usr/bin/python
  2. # -*- coding:UTF-8 -*-
  3. import sys
  4. import os
  5. import re
  6. import math
  7. import optparse
  8. import kenlm
  9. from collections import defaultdict
  10. from gzip import GzipFile
  11.  
  12. '''
  13. takes a classification file (optionally gzipped) as input
  14. and writes a file with two columns (replacement+target)
  15.  
  16. Input format:
  17.  - classes (ignored)
  18.  - true replacements (ignored)
  19.  - source text (currently ignored)
  20.  - target text
  21.  - alignments (currently ignored)
  22.  
  23. Output format (for fmt=replace):
  24.  - predicted classes
  25.  - predicted replacements
  26.  - original source text
  27.  - original target text
  28.  - alignments
  29. '''
  30.  
  31. oparse = optparse.OptionParser(usage='%prog [options] input_file')
  32. oparse.add_option('--lm', dest='lm',
  33.                   help='language model file',
  34.                   default='corpus.5.fr.trie.kenlm')
  35. oparse.add_option('--fmt', dest='fmt',
  36.                   choices=['replace', 'predicted', 'both', 'compare', 'scores'],
  37.                   help='format (replace, predicted, both, compare, scores)',
  38.                   default='replace')
  39. oparse.add_option('--none-penalty', dest='none_penalty',
  40.                   type='float', default=0.0,
  41.                   help='penalty for empty filler')
  42.  
  43. replace_re = re.compile('REPLACE_[0-9]+')
  44.  
  45. all_fillers = [
  46.     ['il'], ['elle'],
  47.     ['ils'], ['elles'],
  48.     ["c'"], ["ce"], ["ça"], ['cela'], ["on"]]
  49.  
  50. non_fillers = [[w] for w in
  51.                '''
  52.               le l' se s' y en qui que qu' tout
  53.               faire ont fait est parler comprendre chose choses
  54.               ne pas dessus dedans
  55.               '''.strip().split()]
  56.  
  57. def map_class(x):
  58.     if [x] in non_fillers:
  59.         return 'OTHER'
  60.     elif x == 'NONE':
  61.         return 'OTHER'
  62.     elif x == "c'":
  63.         return 'ce'
  64.     else:
  65.         return x
  66.  
  67. NONE_PENALTY = 0
  68.  
  69. def gen_items(contexts, prev_contexts):
  70.     '''
  71.    extends the items from *prev_contexts* with
  72.    fillers and the additional bits of context from
  73.    *contexts*
  74.  
  75.    returns a list of (text, score, fillers) tuples,
  76.    and expects prev_contexts to have the same shape.
  77.    '''
  78.     if len(contexts) == 1:
  79.         return [(x+contexts[0], y, z)
  80.                 for (x,y,z) in prev_contexts]
  81.     else:
  82.         #print >>sys.stderr, "gen_items %s %s"%(contexts, prev_contexts)
  83.         context = contexts[0]
  84.         next_contexts = []
  85.         for filler in all_fillers:
  86.             next_contexts += [(x+context+filler, y, z+filler)
  87.                               for (x,y,z) in prev_contexts]
  88.         for filler in non_fillers:
  89.             next_contexts += [(x+context+filler, y, z+filler)
  90.                               for (x,y,z) in prev_contexts]
  91.         next_contexts += [(x+context, y+NONE_PENALTY, z+['NONE'])
  92.                             for (x,y,z) in prev_contexts]
  93.         if len(next_contexts) > 5000:
  94.             print >>sys.stderr, "Too many alternatives, pruning some..."
  95.             next_contexts = next_contexts[:200]
  96.             next_contexts.sort(key=score_item, reverse=True)
  97.         return gen_items(contexts[1:], next_contexts)
  98.  
  99. def score_item(x):
  100.     model_score = model.score(' '.join(x[0]))
  101.     return model_score + x[1]
  102.  
  103. def main(argv=None):
  104.     global model, NONE_PENALTY
  105.     opts, args = oparse.parse_args(argv)
  106.     if not args:
  107.         oparse.print_help()
  108.         sys.exit(1)
  109.     NONE_PENALTY = opts.none_penalty
  110.     discomt_file = args[0]
  111.     print >>sys.stderr, "Loading language model..."
  112.     model = kenlm.LanguageModel(opts.lm)
  113.     mode = opts.fmt
  114.     print >>sys.stderr, "Processing stuff..."
  115.     if discomt_file.endswith('.gz'):
  116.         f_input = GzipFile(discomt_file)
  117.     else:
  118.         f_input = file(discomt_file)
  119.     for i, l in enumerate(f_input):
  120.         if l[0] == '\t':
  121.             if mode == 'replace':
  122.                 print l,
  123.                 continue
  124.             elif mode != 'scores':
  125.                 continue
  126.         classes_str, target, text_src, text, text_align = l.rstrip().split('\t')
  127.         if mode == 'scores':
  128.             print '%d\tTEXT\t%s\t%s\t%s' % (i, text_src, text, text_align)
  129.             if l[0] == '\t':
  130.                 continue
  131.         text = replace_re.sub('REPLACE', text)
  132.         targets = [x.strip() for x in target.split(' ')]
  133.         classes = [x.strip() for x in classes_str.split(' ')]
  134.         contexts = [x.strip().split() for x in text.split('REPLACE')]
  135.         #print "TARGETs:", target
  136.         #print "CONTEXTs: ", contexts
  137.         if len(contexts) > 5:
  138.             print >>sys.stderr, "#contexts:", len(contexts)
  139.         items = gen_items(contexts, [([], 0.0, [])])
  140.         items.sort(key = score_item, reverse=True)
  141.         pred_fillers = items[0][2]
  142.         pred_classes = [map_class(x) for x in pred_fillers]
  143.         if mode == 'scores':
  144.             #TODO compute individual scores for each slot
  145.             # and convert the scores to probabilities
  146.             scored_items = []
  147.             for item in items:
  148.                 words, penalty, fillers = item
  149.                 scored_items.append((words, score_item(item), fillers))
  150.             best_penalty = max([x[1] for x in items])
  151.             dists = [defaultdict(float) for k in items[0][2]]
  152.             for words, penalty, fillers in scored_items:
  153.                 exp_pty = math.exp(penalty - best_penalty)
  154.                 for j, w in enumerate(fillers):
  155.                     dists[j][w] += exp_pty
  156.             for j in xrange(len(items[0][2])):
  157.                 sum_all = sum(dists[j].values())
  158.                 if sum_all == 0:
  159.                     sum_all = 1.0
  160.                 items = [(k, v/sum_all) for k,v in dists[j].iteritems()]
  161.                 items.sort(key=lambda x: -x[1])
  162.                 print "%s\tITEM %d\t%s"%(
  163.                     i, j, ' '.join([
  164.                         '%s %.4f'%(x[0], x[1])
  165.                         for x in items if x[1] > 0.001]))
  166.         elif mode == 'both':
  167.             print "%s\t%s"%(target, ' '.join(pred_fillers))
  168.         elif mode == 'predicted':
  169.             print "%s\t%s"%(
  170.                 ' '.join(pred_classes),
  171.                 ' '.join(pred_fillers))
  172.         elif mode == 'replace':
  173.             print "%s\t%s\t%s\t%s\t%s"%(
  174.                 ' '.join(pred_classes),
  175.                 ' '.join(pred_fillers),
  176.                 text_src, text, text_align)
  177.         elif mode == 'compare':
  178.             assert len(classes) == len(pred_classes), (classes, pred_classes)
  179.             for gold, syst in zip(classes, pred_classes):
  180.                 print "%s\t%s"%(gold, syst)
  181.  
  182. if __name__ == '__main__':
  183.     main()

Paste is for source code and general debugging text.

Login or Register to edit, delete and keep track of your pastes and more.

Raw Paste

Login or Register to edit or fork this paste. It's free.