Пример #1
0
 def __init__(self, opt, reuse=False):
     # build the model
     # separate inference op with train op, especially in train and validation steps
     with tf.name_scope('inference'):
         LM = LanguageModel(opt, 'test', reuse=reuse)
         LM.build()
     self.model = LM
Пример #2
0
def main():

    # path = input ("Where is the dataset(assignment1-dataset.txt)? Give me folder path : ")
    path = "E:\\Universite\\8yy\\497\\NLP_Homeworks\\generateSentence"
    """ NN_grams = int(input('''Which grams model you want? 
                        \nUnigram -> 1, Bigram -> 2, Trigram ->3 and so on. 
                        \nGive me your wanted ''')) """

    # count_of_sentence = int(input("How many sentences you want : "))
    count_of_sentence = 3

    # maxlength_of_sentence = int(input("How many words does a sentence consist of : "))
    maxlength_of_sentence = 30

    NN_grams = 3

    sentences = dataset(path)

    ngram_language_model = LanguageModel(NN_grams)
    ngram_language_model.LoadDataset2Model(sentences)
    generated_sentences = ngram_language_model.Generate(
        maxlength_of_sentence, count_of_sentence)

    for i in range(count_of_sentence):
        print(str(i) + ".sentence :", generated_sentences[i])
        ngram_language_model.PPL(generated_sentences[i])
def get_model():
    # Load max_len, chars, mapping
    for_server = load(open('saved_models/for_server.pkl', 'rb'))
    sequence_max_len, chars, chars_mapping = for_server[0], for_server[
        1], for_server[2]

    language_model = LanguageModel(sequence_max_len, chars, chars_mapping)
    model = language_model.load_model()

    return language_model, model
Пример #4
0
def build_models(sentences, delta=0.5, sigma=0.1):
    trs_counts, ems_counts, tags, words = count_grams(sentences)

    # print('Number of tags:', len(tags))
    # print('Number of words:', len(words))

    trs_model = LanguageModel(trs_counts, tags, delta)
    ems_model = LanguageModel(ems_counts, words, sigma)

    return trs_model, ems_model
Пример #5
0
  def build_lang_models(self, delta = 0.1):
    pos_counts = Counter()
    neg_counts = Counter()

    for review in self.corpus:
      if review.label:
        pos_counts.update(review.features)
      else:
        neg_counts.update(review.features)

    self.pos_model = LanguageModel(pos_counts, self.vocab, delta)
    self.neg_model = LanguageModel(neg_counts, self.vocab, delta)
Пример #6
0
    def __init__(self, raw_segments, observation_sequences, label_sequences):
        super(Retrainer, self).__init__()
        self.raw_segments = raw_segments
        self.observation_sequences = observation_sequences
        self.label_sequences = label_sequences
        self.hmm_new = None
        self.feature_entity_list = FeatureEntityList()
        self.lm = LanguageModel()
        self.boosting_feature_generator = BoostingFeatureGenerator()

        self.DOMINANT_RATIO = 0.85  # dominant label ratio: set empirically

        self.retrain_with_boosting_features()
Пример #7
0
class TestLanguageModel(unittest.TestCase):
    def setUp(self):
        self.expected = {'hello': 10, 'world': 20, 'goodbye': 10}
        self.logger = logging.getLogger("TestLanguageModel")
        self.bk_model = LanguageModel(file='term_occurrences.txt')
        self.doc_model = LanguageModel(term_dict=self.expected)

    def test_populate_occurrences(self):
        self.logger.debug("Test Populate Occurrences")
        self.bk_model._populate_occurrences('term_occurrences.txt')
        #check it's not null and contains the following
        #hello 10
        #world 20
        #goodbye 10
        self.assertDictEqual(self.bk_model.occurrence_dict, self.expected)


#don't think I need to test doc and bk separately, bk read into dict and
#this has been tested

    def test_get_number_occurrences_doc(self):
        self.logger.debug("Test Get Number of Occurrences with dictionary")
        hello_count = self.doc_model.get_num_occurrences('hello')
        hello_expected = 10
        self.assertEquals(hello_count, hello_expected)

    def test_get_number_occurrences_bk(self):
        self.logger.debug("Test Get Number of Occurrences with file")
        #same test as _doc but with bk model
        hello_count = self.bk_model.get_num_occurrences('hello')
        hello_expected = 10
        self.assertEquals(hello_count, hello_expected)

    def test_when_no_occurrences_bk(self):
        self.logger.debug("Test when term does not exist")
        #same test as _doc but with bk model
        term_count = self.bk_model.get_num_occurrences('garble')
        term_expected = 0
        self.assertEquals(term_count, term_expected)

    def test_calc_total_occurrences_doc(self):
        self.logger.debug("Test calculate total occurrences with dictionary")
        actual_total = self.doc_model.get_total_occurrences()
        expected_total = 40
        self.assertEquals(expected_total, actual_total)

    def test_calc_total_occurrences_bk(self):
        self.logger.debug("Test calculate total occurrences with file")
        actual_total = self.bk_model.get_total_occurrences()
        expected_total = 40
        self.assertEquals(expected_total, actual_total)

    def test_get_term_probability(self):
        self.logger.debug("Test get term probability")
        #probability of hello is 10/40 = 0.25
        expected = 0.25
        actual = self.bk_model.get_term_prob('hello')
        self.assertEquals(expected, actual)
Пример #8
0
class TestLanguageModel(unittest.TestCase):

    def setUp(self):
        self.expected = {'hello': 10, 'world': 20, 'goodbye': 10}
        self.logger = logging.getLogger("TestLanguageModel")
        self.bk_model = LanguageModel(file='term_occurrences.txt')
        self.doc_model = LanguageModel(term_dict=self.expected)

    def test_populate_occurrences(self):
        self.logger.debug("Test Populate Occurrences")
        self.bk_model._populate_occurrences('term_occurrences.txt')
        #check it's not null and contains the following
        #hello 10
        #world 20
        #goodbye 10
        self.assertDictEqual(self.bk_model.occurrence_dict,self.expected)

#don't think I need to test doc and bk separately, bk read into dict and
#this has been tested

    def test_get_number_occurrences_doc(self):
        self.logger.debug("Test Get Number of Occurrences with dictionary")
        hello_count=self.doc_model.get_num_occurrences('hello')
        hello_expected=10
        self.assertEquals(hello_count,hello_expected)

    def test_get_number_occurrences_bk(self):
        self.logger.debug("Test Get Number of Occurrences with file")
        #same test as _doc but with bk model
        hello_count=self.bk_model.get_num_occurrences('hello')
        hello_expected=10
        self.assertEquals(hello_count,hello_expected)

    def test_when_no_occurrences_bk(self):
        self.logger.debug("Test when term does not exist")
        #same test as _doc but with bk model
        term_count=self.bk_model.get_num_occurrences('garble')
        term_expected=0
        self.assertEquals(term_count, term_expected)

    def test_calc_total_occurrences_doc(self):
        self.logger.debug("Test calculate total occurrences with dictionary")
        actual_total=self.doc_model.get_total_occurrences()
        expected_total=40
        self.assertEquals(expected_total,actual_total)

    def test_calc_total_occurrences_bk(self):
        self.logger.debug("Test calculate total occurrences with file")
        actual_total=self.bk_model.get_total_occurrences()
        expected_total=40
        self.assertEquals(expected_total,actual_total)

    def test_get_term_probability(self):
        self.logger.debug("Test get term probability")
        #probability of hello is 10/40 = 0.25
        expected=0.25
        actual=self.bk_model.get_term_prob('hello')
        self.assertEquals(expected,actual)
Пример #9
0
def run_comprehension_experiment(dataset, experiment_paths, experiment_config, image_ids=None):
  if experiment_config.exp_name == 'baseline' or experiment_config.exp_name.startswith('max_margin'):
    captioner = LanguageModel(experiment_config.test.lstm_model_file, experiment_config.test.lstm_net_file,
                              experiment_config.vocab_file, device_id=0)
  elif experiment_config.exp_name.startswith('mil_context'):
    captioner = MILContextLanguageModel(experiment_config.test.lstm_model_file, experiment_config.test.lstm_net_file,
                                        experiment_config.vocab_file, device_id=0)
  else:
      raise StandardError("Unknown experiment name: %s" % experiment_config.exp_name)

  if experiment_config.exp_name == 'baseline' or experiment_config.exp_name.startswith('max_margin'):
    experimenter = ComprehensionExperiment(captioner, dataset, image_ids=image_ids)
  elif experiment_config.exp_name.startswith('mil_context'):
    experimenter = MILContextComprehensionExperiment(captioner, dataset, image_ids=image_ids)
  else:
    raise StandardError("Unknown experiment name: %s" % experiment_config.exp_name)

  results = experimenter.comprehension_experiment(experiment_paths, proposal_source=experiment_config.test.proposal_source,
                                                  visualize=experiment_config.test.visualize)

  if isinstance(results,dict):
    for method in results:
      print "Results for method: %s" % method
      results_filename = '%s/%s_%s_%s_results.json' % (experiment_paths.retrieval_results, dataset.dataset_name,
                                                       experiment_config.test.tag, method)
      with open(results_filename,'w') as f: json.dump(results[method], f)
  else:
    results_filename = '%s/%s_%s_results.json' % (experiment_paths.retrieval_results, dataset.dataset_name,
                                                           experiment_config.test.tag)
    with open(results_filename,'w') as f: json.dump(results, f)
Пример #10
0
def decode_bigram(pinyins):
    """
    Use Viterbi to decode the pinyin sequence.
    """
    lm = LanguageModel.load_from_trained()
    for pinyin in pinyins:
        print(pinyin)
        print(viterbi_bigram(lm, pinyin))
Пример #11
0
    def setUp(self):

        self.doc = {'hello': 1, 'world': 2, 'help': 1}

        self.col = {
            'hello': 20,
            'world': 5,
            'good': 5,
            'bye': 15,
            'free': 1,
            'code': 1,
            'source': 1,
            'compile': 1,
            'error': 1
        }

        self.colLM = LanguageModel(term_dict=self.col)
        self.docLM = LanguageModel(term_dict=self.doc)
Пример #12
0
def decode_trigram(pinyins):
    """
    Use Viterbi to decode the pinyin sequence.
    """
    lm = LanguageModel.load_from_trained(model_path="models/3-lm.pkl")
    print("load {}-gram model successfully.".format(lm.ngram))

    for pinyin in pinyins:
        print(pinyin)
        print(viterbi_trigram(lm, pinyin))
Пример #13
0
 def __init__(self, threshold=0.96):
     basename = os.path.dirname(os.path.realpath(__file__))
     self.lm = LanguageModel()
     # Load spaCy
     self.nlp = spacy.load("en")
     # Hunspell spellchecker: https://pypi.python.org/pypi/CyHunspell
     # CyHunspell seems to be more accurate than Aspell in PyEnchant, but a bit slower.
     self.gb = Hunspell("en_GB-large",
                        hunspell_data_dir=basename + '/resources/spelling/')
     # Inflection forms: http://wordlist.aspell.net/other/
     self.gb_infl = loadWordFormDict(basename +
                                     "/resources/agid-2016.01.19/infl.txt")
     # List of common determiners
     self.determiners = {"", "the", "a", "an"}
     # List of common prepositions
     self.prepositions = {
         "", "about", "at", "by", "for", "from", "in", "of", "on", "to",
         "with"
     }
     self.threshold = threshold
Пример #14
0
    def __init__(self, raw_segments, observation_sequences, label_sequences):
        super(Retrainer, self).__init__()
        self.raw_segments = raw_segments
        self.observation_sequences = observation_sequences
        self.label_sequences = label_sequences
        self.hmm_new = None
        self.feature_entity_list = FeatureEntityList()
        self.lm = LanguageModel()
        self.boosting_feature_generator = BoostingFeatureGenerator()

        self.DOMINANT_RATIO = 0.85  # dominant label ratio: set empirically

        self.retrain_with_boosting_features()
 def _create_lm_from_counts(self, smoothing):
     lm = LanguageModel(part=self._part)
     lm.set_ngram_size(self._ngram_size)
     for context in self._lm_counts:
         total_notes_after_context = sum(self._lm_counts[context].values())
         if len(self._lm_counts[context].keys()) > 2:
             context_counts = self._lm_counts[context].items()
             for (note, count) in context_counts:
                 # approximately 4 octaves in our vocabulary (48 notes)
                 prob = (count + smoothing) / float(total_notes_after_context + (48 * smoothing))
                 lm.add_to_model(context, note, prob)
             lm.add_to_model(context, "<UNK>", (smoothing / float(total_notes_after_context + (48 * smoothing))))
     return lm
Пример #16
0
def cur_cost(current_hyp, source, phrase_probs, lang_model: LanguageModel):
    translated_phrases = [x for _,_,x in current_hyp]
    translated_tokens = []
    for phrase in translated_phrases:
        translated_tokens.extend(phrase.split(" "))
    lang_model_probs = lang_model.get_prob_sentance(translated_tokens,padded_left=True,padded_right=False)

    translation_prob = 1
    distortion_prob = 1
    last_end_f = -1
    for cur_phrase in current_hyp:
        start_f,end_f,e = cur_phrase
        foreign_phrase = " ".join(source[start_f:end_f+1])
        translation_prob *= phrase_probs[foreign_phrase][e]
        if last_end_f != -1:
            distortion_prob *= d(start_f,last_end_f)
        last_end_f = end_f
    # print(translation_prob)
    # print(distortion_prob)
    # print(lang_model_probs)
    return lang_model_probs*distortion_prob*translation_prob
Пример #17
0
    def __init__(self, opt, dict):
        super(Trainer, self).__init__()
        self.opt = opt
        self.dict = dict
        self.hier = opt.hier

        if opt.restore:
            if opt.hier:
                self.mlp, self.encoder = torch.load(opt.model)
            else:
                self.mlp = torch.load(opt.model)
                self.encoder = self.mlp.encoder

            self.mlp.epoch += 1
            print("Restoring MLP {} with epoch {}".format(
                opt.model, self.mlp.epoch))
        else:
            if opt.hier:
                glove_weights = build_glove(dict["w2i"]) if opt.glove else None

                self.encoder = apply_cuda(HierAttnEncoder(len(dict["i2w"]), opt.bowDim, opt.hiddenSize, opt, glove_weights))
                self.mlp = apply_cuda(HierAttnDecoder(len(dict["i2w"]), opt.bowDim, opt.hiddenSize, opt, glove_weights))
            else:
                self.mlp = apply_cuda(LanguageModel(self.dict, opt))
                self.encoder = self.mlp.encoder

            self.mlp.epoch = 0

        self.loss = apply_cuda(nn.NLLLoss(ignore_index=0))
        self.decoder_embedding = self.mlp.context_embedding

        if opt.hier:
            c = 0.9
            self.encoder_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.encoder.parameters()), self.opt.learningRate,
                                                                    momentum=c, weight_decay=c)
            self.optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.mlp.parameters()), self.opt.learningRate,
                                                                    momentum=c, weight_decay=c)
        else:
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.mlp.parameters()), self.opt.learningRate)  # Half learning rate
Пример #18
0
def train(train_x, train_y, word_dict, args):
    with tf.compat.v1.Session() as sess:
        # model = AutoEncoder(word_dict, MAX_DOCUMENT_LEN)
        model = LanguageModel(word_dict, MAX_DOCUMENT_LEN)
        global_steps = tf.Variable(0, trainable=False)
        params = tf.trainable_variables()
        gradients = tf.gradients(model.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        optimizer = tf.train.AdamOptimizer(0.001)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, params))

        loss_summary = tf.summary.scalar("loss", model.loss)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter("AutoEncoder", sess.graph)

        saver = tf.train.Saver(tf.global_variables())

        sess.run(tf.global_variables_initializer())

        def train_step(batch_x):
            feed_dict = {model.x: batch_x}
            _, step, summaries, loss = sess.run([train_op, global_steps, summary_op, model.loss],
                                                feed_dict=feed_dict)
            summary_writer.add_summary(summaries, step)

            if step % 100 == 0:
                print("step {0} : loss = {1}".format(step, loss))

        batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS)

        for batch_x, _ in batches:
            train_step(batch_x)
            step = tf.train.global_step(sess, global_steps)

            if step % 5000 == 0:
                saver.save(sess, os.path.join("AutoEncoder", "model", "model.ckpt"), global_step=step)
Пример #19
0
    def language_modeling(self):
        decoder = self.create_decoder()
        assert (os.path.exists(self.options.result_dir + 'model_dec'))
        self.load_decoder(decoder)

        encoder = self.create_encoder()
        assert (os.path.exists(self.options.result_dir + 'model_enc'))
        self.load_encoder(encoder)

        print('computing language model score...')

        test = self.reader.next_example(2)
        lm = LanguageModel(encoder, decoder)

        total_ll = 0
        total_tokens = 0
        for dataid, data in enumerate(test):
            s1, s2, s3, pos, act = data[0], data[1], data[2], data[3], data[4]
            if len(s1) <= 1:
                continue
            total_ll += lm(s1, s2, s3, pos, self.options.nsamples)
            total_tokens += len(s1)
        perp = compute_perplexity(total_ll, total_tokens)
        print('perplexity: {}'.format(perp))
Пример #20
0
    config = p.parse_args()

    return config


if __name__ == '__main__':
    config = define_argparser()

    loader = DataLoader(config.train,
                        config.valid,
                        batch_size=config.batch_size,
                        device=config.gpu_id,
                        max_length=config.max_length)
    model = LM(len(loader.text.vocab),
               word_vec_dim=config.word_vec_dim,
               hidden_size=config.hidden_size,
               n_layers=config.n_layers,
               dropout_p=config.dropout,
               max_length=config.max_length)

    # Let criterion cannot count PAD as right prediction, because PAD is easy to predict.
    loss_weight = torch.ones(len(loader.text.vocab))
    loss_weight[data_loader.PAD] = 0
    criterion = nn.NLLLoss(weight=loss_weight, size_average=False)

    print(model)
    print(criterion)

    if config.gpu_id >= 0:
        model.cuda(config.gpu_id)
        criterion.cuda(config.gpu_id)
Пример #21
0
class Retrainer(object):
    def __init__(self, raw_segments, observation_sequences, label_sequences):
        super(Retrainer, self).__init__()
        self.raw_segments = raw_segments
        self.observation_sequences = observation_sequences
        self.label_sequences = label_sequences
        self.hmm_new = None
        self.feature_entity_list = FeatureEntityList()
        self.lm = LanguageModel()
        self.boosting_feature_generator = BoostingFeatureGenerator()

        self.DOMINANT_RATIO = 0.85  # dominant label ratio: set empirically

        self.retrain_with_boosting_features()
    
    def retrain(self):
        self.hmm_new = HMM('retrainer', 6)
        self.hmm_new.train(self.observation_sequences, self.label_sequences, useLaplaceRule=False)  #important to set laplace to be no
    
    # With new features
    def retrain_with_boosting_features(self):
        # Build language model
        for raw_segment, label_sequence in zip(self.raw_segments, self.label_sequences):
            for token, label in zip(Tokens(raw_segment).tokens, label_sequence):
                self.lm.add(token, label)
        self.lm.prettify()
        self.token_BGM = self.lm.prettify_model
        self.pattern_BGM = None

        # Retrain
        self.hmm_new = HMM('retrainer', 6)
        partial_features = []
        for raw_segment in self.raw_segments:
            partial_features.append(BoostingFeatureGenerator(raw_segment, self.token_BGM, self.pattern_BGM).features)
        self.hmm_new.train(partial_features, self.label_sequences, useLaplaceRule=False)
        self.observation_sequences = partial_features


    def run(self):
        i = 0
        self.new_labels = []
        for raw_segment, label_sequence in zip(self.raw_segments, self.label_sequences):
            new_labels = self.hmm_new.decode(raw_segment)[1]
            self.new_labels.append(new_labels)
            tokens = Tokens(raw_segment).tokens
            feature_vectors = FeatureGenerator(raw_segment).features
            print i, ':  ', raw_segment
            for token, old_label, new_label, feature_vector in zip(tokens, label_sequence, new_labels, feature_vectors):
                print to_label(old_label), '\t', to_label(new_label), '\t', token
                self.feature_entity_list.add_entity(feature_vector, old_label, token)   #???? Old label first
            print '\n'
            i+=1

    def find_pattern(self):        
        self.hmm_new.feature_entity_list.print_all_entity()


    # Find the first tokens at VN boundaries
    def find_venue_boundary_tokens(self):
        recorder = {}
        for raw_segment, observation_sequence, label_sequence in zip(self.raw_segments, self.observation_sequences, self.label_sequences):
            first_target_label_flag = True
            tokens = Tokens(raw_segment).tokens
            for token, feature_vector, label in zip(tokens, observation_sequence, label_sequence):
                # First meet a VN label
                if label == 4 and first_target_label_flag:
                    key = token.lower()
                    if not key.islower():
                        continue
                    if recorder.has_key(key):
                        recorder[key] += 1
                    else:
                        recorder[key] = 1
                    first_target_label_flag = False

                elif (first_target_label_flag is False) and label in [0,1,3]:
                    first_target_label_flag = True

        for k,v in recorder.iteritems():
            print k, '\t', v
        return recorder


    # Learn the general order of structure of publications before moving forward
    def find_majority_structure(self):
        first_bit_counter = {'0': 0, '3': 0, '4':0, '5':0}
        overall_pattern_counter = {}
        for label_sequence in self.label_sequences:
            label = label_sequence[0]
            if label == 2:
                continue
            elif label == 5:
                continue
            elif label in [0,1]:
                first_bit_counter['0'] += 1
            else:
                first_bit_counter[str(label)] += 1

            pattern = []
            for label in label_sequence:
                if label in [2,5]:
                    continue
                elif label in [0,1]:
                    if 0 in pattern:
                        continue
                    else:
                        pattern.append(0)
                elif label == 3:
                    if 3 in pattern:
                        continue
                    else:
                        pattern.append(3)
                elif label == 4:
                    if 4 in pattern:
                        continue
                    else:
                        pattern.append(4)
            key = str(pattern)
            if overall_pattern_counter.has_key(key):
                overall_pattern_counter[key] += 1
            else:
                overall_pattern_counter[key] = 1

        # Inducing the structure
        sorted_firstbit_counter = sorted(first_bit_counter.iteritems(), key=operator.itemgetter(1), reverse=True)
        sorted_pattern_counter = sorted(overall_pattern_counter.iteritems(), key=operator.itemgetter(1), reverse=True)
        print '===========================================', sorted_pattern_counter
        return int(sorted_firstbit_counter[0][0]), ast.literal_eval(sorted_pattern_counter[0][0]), (float(sorted_pattern_counter[0][1]))/len(self.label_sequences)

    def run_with_boosting_features(self):
        i = 0
        self.new_labels = []
        self.combined_labels = []

        for raw_segment, label_sequence in zip(self.raw_segments, self.label_sequences):
            feature_vectors, new_labels = self.hmm_new.decode(raw_segment, True, True, self.token_BGM, self.pattern_BGM)
            self.new_labels.append(new_labels)
            tokens = Tokens(raw_segment).tokens
            print i, ':  ', raw_segment

            # Combination step: 
            tmp_combined_labels = []    # the decided combined labels so far
            for token, old_label, new_label, feature_vector in zip(tokens, label_sequence, new_labels, feature_vectors):

                # Combine old and new labels to come out a combined label, and deciding...
                combined_label = -1
                
                if old_label == new_label:
                    combined_label = new_label
                    tmp_combined_labels.append(new_label)
                
                # Combine compatible labels: FN and LN
                elif old_label in [0,1] and new_label in [0,1]:
                    combined_label = old_label
                    tmp_combined_labels.append(new_label)
                
                # Combine labels that are not compatible
                else:   
                    tmp_feature_entity = self.hmm_new.feature_entity_list.lookup(feature_vector)    # Get the Background knowledge provided the feature vector: the language feature model
                    sorted_label_distribution = sorted(tmp_feature_entity.label_distribution.iteritems(), key=operator.itemgetter(1), reverse=True)
                    total_label_occurence = float(sum(tmp[1] for tmp in sorted_label_distribution))

                    

                    # ============================================================================================
                    # ============================================================================================
                    # ???? Experimenting: removing the low prob label distribution; FAILURE; ARCHIVED HERE AND DEPRECATED 
                    # sorted_label_distribution = []
                    # sum_prob = 0.0
                    # for pair in tmp_sorted_label_distribution:
                    #     sorted_label_distribution.append(pair)
                    #     sum_prob += pair[1]
                    #     if sum_prob/total_label_occurence >= 0.90:
                    #         break
                    # ============================================================================================
                    # ============================================================================================



                    # Dominant label case: Iterate from the highest label stats according to this feature vector:
                    for label_frequency in sorted_label_distribution:
                        if int(label_frequency[0]) in [old_label, new_label] and (label_frequency[1]/total_label_occurence)>=self.DOMINANT_RATIO:
                            print 'Dominant labels'
                            # Check for constraint:
                            tmp_label_to_check = int(label_frequency[0])
                            
                            # Find last occurence position of this label
                            if tmp_label_to_check not in [0,1]:
                                last_occurence = ''.join([str(c) for c in tmp_combined_labels]).rfind(str(tmp_label_to_check))
                            elif tmp_label_to_check in [0,1]:
                                last_occurence_0 = ''.join([str(c) for c in tmp_combined_labels]).rfind('0')
                                last_occurence_1 = ''.join([str(c) for c in tmp_combined_labels]).rfind('1')
                                last_occurence = max(last_occurence_0, last_occurence_1)

                            # Checking constraints by simplifying what we did in viterbi
                            if last_occurence == -1 or last_occurence == (len(tmp_combined_labels)-1):  # Never occurred, or last occurence is the last label
                                # When we are deciding the first label
                                if len(tmp_combined_labels) == 0:
                                    first_bit = self.find_majority_structure()[0]
                                    if first_bit == 0 and tmp_label_to_check not in [0,1]:
                                        continue
                                    if first_bit == 3 and tmp_label_to_check != 3:
                                        continue

                                # VN CANNOT FOLLOW TI W/O DL constraint
                                if tmp_label_to_check == 4 and tmp_combined_labels[-1] == 3:
                                    continue
                            elif tmp_label_to_check in [0,1]:
                                flag = False
                                for j in range(last_occurence, len(tmp_combined_labels)):
                                    if tmp_combined_labels[j] not in [0,1,2]:
                                        flag = True
                                        break
                                if flag:
                                    continue
                            elif tmp_label_to_check == 3:
                                continue
                            elif tmp_label_to_check == 4:
                                if tmp_combined_labels[-1] == 3:    #????
                                    continue

                            combined_label = tmp_label_to_check
                            tmp_combined_labels.append(tmp_label_to_check)
                            break
                    
                    # No dominance case OR Dominance-fail-due-to-constraint case: Find relatively if the label with higher possibility follow the constraint of publication order
                    if combined_label == -1:
                        # Iterate from the highest label stats according to this feature vector:

                        for label_frequency in sorted_label_distribution:
                            breakout_flag = False
                            #Test against constraints
                            # 1. DL separate labels principle
                            # 2. AU-TI-VN Order 
                            if int(label_frequency[0]) in [old_label, new_label]:
                                tmp_label_to_check = int(label_frequency[0])
                                
                                # find structure of the order, and find what have appeared, and so predict what to be appear next
                                structure_overview = []     #will record the order in big sense: 0,3,4/4,0,3
                                for tmp_combined_label in tmp_combined_labels:
                                    if tmp_combined_label in [2,5]:
                                        continue                                            
                                    elif tmp_combined_label in [0,1]:
                                        if 0 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(0)
                                    elif tmp_combined_label == 3:
                                        if 3 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(3)
                                    elif tmp_combined_label == 4:
                                        if 4 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(4)
                                # Based on the structure overview, find what should appear next
                                appear_next = []
                                if structure_overview == [0]:
                                    appear_next = [0,1,3,2,5]
                                elif structure_overview == [3]:
                                    appear_next = [3,0,1,2,5]
                                elif structure_overview == [0,3]:
                                    appear_next = [3,4,2,5]
                                elif structure_overview == [3,0]:
                                    appear_next = [0,1,4,2,5]
                                elif structure_overview == [0,3,4]:
                                    appear_next = [4,2,5]
                                elif structure_overview == [3,0,4]:
                                    appear_next = [4,2,5]
                                else:   #weird case
                                    print 'Weird structure! Weird case!'
                                    if tmp_feature_entity.label_distribution[str(old_label)] > tmp_feature_entity.label_distribution[str(new_label)]:
                                        tmp_label_to_check_list = [old_label, new_label]
                                    else:
                                        tmp_label_to_check_list = [new_label, old_label]
                                    # Apply constraints here too
                                    for tmp_label_to_check in tmp_label_to_check_list:
                                        if tmp_label_to_check not in [0,1]:
                                            last_occurence = ''.join([str(c) for c in tmp_combined_labels]).rfind(str(tmp_label_to_check))
                                        elif tmp_label_to_check in [0,1]:
                                            last_occurence_0 = ''.join([str(c) for c in tmp_combined_labels]).rfind('0')
                                            last_occurence_1 = ''.join([str(c) for c in tmp_combined_labels]).rfind('1')
                                            last_occurence = max(last_occurence_0, last_occurence_1)

                                        # Checking constraints by simplifying what we did in viterbi
                                        if last_occurence == -1 or last_occurence == (len(tmp_combined_labels)-1):
                                            # When we are deciding the first label
                                            if len(tmp_combined_labels) == 0:
                                                first_bit = self.find_majority_structure()[0]
                                                if first_bit == 0 and tmp_label_to_check not in [0,1]:
                                                    continue
                                                if first_bit == 3 and tmp_label_to_check != 3:
                                                    continue
                                            try:
                                                if tmp_label_to_check == 4 and tmp_combined_labels[-1] == 3:
                                                    continue
                                            except:
                                                continue
                                        elif tmp_label_to_check in [0,1]:
                                            flag = False
                                            for j in range(last_occurence, len(tmp_combined_labels)):
                                                if tmp_combined_labels[j] not in [0,1,2]:
                                                    flag = True
                                                    break
                                            if flag:
                                                continue
                                        elif tmp_label_to_check == 3:
                                            continue
                                        elif tmp_label_to_check == 4:
                                            if tmp_combined_labels[-1] == 3:
                                                continue

                                        combined_label = tmp_label_to_check
                                        tmp_combined_labels.append(combined_label)
                                        breakout_flag = True
                                        break

                                if breakout_flag:
                                    break
                                if tmp_label_to_check in appear_next:
                                    # Then check constraint. find last occurence, DL constraints
                                    # Just need to check DL constraints, no need to verify more on tokens, assume token verification is done in the first iteration
                                    if tmp_label_to_check not in [0,1]:
                                        last_occurence = ''.join([str(c) for c in tmp_combined_labels]).rfind(str(tmp_label_to_check))
                                    elif tmp_label_to_check in [0,1]:
                                        last_occurence_0 = ''.join([str(c) for c in tmp_combined_labels]).rfind('0')
                                        last_occurence_1 = ''.join([str(c) for c in tmp_combined_labels]).rfind('1')
                                        last_occurence = max(last_occurence_0, last_occurence_1)

                                    # Checking constraints by simplifying what we did in viterbi
                                    if last_occurence == -1 or last_occurence == (len(tmp_combined_labels)-1):
                                        if tmp_label_to_check == 4 and tmp_combined_labels[-1] == 3: #Hardcode rule [2013/07/23]: For VN, cannot directly follow a TI without DL???? may remove on real effect
                                            continue
                                    elif tmp_label_to_check in [0,1]:
                                        flag = False
                                        for j in range(last_occurence, len(tmp_combined_labels)):
                                            if tmp_combined_labels[j] not in [0,1,2]:
                                                flag = True
                                                break
                                        if flag:
                                            continue

                                    elif tmp_label_to_check == 3:
                                        continue
                                        # flag = False
                                        # for j in range(last_occurence, len(tmp_combined_labels)):
                                        #     if tmp_combined_labels[j] not in [3,2]:
                                        #         flag = True
                                        #         break
                                        # if flag:
                                        #     continue

                                    elif tmp_label_to_check == 4:
                                        if tmp_combined_labels[-1] == 3:    #????
                                            continue

                                    # elif tmp_label_to_check == 2:
                                    # elif tmp_label_to_check == 5:
                                    
                                    # Otherwise, pass
                                    log_err('\t\t' + str(i) + 'Should combine this one')
                                    combined_label = tmp_label_to_check
                                    tmp_combined_labels.append(tmp_label_to_check)
                                    # combined_label = (tmp_label_to_check, sorted_label_distribution)
                                    break
                                    
                                else:
                                    continue

                        # Debug
                        if combined_label == -1:
                            log_err(str(i) + 'problem')
                            combined_label = (appear_next, sorted_label_distribution)
                            tmp_combined_labels.append(-1)


            # Final check the accordance with the major order, ideally, all records under one domain should have the same order... PS very ugly code I admit
            print '==========================tmp_combined_labels', tmp_combined_labels
            majority_order_structure = self.find_majority_structure()[1]
            majority_rate = self.find_majority_structure()[2]
            tmp_combined_labels_length = len(tmp_combined_labels)
            if majority_rate > 0.80 and majority_order_structure == [0,3,4]:
                # p1(phase1): author segments
                for p1 in range(tmp_combined_labels_length):
                    if tmp_combined_labels[p1] in [0,1,2,5]:
                        continue
                    else:
                        break

                # p2(phase2): title segments
                for p2 in range(p1, tmp_combined_labels_length):
                    if tmp_combined_labels[p2] == 3:
                        continue
                    else:
                        break

                #p3(phase3): venue segments
                for p3 in range(p2, tmp_combined_labels_length):
                    if tmp_combined_labels[p3] in [2,5,4]:
                        continue
                    else:
                        break

                # Decision
                if p1 == 0:
                    print 'Houston we got a SERIOUS problem!'
                    log_err('Houston we got a SERIOUS problem!!!!!!!!')

                if p2 == p1:
                    print 'Houston we got a problem!'
                    for sp2 in range(p2, tmp_combined_labels_length):
                        if tmp_combined_labels[sp2] != 2:
                            tmp_combined_labels[sp2] = 3
                        else:
                            break   # should fix common mislabeling at this point now??????????


            # elif majority_rate > 0.80 and majority_order_structure == [3,0,4]:    # ???? not sure if this is normal
            #     # p1(phase1): title segments
            #     for p1 in range(tmp_combined_labels_length):
            #         if tmp_combined_labels[p1] in [3]:
            #             continue
            #         else:
            #             break

            #     # p2(phase2): author segments
            #     for p2 in range(p1, tmp_combined_labels_length):
            #         if tmp_combined_labels[p2] == 3:
            #             continue
            #         else:
            #             break

            #     #p3(phase3): venue segments
            #     for p3 in range(p2, tmp_combined_labels_length):
            #         if tmp_combined_labels[p3] in [2,5,4]:
            #             continue
            #         else:
            #             break

            #     # Decision
            #     if p1 == 0:
            #         print 'Houston we got a SERIOUS problem!'
            #         log_err('Houston we got a SERIOUS problem!!!!!!!!')

            #     if p2 == p1:
            #         print 'Houston we got a problem!'
            #         for sp2 in range(p2, tmp_combined_labels_length):
            #             if tmp_combined_labels[sp2] != 2:
            #                 tmp_combined_labels[sp2] = 3
            #             else:
            #                 break
            for old_label, new_label, tmp_combined_label, token, feature_vector in zip(label_sequence, new_labels, tmp_combined_labels, tokens, feature_vectors):
                print to_label(old_label), '\t', to_label(new_label), '\t', to_label(tmp_combined_label), '\t', token, '\t', feature_vector
            print '\n'
            i+=1
Пример #22
0
vocab.save_to_files(args.serialization_path + "/vocabulary")

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)

word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

lstm = PytorchSeq2SeqWrapper(
    torch.nn.LSTM(EMBEDDING_DIM,
                  HIDDEN_DIM,
                  batch_first=True,
                  dropout=args.drop))

lstm_model = LanguageModel(contextualizer=lstm,
                           text_field_embedder=word_embeddings,
                           vocab=vocab)

transformer = MultiHeadSelfAttention(attention_dim=16,
                                     input_dim=EMBEDDING_DIM,
                                     num_heads=2,
                                     values_dim=16,
                                     attention_dropout_prob=args.drop)

transformer_model = LanguageModel(contextualizer=transformer,
                                  text_field_embedder=word_embeddings,
                                  vocab=vocab)

stacked_transformer = StackedSelfAttentionEncoder(
    input_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
Пример #23
0
        lines += [line]

    return lines


if __name__ == '__main__':
    config = define_argparser()

    loader = DataLoader(config.train,
                        config.valid,
                        batch_size=config.batch_size,
                        device=config.gpu_id,
                        max_length=config.max_length)
    model = LM(len(loader.text.vocab),
               word_vec_dim=config.word_vec_dim,
               hidden_size=config.hidden_size,
               n_layers=config.n_layers,
               dropout_p=config.dropout,
               max_length=config.max_length)

    # Let criterion cannot count PAD as right prediction, because PAD is easy to predict.
    loss_weight = torch.ones(len(loader.text.vocab))
    loss_weight[data_loader.PAD] = 0
    criterion = nn.NLLLoss(weight=loss_weight, size_average=False)

    print(model)
    print(criterion)

    if config.gpu_id >= 0:
        model.cuda(config.gpu_id)
        criterion.cuda(config.gpu_id)
Пример #24
0
fixed_queries_to_words = pd.Series(fixed_queries).replace(
    '[' + punctuation + ']', '', regex=True).str.split()
fixed_words = flatten_list(fixed_queries_to_words)

original_queries_to_words = pd.Series(original_queries).replace(
    '[' + punctuation + ']', '', regex=True).str.split()
original_words = flatten_list(original_queries_to_words)

error_model = ErrorModel()

for original, fixed in zip(original_queries_to_words, fixed_queries_to_words):
    number_of_words = min(len(original), len(fixed))
    for i in range(number_of_words):
        error_model.update_statistics(original[i], fixed[i])

error_model.calculate_weights()

language_model = LanguageModel()

for fixed in fixed_queries_to_words:
    for word in fixed:
        language_model.update_statistics(word)

language_model.calculate_weights()

error_model.store_json('error.json')
language_model.store_json('language.json')

# In[ ]:
Пример #25
0
class Experiment(object):

    def __init__(self):

        print('----- Loading data -----')
        self.train_set = Amazon('train', False)
        self.test_set = Amazon('test', False)
        self.val_set = Amazon('dev', False)
        print('The train set has {} items'.format(len(self.train_set)))
        print('The test set has {} items'.format(len(self.test_set)))
        print('The val set has {} items'.format(len(self.val_set)))

        self.vocab = self.train_set.vocab

        # Load pretrained classifier for evaluation.
        self.clseval = ClassifierEval(config.test_cls, config.dataset)
        self.clseval.restore_model()

        print('----- Loading model -----')
        embedding = self.vocab.embedding
        self.Emb = nn.Embedding.from_pretrained(embedding.clone(), freeze=False)
        self.Classifier = AttenClassifier(emb_dim=config.emb_dim,
                                          dim_h=config.dim_h,
                                          n_layers=config.n_layers,
                                          dropout=config.dropout,
                                          bi=config.bidirectional)
        self.Gate0 = Gate(dim_h=config.dim_h, n_layers=config.n_layers, dropout=config.dropout, bi=config.bidirectional,
                          temperature=config.temp_gate, embedding=embedding.clone())
        self.Gate1 = Gate(dim_h=config.dim_h, n_layers=config.n_layers, dropout=config.dropout, bi=config.bidirectional,
                          temperature=config.temp_gate, embedding=embedding.clone())
        self.InsFront0 = InsFront(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                  dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                  temperature=config.temp_sub)
        self.InsFront1 = InsFront(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                  dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                  temperature=config.temp_sub)
        self.InsBehind0 = InsBehind(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                    dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                    temperature=config.temp_sub)
        self.InsBehind1 = InsBehind(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                    dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                    temperature=config.temp_sub)
        self.Replace0 = Replace(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                temperature=config.temp_sub)
        self.Replace1 = Replace(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                temperature=config.temp_sub)
        self.Del0 = Delete()  # Not a module.
        self.Del1 = Delete()  # Not a module.
        # Language models.
        if config.train_mode == 'pto':
            self.LMf0 = LanguageModel(config.dataset, direction='forward', sentiment=0)
            self.LMf0.model_load()
            self.LMf1 = LanguageModel(config.dataset, direction='forward', sentiment=1)
            self.LMf1.model_load()
            self.LMb0 = LanguageModel(config.dataset, direction='backward', sentiment=0)
            self.LMb0.model_load()
            self.LMb1 = LanguageModel(config.dataset, direction='backward', sentiment=1)
            self.LMb1.model_load()
        # Auxiliary classifier.
        self.Aux_Emb = nn.Embedding.from_pretrained(embedding.clone(), freeze=False)
        self.Aux_Classifier = AttenClassifier(emb_dim=config.emb_dim,
                                              dim_h=config.dim_h,
                                              n_layers=config.n_layers,
                                              dropout=config.dropout,
                                              bi=config.bidirectional)

        self.modules = ['Emb', 'Classifier',
                        'Gate0', 'Gate1',
                        'InsFront0', 'InsFront1',
                        'InsBehind0', 'InsBehind1',
                        'Replace0', 'Replace1',
                        'Aux_Classifier', 'Aux_Emb']
        for module in self.modules:
            print('--- {}: '.format(module))
            print(getattr(self, module))
            setattr(self, module, gpu_wrapper(getattr(self, module)))

        self.scopes = {
            'emb': ['Emb'],
            'cls': ['Classifier'],
            'aux_cls': ['Aux_Classifier', 'Aux_Emb'],
            'gate': ['Gate0', 'Gate1'],
            'oprt': ['InsFront0', 'InsFront1', 'InsBehind0', 'InsBehind1', 'Replace0', 'Replace1'],
        }
        for scope in self.scopes.keys():
            setattr(self, scope + '_lr', getattr(config, scope + '_lr'))

        self.iter_num = -1
        self.logger = None
        if config.train_mode == 'pto':
            pass
        elif config.train_mode == 'aux-cls-only':
            self.train_set = Amazon('train', True)
            self.test_set = Amazon('test', True)
            self.val_set = Amazon('dev', True)
            self.best_acc = 0
        elif config.train_mode == 'cls-only':
            self.best_acc = 0
        self.criterionSeq, self.criterionCls, self.criterionRL, self.criterionBack = None, None, None, None

    def restore_model(self, modules):
        """Restore the trained generators and discriminator."""
        print('Loading the trained best models...')
        for module in modules:
            path = os.path.join(config.save_model_dir, 'best-{}.ckpt'.format(module))
            getattr(self, module).load_state_dict(torch.load(path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from utils.logger import Logger
        self.logger = Logger(config.log_dir)

    def log_step(self, loss):
        # Log loss.
        for loss_name, value in loss.items():
            self.logger.scalar_summary(loss_name, value, self.iter_num)
        # Log learning rate.
        for scope in self.scopes:
            self.logger.scalar_summary('{}/lr'.format(scope), getattr(self, scope + '_lr'), self.iter_num)

    def save_step(self, modules, use_iter=False):
        if use_iter:
            for module in modules:
                path = os.path.join(config.save_model_dir, '{}-{}.ckpt'.format(self.iter_num, module))
                torch.save(getattr(self, module).state_dict(), path)
        else:
            for module in modules:
                path = os.path.join(config.save_model_dir, 'best-{}.ckpt'.format(module))
                torch.save(getattr(self, module).state_dict(), path)
        print('Saved model checkpoints into {}...\n\n\n\n\n\n\n\n\n\n\n\n'.format(config.save_model_dir))

    def zero_grad(self):
        for scope in self.scopes:
            getattr(self, scope + '_optim').zero_grad()

    def step(self, scopes, clip_norm=float('inf'), clip_value=float('inf')):
        trainable = []
        for scope in scopes:
            trainable.extend(getattr(self, 'trainable_' + scope))
        # Clip on all parameters.
        if clip_norm < float('inf'):
            clip_grad_norm_(parameters=trainable, max_norm=config.clip_norm)
        if clip_value < float('inf'):
            clip_value = float(config.clip_value)
            for p in filter(lambda p: p.grad is not None, trainable):
                p.grad.data.clamp_(min=-clip_value, max=clip_value)
        # Backward.
        for scope in scopes:
            getattr(self, scope + '_optim').step()

    def update_lr(self):
        for scope in self.scopes:
            setattr(self, scope + '_lr', getattr(self, scope + '_lr') * (1 - float(config.lr_decay_step / config.num_iters_decay)))
            for param_group in getattr(self, scope + '_optim').param_groups:
                param_group['lr'] = getattr(self, scope + '_lr')

    def set_requires_grad(self, modules, requires_grad):
        if not isinstance(modules, list):
            modules = [modules]
        for module in modules:
            for param in getattr(self, module).parameters():
                param.requires_grad = requires_grad

    def set_training(self, mode):
        for module in self.modules:
            getattr(self, module).train(mode=mode)

    def restore_pretrained(self, modules):
        for module in modules:
            path = os.path.join('pretrained/pretrained-{}.ckpt'.format(module))
            getattr(self, module).load_state_dict(torch.load(path, map_location=lambda storage, loc: storage))

    def train(self):

        # Logging.
        if config.use_tensorboard:
            self.build_tensorboard()

        # Load pretrained.
        if config.train_mode == 'cls-only':
            pass
        elif config.train_mode == 'pto':
            self.restore_pretrained(['Classifier', 'Emb', 'Aux_Classifier', 'Aux_Emb'])
        elif config.train_mode == 'aux-cls-only':
            pass
        else:
            raise ValueError()

        # Set trainable parameters, according to the frozen parameter list.
        for scope in self.scopes.keys():
            trainable = []
            for module in self.scopes[scope]:
                for k, v in getattr(self, module).state_dict(keep_vars=True).items():
                    # k is the parameter name; v is the parameter value.
                    if v.requires_grad:
                        trainable.append(v)
                        print("[{} Trainable:]".format(module), k)
                    else:
                        print("[{} Frozen:]".format(module), k)
            setattr(self, scope + '_optim', Adam(trainable, getattr(self, scope + '_lr'), [config.beta1, config.beta2]))
            setattr(self, 'trainable_' + scope, trainable)

        # Build criterion.
        self.criterionSeq = SeqLoss(voc_size=self.train_set.vocab.size, pad=self.train_set.pad,
                                    end=self.train_set.eos, unk=self.train_set.unk)
        self.criterionCls = nn.BCELoss()
        self.criterionBack = BackLoss(reduce=False)
        self.criterionRL = RewardCriterion()

        # Train.
        epoch = 0
        while True:
            self.train_epoch(epoch_idx=epoch)
            epoch += 1
            if self.iter_num >= config.num_iters:
                break

        self.test()

    def test(self):
        config.batch_size = 500

        self.restore_pretrained(['Classifier', 'Emb', 'Replace0', 'Replace1',
                                 'InsFront0', 'InsFront1', 'InsBehind0', 'InsBehind1', 'Aux_Classifier', 'Aux_Emb'])

        self.valtest(val_or_test="test",
                     mode='multi-steps',
                     _pow_lm=(1 / config.s[0], 1 / config.s[1]),
                     cls_stop=config.cls_stop,
                     max_iter=config.max_iter)
    def train_epoch(self, epoch_idx):

        loader = DataLoader(self.train_set, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, drop_last=True)
        self.set_training(mode=True)

        with tqdm(loader) as pbar:
            for data in pbar:
                self.iter_num += 1
                loss = {}

                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
                bare_0, _, _, len_0, y_0, _, bare_1, _, _, len_1, y_1, _ = self.preprocess_data(data)
                null_mask_0 = bare_0.eq(self.train_set.pad)
                null_mask_1 = bare_1.eq(self.train_set.pad)

                # =================================================================================== #
                #                                       2. cls-only                                   #
                # =================================================================================== #
                if config.train_mode == 'cls-only':
                    # ----- Forward pass of classification -----
                    emb_bare_0 = self.Emb(bare_0)
                    emb_bare_1 = self.Emb(bare_1)
                    cls_0, att_0 = self.Classifier(emb_bare_0, len_0, null_mask_0)
                    cls_1, att_1 = self.Classifier(emb_bare_1, len_1, null_mask_1)
                    # ----- Classification loss -----
                    cls_loss_0 = self.criterionCls(cls_0, y_0)
                    cls_loss_1 = self.criterionCls(cls_1, y_1)
                    cls_loss = cls_loss_0 + cls_loss_1
                    # ----- Logging -----
                    loss['Cls/L-0'] = round(cls_loss_0.item(), ROUND)
                    loss['Cls/L-1'] = round(cls_loss_1.item(), ROUND)
                    # ----- Backward for scopes: ['emb', 'cls'] -----
                    self.zero_grad()
                    cls_loss.backward()
                    self.step(['emb', 'cls'])
                elif config.train_mode == 'aux-cls-only':
                    # ----- Forward pass of classification -----
                    emb_bare_0 = self.Aux_Emb(bare_0)
                    emb_bare_1 = self.Aux_Emb(bare_1)
                    cls_0, att_0 = self.Aux_Classifier(emb_bare_0, len_0, null_mask_0)
                    cls_1, att_1 = self.Aux_Classifier(emb_bare_1, len_1, null_mask_1)
                    # ----- Classification loss -----
                    cls_loss_0 = self.criterionCls(cls_0, y_0)
                    cls_loss_1 = self.criterionCls(cls_1, y_1)
                    cls_loss = cls_loss_0 + cls_loss_1
                    # ----- Logging -----
                    loss['Cls/L-0'] = round(cls_loss_0.item(), ROUND)
                    loss['Cls/L-1'] = round(cls_loss_1.item(), ROUND)
                    # ----- Backward for scopes: ['emb', 'cls'] -----
                    self.zero_grad()
                    cls_loss.backward()
                    self.step(['aux_cls'])

                # =================================================================================== #
                #                                        3. pto                                       #
                # =================================================================================== #
                elif config.train_mode == 'pto':

                    # ----- Forward pass of classification -----
                    emb_bare_0 = self.Emb(bare_0)
                    emb_bare_1 = self.Emb(bare_1)
                    cls_0, att_0 = self.Classifier(emb_bare_0, len_0, null_mask_0)
                    cls_1, att_1 = self.Classifier(emb_bare_1, len_1, null_mask_1)
                    # ----- Classification loss -----
                    cls_loss_0 = self.criterionCls(cls_0, y_0)
                    cls_loss_1 = self.criterionCls(cls_1, y_1)
                    cls_loss = cls_loss_0 + cls_loss_1
                    # ----- Logging -----
                    loss['Cls/L-0'] = round(cls_loss_0.item(), ROUND)
                    loss['Cls/L-1'] = round(cls_loss_1.item(), ROUND)
                    # ----- Backward for scopes: ['emb', 'cls'] -----
                    self.zero_grad()
                    cls_loss.backward()
                    self.step(['emb', 'cls'])

                    #################
                    # 0 --> 1 --> 0 #
                    #################
                    att_pg_0, Rep_xbar_pg_0, Rep_XE2_0, IF_xbar_pg_0, IB_xbar_pg_0, Del_ib_XE2_0, Del_if_XE2_0 = self.forward_pto(bare_0, len_0, null_mask_0, direction=0, loss=loss)

                    #################
                    # 1 --> 0 --> 1 #
                    #################
                    att_pg_1, Rep_xbar_pg_1, Rep_XE2_1, IF_xbar_pg_1, IB_xbar_pg_1, Del_ib_XE2_1, Del_if_XE2_1 = self.forward_pto(bare_1, len_1, null_mask_1, direction=1, loss=loss)

                    #######################
                    # Combine and logging #
                    #######################
                    tot_loss = att_pg_0 + att_pg_1 + \
                               Rep_xbar_pg_0 + Rep_XE2_0 + \
                               Rep_xbar_pg_1 + Rep_XE2_1 + \
                               IF_xbar_pg_0 + IB_xbar_pg_0 + \
                               Del_ib_XE2_0 + Del_if_XE2_0 + \
                               IF_xbar_pg_1 + IB_xbar_pg_1 + \
                               Del_ib_XE2_1 + Del_if_XE2_1

                    # ----- Backward for scopes: ['emb', 'cls', 'oprt'] -----
                    self.zero_grad()
                    tot_loss.backward()
                    self.step(['emb', 'cls', 'oprt'])

                else:
                    raise ValueError()

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                verbose = False
                if verbose:
                    display = ', '.join([key + ':' + pretty_string(loss[key]) for key in loss.keys()])
                    pbar.set_description_str(display)

                # Print out training information.
                if self.iter_num % config.log_step == 0 and config.use_tensorboard:
                    self.log_step(loss)

                # Validation.
                if self.iter_num % config.sample_step == 0:
                    if config.train_mode == 'pto':
                        self.valtest(val_or_test="val",
                                     mode='multi-steps',
                                     _pow_lm=(1 / config.s[0], 1 / config.s[1]),
                                     cls_stop=config.cls_stop,
                                     max_iter=config.max_iter)
                    elif config.train_mode == 'aux-cls-only':
                        self.valtest('val', 'aux-cls')
                    elif config.train_mode == 'cls-only':
                        self.valtest('val', 'cls')
                    else:
                        raise ValueError()

                # Decay learning rates.
                if self.iter_num % config.lr_decay_step == 0 and \
                        self.iter_num > (config.num_iters - config.num_iters_decay):
                    self.update_lr()

    def forward_pto(self, bare, seq_len, null_mask, direction, loss):
        # ----- Classify -----
        cls, att = self.Classifier(self.Emb(bare), seq_len, null_mask)
        cls = cls.detach()

        # ----- Sample attention -----
        hard_att, att_prob = sample_2d(probs=att, temperature=config.temp_att)

        # ----- Fix oprt_idx -----
        B = seq_len.shape[0]
        oprt_idx = torch.zeros_like(seq_len)
        oprt_idx[:B // 8] = 0  # InsFront.
        oprt_idx[B // 8: 2 * B // 8] = 1  # InsBehind.
        oprt_idx[2 * B // 8: 7 * B // 8] = 2  # Replace.
        oprt_idx[7 * B // 8:] = 3  # Delete.

        # ----- Process each operation -----
        bare_bar = torch.zeros_like(bare)
        T_bar = torch.zeros_like(seq_len)

        # ----- InsFront -----
        IF_idx = torch.nonzero(oprt_idx == 0).view(-1)
        IF_bare_bar, IF_xbar_prob, _ = getattr(self, 'InsFront' + str(direction))(bare[IF_idx, :], hard_att[IF_idx],
                                                                                  seq_len[IF_idx], sample=True)
        bare_bar[IF_idx, :] = IF_bare_bar
        T_bar[IF_idx] = seq_len[IF_idx] + 1

        # ----- InsBehind -----
        IB_idx = torch.nonzero(oprt_idx == 1).view(-1)
        IB_bare_bar, IB_xbar_prob, _ = getattr(self, 'InsBehind' + str(direction))(bare[IB_idx, :], hard_att[IB_idx],
                                                                                   seq_len[IB_idx], sample=True)
        bare_bar[IB_idx, :] = IB_bare_bar
        T_bar[IB_idx] = seq_len[IB_idx] + 1

        # ----- Del -----
        Del_idx = torch.nonzero(oprt_idx == 3).view(-1)
        Del_bare_bar = getattr(self, 'Del' + str(direction))(bare[Del_idx, :], hard_att[Del_idx])
        bare_bar[Del_idx, :] = Del_bare_bar
        T_bar[Del_idx] = seq_len[Del_idx] - 1

        # ----- Replace -----
        Rep_idx = torch.nonzero(oprt_idx == 2).view(-1)
        Rep_bare_bar, Rep_xbar_prob, _ = getattr(self, 'Replace' + str(direction))(bare[Rep_idx, :], hard_att[Rep_idx],
                                                                                   seq_len[Rep_idx], sample=True)
        bare_bar[Rep_idx, :] = Rep_bare_bar
        T_bar[Rep_idx] = seq_len[Rep_idx]

        # ----- Update null_mask -----
        null_mask_bar = bare_bar.eq(self.train_set.pad)

        # ----- Classify -----
        # Sort and re-sort.
        s_idx = [ix for ix, l in sorted(enumerate(T_bar.cpu()), key=lambda x: x[1], reverse=True)]
        res_idx = [a for a, b in sorted(enumerate(s_idx), key=lambda x: x[1])]
        s_cls_bare_bar, _ = self.Classifier(self.Emb(bare_bar[s_idx, :]), T_bar[s_idx], null_mask_bar[s_idx, :])
        cls_bare_bar = s_cls_bare_bar[res_idx].detach()

        # ----- Reward/PG for att (w.r.t. confidence) -----
        rwd_att_CRITIC = config.beta_att[direction]
        # Great difference -> position is important.
        rwd_att = torch.abs(cls - cls_bare_bar) - rwd_att_CRITIC
        att_pg = self.criterionRL(sample_probs=att_prob, reward=rwd_att) * config.lambda_att_conf

        # ----- Reward/PG for xbar (w.r.t. confidence) -----
        rwd_xbar_conf_CRITIC = config.beta_xbar_conf[direction]
        if direction == 0:
            IF_rwd_xbar_conf = cls_bare_bar[IF_idx] - cls[IF_idx] - rwd_xbar_conf_CRITIC
            IB_rwd_xbar_conf = cls_bare_bar[IB_idx] - cls[IB_idx] - rwd_xbar_conf_CRITIC
            Rep_rwd_xbar_conf = cls_bare_bar[Rep_idx] - cls[Rep_idx] - rwd_xbar_conf_CRITIC
        else:
            IF_rwd_xbar_conf = cls[IF_idx] - cls_bare_bar[IF_idx] - rwd_xbar_conf_CRITIC
            IB_rwd_xbar_conf = cls[IB_idx] - cls_bare_bar[IB_idx] - rwd_xbar_conf_CRITIC
            Rep_rwd_xbar_conf = cls[Rep_idx] - cls_bare_bar[Rep_idx] - rwd_xbar_conf_CRITIC
        IF_xbar_conf_pg = self.criterionRL(sample_probs=IF_xbar_prob, reward=IF_rwd_xbar_conf)
        IB_xbar_conf_pg = self.criterionRL(sample_probs=IB_xbar_prob, reward=IB_rwd_xbar_conf)
        Rep_xbar_conf_pg = self.criterionRL(sample_probs=Rep_xbar_prob, reward=Rep_rwd_xbar_conf)

        # ----- Star Indices -----
        star_index = torch.zeros_like(oprt_idx)
        star_index[IF_idx] = hard_att[IF_idx]
        star_index[IB_idx] = hard_att[IB_idx] + 1
        star_index[Rep_idx] = hard_att[Rep_idx]
        n_del = Del_idx.shape[0]
        Del_if_idx = Del_idx[: n_del // 2]
        Del_ib_idx = Del_idx[n_del // 2:]
        star_index[Del_if_idx] = hard_att[Del_if_idx]
        star_index[Del_ib_idx] = hard_att[Del_ib_idx] - 1

        # ----- Reward/PG for xbar (w.r.t. language model) -----
        rwd_xbar_lm_CRITIC = config.beta_xbar_lm[direction]

        IF_word_prob_f = getattr(self, 'LMf{}'.format(1 - direction)).inference(bare_bar[IF_idx], star_index[IF_idx], T_bar[IF_idx])  # shape = (n_IF, )
        IF_word_prob_b = getattr(self, 'LMb{}'.format(1 - direction)).inference(bare_bar[IF_idx], star_index[IF_idx], T_bar[IF_idx])  # shape = (n_IF, )
        IF_word_prob = torch.sqrt(IF_word_prob_f * IF_word_prob_b)  # shape = (n_IF, )
        IF_rwd_xbar_lm = IF_word_prob - rwd_xbar_lm_CRITIC
        IF_rwd_xbar_lm = IF_rwd_xbar_lm * config.lambda_lm
        IF_xbar_lm_pg = self.criterionRL(sample_probs=IF_xbar_prob, reward=IF_rwd_xbar_lm)

        IB_word_prob_f = getattr(self, 'LMf{}'.format(1 - direction)).inference(bare_bar[IB_idx], star_index[IB_idx], T_bar[IB_idx])  # shape = (n_IB, )
        IB_word_prob_b = getattr(self, 'LMb{}'.format(1 - direction)).inference(bare_bar[IB_idx], star_index[IB_idx], T_bar[IB_idx])  # shape = (n_IB, )
        IB_word_prob = torch.sqrt(IB_word_prob_f * IB_word_prob_b)  # shape = (n_IB, )
        IB_rwd_xbar_lm = IB_word_prob - rwd_xbar_lm_CRITIC
        IB_rwd_xbar_lm = IB_rwd_xbar_lm * config.lambda_lm
        IB_xbar_lm_pg = self.criterionRL(sample_probs=IB_xbar_prob, reward=IB_rwd_xbar_lm)

        Rep_word_prob_f = getattr(self, 'LMf{}'.format(1 - direction)).inference(bare_bar[Rep_idx], star_index[Rep_idx], T_bar[Rep_idx])  # shape = (n_Rep, )
        Rep_word_prob_b = getattr(self, 'LMb{}'.format(1 - direction)).inference(bare_bar[Rep_idx], star_index[Rep_idx], T_bar[Rep_idx])  # shape = (n_Rep, )
        Rep_word_prob = torch.sqrt(Rep_word_prob_f * Rep_word_prob_b)  # shape = (n_Rep, )
        Rep_rwd_xbar_lm = Rep_word_prob - rwd_xbar_lm_CRITIC
        Rep_rwd_xbar_lm = Rep_rwd_xbar_lm * config.lambda_lm
        Rep_xbar_lm_pg = self.criterionRL(sample_probs=Rep_xbar_prob, reward=Rep_rwd_xbar_lm)

        # ------ Supervision for xbar ------
        ori = bare.gather(1, hard_att.unsqueeze(1)).squeeze(1)

        # ----- Del_if (back) -----
        _, _, Del_if_lgt = getattr(self, 'InsFront' + str(1 - direction))(bare_bar[Del_if_idx, :],
                                                                          star_index[Del_if_idx],
                                                                          T_bar[Del_if_idx], sample=True)
        Del_if_tgt = ori[Del_if_idx]
        Del_if_XE2 = self.criterionBack(Del_if_lgt, Del_if_tgt)

        # ----- Del_ib (back) -----
        _, _, Del_ib_lgt = getattr(self, 'InsBehind' + str(1 - direction))(bare_bar[Del_ib_idx, :],
                                                                           star_index[Del_ib_idx],
                                                                           T_bar[Del_ib_idx], sample=True)
        Del_ib_tgt = ori[Del_ib_idx]
        Del_ib_XE2 = self.criterionBack(Del_ib_lgt, Del_ib_tgt)

        # ----- Replace (back) -----
        _, _, Rep_lgt = getattr(self, 'Replace' + str(1 - direction))(bare_bar[Rep_idx, :],
                                                                      star_index[Rep_idx],
                                                                      T_bar[Rep_idx], sample=True)
        Rep_tgt = ori[Rep_idx]
        Rep_XE2 = self.criterionBack(Rep_lgt, Rep_tgt)

        # ----- Summarize -----
        IF_xbar_pg = IF_xbar_conf_pg * config.lambda_ins_conf + IF_xbar_lm_pg
        IB_xbar_pg = IB_xbar_conf_pg * config.lambda_ins_conf + IB_xbar_lm_pg
        Del_ib_XE2 = Del_ib_XE2.mean()
        Del_if_XE2 = Del_if_XE2.mean()

        # ----- Reward/PG for Replace's xbar (w.r.t. XE2) -----
        rwd_xbar_XE2_CRITIC = config.beta_xbar_XE2[direction]
        Rep_rwd_xbar_XE2 = config.subtract_XE2[direction] - Rep_XE2.detach() - rwd_xbar_XE2_CRITIC
        Rep_xbar_XE2_pg = self.criterionRL(sample_probs=Rep_xbar_prob, reward=Rep_rwd_xbar_XE2) * 0.05

        # ----- Combine -----
        Rep_xbar_pg = Rep_xbar_XE2_pg + Rep_xbar_conf_pg + Rep_xbar_lm_pg

        # Logging.
        loss['Att/R-{}'.format(direction)] = rwd_att.mean().item()
        loss['Rep/R-conf-{}'.format(direction)] = Rep_rwd_xbar_conf.mean().item()
        loss['Rep/R-XE2-{}'.format(direction)] = Rep_rwd_xbar_XE2.mean().item()
        loss['Rep/XE2-{}'.format(direction)] = Rep_XE2.mean().item()
        loss['Rep/R-lm-{}'.format(direction)] = Rep_rwd_xbar_lm.mean().item()
        loss['IF/R-conf-{}'.format(direction)] = IF_rwd_xbar_conf.mean().item()
        loss['IF/R-lm-{}'.format(direction)] = IF_rwd_xbar_lm.mean().item()
        loss['IB/R-conf-{}'.format(direction)] = IB_rwd_xbar_conf.mean().item()
        loss['IB/R-lm-{}'.format(direction)] = IB_rwd_xbar_lm.mean().item()
        loss['Del_f/XE2-{}'.format(direction)] = Del_if_XE2.mean().item()
        loss['Del_b/XE2-{}'.format(direction)] = Del_ib_XE2.mean().item()

        return att_pg, Rep_xbar_pg, Rep_XE2.mean(), IF_xbar_pg, IB_xbar_pg, Del_ib_XE2, Del_if_XE2

    def valtest(self, val_or_test, mode, _pow_lm=None, cls_stop=None, max_iter=None, ablation=None):
        dataset = {
                "test": self.test_set,
                "val": self.val_set
            }[val_or_test]

        loader = DataLoader(dataset, batch_size=2048, shuffle=False, num_workers=config.num_workers)

        self.set_training(mode=False)

        fake_sents_0, fake_sents_1, clss_0, clss_1 = [], [], [], []
        with tqdm(loader) as pbar, torch.no_grad():
            for data in pbar:
                bare_0, _, _, len_0, y_0, res_idx_0, bare_1, _, _, len_1, y_1, res_idx_1 = self.preprocess_data(data)
                null_mask_0 = bare_0.eq(self.train_set.pad)
                null_mask_1 = bare_1.eq(self.train_set.pad)

                if mode == 'multi-steps':
                    cls_0, bare_bar_0 = self.valtest_forward_multi_steps_quick(bare_0, len_0, null_mask_0, 0, _pow_lm[0], cls_stop[0], max_iter[0], ablation, res_idx_0)
                    cls_1, bare_bar_1 = self.valtest_forward_multi_steps_quick(bare_1, len_1, null_mask_1, 1, _pow_lm[1], cls_stop[1], max_iter[1], ablation, res_idx_1)

                    clss_0.append((cls_0 > 0.5)[res_idx_0])
                    clss_1.append((cls_1 > 0.5)[res_idx_1])
                    bare_bar_0 = strip_pad(
                        [[self.vocab.id2word[i.data.cpu().numpy()] for i in sent] for sent in bare_bar_0])
                    bare_bar_1 = strip_pad(
                        [[self.vocab.id2word[i.data.cpu().numpy()] for i in sent] for sent in bare_bar_1])
                    fake_sents_1.extend([bare_bar_0[i] for i in res_idx_0])
                    fake_sents_0.extend([bare_bar_1[i] for i in res_idx_1])
                elif mode == 'aux-cls':
                    cls_0, _ = self.Aux_Classifier(self.Aux_Emb(bare_0), len_0, null_mask_0)
                    cls_1, _ = self.Aux_Classifier(self.Aux_Emb(bare_1), len_1, null_mask_1)
                    clss_0.append((cls_0 > 0.5)[res_idx_0])
                    clss_1.append((cls_1 > 0.5)[res_idx_1])
                elif mode == 'cls':
                    cls_0, _ = self.Classifier(self.Emb(bare_0), len_0, null_mask_0)
                    cls_1, _ = self.Classifier(self.Emb(bare_1), len_1, null_mask_1)
                    clss_0.append((cls_0 > 0.5)[res_idx_0])
                    clss_1.append((cls_1 > 0.5)[res_idx_1])
                else:
                    raise ValueError()

        if mode == 'aux-cls':
            clss_0 = torch.cat(clss_0, dim=0).float()
            clss_1 = torch.cat(clss_1, dim=0).float()
            # ----- Attention Classifier Acc -----
            n_wrong = clss_0.sum().item() + (1 - clss_1).sum().item()
            n_all = clss_0.shape[0] + clss_1.shape[0]
            acc = (n_all - n_wrong) / n_all
            print('\nAttention classifier accuracy =\n', acc)
            if acc > self.best_acc:
                self.best_acc = acc
                self.save_step(['Aux_Classifier', 'Aux_Emb'])
            self.set_training(mode=True)
            return None
        elif mode == 'cls':
            clss_0 = torch.cat(clss_0, dim=0).float()
            clss_1 = torch.cat(clss_1, dim=0).float()
            # ----- Attention Classifier Acc -----
            n_wrong = clss_0.sum().item() + (1 - clss_1).sum().item()
            n_all = clss_0.shape[0] + clss_1.shape[0]
            acc = (n_all - n_wrong) / n_all
            print('\nAttention classifier accuracy =\n', acc)
            if acc > self.best_acc:
                self.best_acc = acc
                self.save_step(['Classifier', 'Emb'])
            self.set_training(mode=True)
            return None

        # Transfer oriented.
        fake_sents_0 = fake_sents_0[:dataset.l1]
        fake_sents_1 = fake_sents_1[:dataset.l0]
        if val_or_test == 'test':
            ori_0, ref_0, ori_1, ref_1 = dataset.get_references()
            assert len(ref_0) == len(fake_sents_1), str(len(ref_0)) + ' ' + str(len(fake_sents_1))
            assert len(ref_1) == len(fake_sents_0)
        else:
            ori_0, ori_1 = dataset.get_val_ori()
            assert len(ori_1) == len(fake_sents_0)
            assert len(ori_0) == len(fake_sents_1)

        if val_or_test == 'test':
            # ---- Moses BLEU -----
            log_dir = 'outputs/temp_results/{}'.format(config.beta_xbar_lm[0])
            if not os.path.exists(log_dir):
                os.mkdir(log_dir)
            multi_BLEU = calc_bleu_score([' '.join(sent) for sent in fake_sents_1] + [' '.join(sent) for sent in fake_sents_0],
                                         [[' '.join(ref)] for ref in ref_0] + [[' '.join(ref)] for ref in ref_1],
                                         log_dir=log_dir,
                                         multi_ref=True)
            print('moses BLEU = {}'.format(round(multi_BLEU, ROUND)))

            with open(os.path.join(config.sample_dir, 'sentiment.test.0.ours'), 'w') as f0:
                for sent in fake_sents_1:
                    f0.write(' '.join(sent) + '\n')
            with open(os.path.join(config.sample_dir, 'sentiment.test.1.ours'), 'w') as f1:
                for sent in fake_sents_0:
                    f1.write(' '.join(sent) + '\n')

            # ----- Classifier Acc -----
            acc = self.clseval.class_score(fake_sents_0 + fake_sents_1, labels=[0] * len(fake_sents_0) + [1] * len(fake_sents_1))
            print('classification accuracy = {}'.format(round(acc, ROUND)))

        else:
            # ----- Classifier Acc -----
            acc = self.clseval.class_score(fake_sents_0 + fake_sents_1, labels=[0] * len(fake_sents_0) + [1] * len(fake_sents_1))
            print('\n\n\n\nclassification accuracy = {}'.format(round(acc, ROUND)))
            # ----- Validation -----
            peep_num = 50
            print('\n1')
            [print(' '.join(sent)) for sent in ori_1[:peep_num]]
            print('\n1 -> 0')
            [print(' '.join(sent)) for sent in fake_sents_0[:peep_num]]

        self.set_training(mode=True)
        return None

    def valtest_forward_multi_steps_quick(self, bare, seq_len, null_mask, direction, _pow_lm, cls_stop, max_iter, ablation, hl_res_idx=None):
        """
        Notes:
            The code is VERY messy and poorly commented.
            Variable names may NOT reflects their semantic meaning, due to the incremental changes of methodology.
        """
        _cls, att = self.Classifier(self.Emb(bare), seq_len, null_mask)

        mask = torch.zeros_like(bare).copy_(bare)
        bare_bar = torch.zeros_like(bare).copy_(bare)
        active_indices = gpu_wrapper(torch.LongTensor(list(range(bare.shape[0]))))

        for j in range(max_iter):
            if ablation == 'mask-att':
                null_mask[:, :] = 0
            s_idx = [ix for ix, l in sorted(enumerate(seq_len[active_indices].cpu()), key=lambda x: x[1], reverse=True)]
            res_idx = [a for a, b in sorted(enumerate(s_idx), key=lambda x: x[1])]
            cls_mask, _ = self.Aux_Classifier(self.Aux_Emb(mask[active_indices][s_idx, :]), seq_len[active_indices][s_idx], null_mask[active_indices][s_idx, :])  # shape = (pre_n_active, ), (pre_n_active, max_len)
            _, att_mask = self.Classifier(self.Emb(mask[active_indices][s_idx, :]), seq_len[active_indices][s_idx], null_mask[active_indices][s_idx, :])  # shape = (pre_n_active, ), (pre_n_active, max_len)
            cls_mask = cls_mask[res_idx]
            att_mask = att_mask[res_idx]
            __active_indices = torch.nonzero(torch.abs(1 - direction - cls_mask) > cls_stop).view(-1)
            if __active_indices.shape[0] == 0:
                break
            active_indices = active_indices[__active_indices]  # shape = (n_active, )
            att_mask = att_mask[__active_indices]  # shape = (n_active, max_len)
            s_idx = [ix for ix, l in sorted(enumerate(seq_len[active_indices].cpu()), key=lambda x: x[1], reverse=True)]
            res_idx = [a for a, b in sorted(enumerate(s_idx), key=lambda x: x[1])]
            i = torch.argmax(att_mask, dim=1)
            bare_bar_InsFront, _, _ = getattr(self, 'InsFront' + str(direction))(bare_bar[active_indices][s_idx], i[s_idx], seq_len[active_indices][s_idx], sample=False)
            bare_bar_InsFront = bare_bar_InsFront[res_idx]
            bare_bar_InsBehind, _, _ = getattr(self, 'InsBehind' + str(direction))(bare_bar[active_indices][s_idx], i[s_idx], seq_len[active_indices][s_idx], sample=False)
            bare_bar_InsBehind = bare_bar_InsBehind[res_idx]
            bare_bar_Replace, _, _ = getattr(self, 'Replace' + str(direction))(bare_bar[active_indices][s_idx], i[s_idx], seq_len[active_indices][s_idx], sample=False)
            bare_bar_Replace = bare_bar_Replace[res_idx]
            bare_bar_Delthis = getattr(self, 'Del' + str(direction))(bare_bar[active_indices], i)
            bare_bar_Delbefore = getattr(self, 'Del' + str(direction))(bare_bar[active_indices], i - 1)
            bare_bar_Delafter = getattr(self, 'Del' + str(direction))(bare_bar[active_indices], i + 1)
            bare_bar_NotChange = bare_bar[active_indices]

            bare_bars = [bare_bar_InsFront,
                         bare_bar_InsBehind,
                         bare_bar_Replace,
                         bare_bar_Delthis,
                         bare_bar_Delbefore,
                         bare_bar_Delafter,
                         bare_bar_NotChange]
            seq_lens = [seq_len[active_indices] + 1,
                        seq_len[active_indices] + 1,
                        seq_len[active_indices],
                        seq_len[active_indices] - 1,
                        seq_len[active_indices] - 1,
                        seq_len[active_indices] - 1,
                        seq_len[active_indices]]
            sent_probs = []
            for _bare_bar, _seq_len in zip(bare_bars, seq_lens):
                sent_prob_f = getattr(self, 'LMf{}'.format(1 - direction)).inference_whole(_bare_bar,  _seq_len)
                sent_prob_b = getattr(self, 'LMb{}'.format(1 - direction)).inference_whole(_bare_bar, _seq_len)
                cls, _ = self.Classifier(self.Emb(_bare_bar[s_idx, :]), _seq_len[s_idx], _bare_bar.eq(self.train_set.pad)[s_idx, :])
                cls = cls[res_idx]
                cls = torch.abs(direction - cls)
                sent_probs.append(torch.pow(torch.sqrt(sent_prob_f * sent_prob_b), _pow_lm) * cls)

            sent_probs = torch.stack(sent_probs, dim=1)
            try:
                for abl in range(7):
                    if abl not in ablation:
                        sent_probs[:, abl] = - float('inf')
            except Exception:
                pass
            oprt = torch.argmax(sent_probs, dim=1)
            bare_bars = torch.stack(bare_bars, dim=2)
            for __index, index in enumerate(active_indices):
                bare_bar[index, :] = bare_bars[__index, :, oprt[__index]]
                seq_len[index] = seq_lens[oprt[__index].item()][__index]

            __infront_indices = torch.nonzero(oprt == 0).view(-1)
            __insbehind_indices = torch.nonzero(oprt == 1).view(-1)
            __replace_indices = torch.nonzero(oprt == 2).view(-1)
            __notchange_indices = torch.nonzero(oprt == 6).view(-1)

            window = False
            for __index in __infront_indices:
                null_mask[active_indices[__index], i[__index]] = 1
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    null_mask[active_indices[__index], before_indices] = 1
                    null_mask[active_indices[__index], i[__index] + 1] = 1
                mask[active_indices[__index], i[__index]] = self.train_set.unk
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    mask[active_indices[__index], before_indices] = self.train_set.unk
                    mask[active_indices[__index], i[__index] + 1] = self.train_set.unk
            for __index in __insbehind_indices:
                null_mask[active_indices[__index], i[__index] + 1] = 1
                if window:
                    null_mask[active_indices[__index], i[__index]]     = 1
                    null_mask[active_indices[__index], i[__index] + 2] = 1
                mask[active_indices[__index], i[__index] + 1] = self.train_set.unk
                if window:
                    mask[active_indices[__index], i[__index]] = self.train_set.unk
                    mask[active_indices[__index], i[__index] + 2] = self.train_set.unk
            for __index in __replace_indices:
                null_mask[active_indices[__index], i[__index]] = 1
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    null_mask[active_indices[__index], before_indices] = 1
                    null_mask[active_indices[__index], i[__index] + 1] = 1

                mask[active_indices[__index], i[__index]] = self.train_set.unk
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    mask[active_indices[__index], before_indices] = self.train_set.unk
                    mask[active_indices[__index], i[__index] + 1] = self.train_set.unk
            for __index in __notchange_indices:
                null_mask[active_indices[__index], i[__index]] = 1
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    null_mask[active_indices[__index], before_indices] = 1
                    null_mask[active_indices[__index], i[__index] + 1] = 1

                mask[active_indices[__index], i[__index]] = self.train_set.unk
                if window:
                    before_indices = i[__index] - 1
                    if len(before_indices.shape) > 0:
                        before_indices[torch.nonzero(before_indices < 0).view(-1)] = 0
                    mask[active_indices[__index], before_indices] = self.train_set.unk
                    mask[active_indices[__index], i[__index] + 1] = self.train_set.unk

        return _cls, bare_bar

    def preprocess_data(self, data):
        bare_0, go_0, eos_0, len_0, bare_1, go_1, eos_1, len_1 = data
        n_batch = bare_0.shape[0]

        s_idx_0 = [ix for ix, l in sorted(enumerate(len_0), key=lambda x: x[1], reverse=True)]
        res_idx_0 = [a for a, b in sorted(enumerate(s_idx_0), key=lambda x: x[1])]
        bare_0 = gpu_wrapper(bare_0[s_idx_0, :])
        go_0 = gpu_wrapper(go_0[s_idx_0, :])
        eos_0 = gpu_wrapper(eos_0[s_idx_0, :])
        len_0 = gpu_wrapper(len_0[s_idx_0])
        y_0 = gpu_wrapper(torch.zeros(n_batch))

        s_idx_1 = [ix for ix, l in sorted(enumerate(len_1), key=lambda x: x[1], reverse=True)]
        res_idx_1 = [a for a, b in sorted(enumerate(s_idx_1), key=lambda x: x[1])]
        bare_1 = gpu_wrapper(bare_1[s_idx_1, :])
        go_1 = gpu_wrapper(go_1[s_idx_1, :])
        eos_1 = gpu_wrapper(eos_1[s_idx_1, :])
        len_1 = gpu_wrapper(len_1[s_idx_1])
        y_1 = gpu_wrapper(torch.ones(n_batch))

        return bare_0, go_0, eos_0, len_0, y_0, res_idx_0, bare_1, go_1, eos_1, len_1, y_1, res_idx_1
Пример #26
0
class Retrainer(object):
    def __init__(self, raw_segments, observation_sequences, label_sequences):
        super(Retrainer, self).__init__()
        self.raw_segments = raw_segments
        self.observation_sequences = observation_sequences
        self.label_sequences = label_sequences
        self.hmm_new = None
        self.feature_entity_list = FeatureEntityList()
        self.lm = LanguageModel()
        self.boosting_feature_generator = BoostingFeatureGenerator()

        self.DOMINANT_RATIO = 0.85  # dominant label ratio: set empirically

        self.retrain_with_boosting_features()

    def retrain(self):
        self.hmm_new = HMM('retrainer', 6)
        self.hmm_new.train(
            self.observation_sequences,
            self.label_sequences,
            useLaplaceRule=False)  #important to set laplace to be no

    # With new features
    def retrain_with_boosting_features(self):
        # Build language model
        for raw_segment, label_sequence in zip(self.raw_segments,
                                               self.label_sequences):
            for token, label in zip(
                    Tokens(raw_segment).tokens, label_sequence):
                self.lm.add(token, label)
        self.lm.prettify()
        self.token_BGM = self.lm.prettify_model
        self.pattern_BGM = None

        # Retrain
        self.hmm_new = HMM('retrainer', 6)
        partial_features = []
        for raw_segment in self.raw_segments:
            partial_features.append(
                BoostingFeatureGenerator(raw_segment, self.token_BGM,
                                         self.pattern_BGM).features)
        self.hmm_new.train(partial_features,
                           self.label_sequences,
                           useLaplaceRule=False)
        self.observation_sequences = partial_features

    def run(self):
        i = 0
        self.new_labels = []
        for raw_segment, label_sequence in zip(self.raw_segments,
                                               self.label_sequences):
            new_labels = self.hmm_new.decode(raw_segment)[1]
            self.new_labels.append(new_labels)
            tokens = Tokens(raw_segment).tokens
            feature_vectors = FeatureGenerator(raw_segment).features
            print i, ':  ', raw_segment
            for token, old_label, new_label, feature_vector in zip(
                    tokens, label_sequence, new_labels, feature_vectors):
                print to_label(old_label), '\t', to_label(
                    new_label), '\t', token
                self.feature_entity_list.add_entity(
                    feature_vector, old_label, token)  #???? Old label first
            print '\n'
            i += 1

    def find_pattern(self):
        self.hmm_new.feature_entity_list.print_all_entity()

    # Find the first tokens at VN boundaries
    def find_venue_boundary_tokens(self):
        recorder = {}
        for raw_segment, observation_sequence, label_sequence in zip(
                self.raw_segments, self.observation_sequences,
                self.label_sequences):
            first_target_label_flag = True
            tokens = Tokens(raw_segment).tokens
            for token, feature_vector, label in zip(tokens,
                                                    observation_sequence,
                                                    label_sequence):
                # First meet a VN label
                if label == 4 and first_target_label_flag:
                    key = token.lower()
                    if not key.islower():
                        continue
                    if recorder.has_key(key):
                        recorder[key] += 1
                    else:
                        recorder[key] = 1
                    first_target_label_flag = False

                elif (first_target_label_flag is False) and label in [0, 1, 3]:
                    first_target_label_flag = True

        for k, v in recorder.iteritems():
            print k, '\t', v
        return recorder

    # Learn the general order of structure of publications before moving forward
    def find_majority_structure(self):
        first_bit_counter = {'0': 0, '3': 0, '4': 0, '5': 0}
        overall_pattern_counter = {}
        for label_sequence in self.label_sequences:
            label = label_sequence[0]
            if label == 2:
                continue
            elif label == 5:
                continue
            elif label in [0, 1]:
                first_bit_counter['0'] += 1
            else:
                first_bit_counter[str(label)] += 1

            pattern = []
            for label in label_sequence:
                if label in [2, 5]:
                    continue
                elif label in [0, 1]:
                    if 0 in pattern:
                        continue
                    else:
                        pattern.append(0)
                elif label == 3:
                    if 3 in pattern:
                        continue
                    else:
                        pattern.append(3)
                elif label == 4:
                    if 4 in pattern:
                        continue
                    else:
                        pattern.append(4)
            key = str(pattern)
            if overall_pattern_counter.has_key(key):
                overall_pattern_counter[key] += 1
            else:
                overall_pattern_counter[key] = 1

        # Inducing the structure
        sorted_firstbit_counter = sorted(first_bit_counter.iteritems(),
                                         key=operator.itemgetter(1),
                                         reverse=True)
        sorted_pattern_counter = sorted(overall_pattern_counter.iteritems(),
                                        key=operator.itemgetter(1),
                                        reverse=True)
        print '===========================================', sorted_pattern_counter
        return int(sorted_firstbit_counter[0][0]), ast.literal_eval(
            sorted_pattern_counter[0][0]), (float(
                sorted_pattern_counter[0][1])) / len(self.label_sequences)

    def run_with_boosting_features(self):
        i = 0
        self.new_labels = []
        self.combined_labels = []

        for raw_segment, label_sequence in zip(self.raw_segments,
                                               self.label_sequences):
            feature_vectors, new_labels = self.hmm_new.decode(
                raw_segment, True, True, self.token_BGM, self.pattern_BGM)
            self.new_labels.append(new_labels)
            tokens = Tokens(raw_segment).tokens
            print i, ':  ', raw_segment

            # Combination step:
            tmp_combined_labels = []  # the decided combined labels so far
            for token, old_label, new_label, feature_vector in zip(
                    tokens, label_sequence, new_labels, feature_vectors):

                # Combine old and new labels to come out a combined label, and deciding...
                combined_label = -1

                if old_label == new_label:
                    combined_label = new_label
                    tmp_combined_labels.append(new_label)

                # Combine compatible labels: FN and LN
                elif old_label in [0, 1] and new_label in [0, 1]:
                    combined_label = old_label
                    tmp_combined_labels.append(new_label)

                # Combine labels that are not compatible
                else:
                    tmp_feature_entity = self.hmm_new.feature_entity_list.lookup(
                        feature_vector
                    )  # Get the Background knowledge provided the feature vector: the language feature model
                    sorted_label_distribution = sorted(
                        tmp_feature_entity.label_distribution.iteritems(),
                        key=operator.itemgetter(1),
                        reverse=True)
                    total_label_occurence = float(
                        sum(tmp[1] for tmp in sorted_label_distribution))

                    # ============================================================================================
                    # ============================================================================================
                    # ???? Experimenting: removing the low prob label distribution; FAILURE; ARCHIVED HERE AND DEPRECATED
                    # sorted_label_distribution = []
                    # sum_prob = 0.0
                    # for pair in tmp_sorted_label_distribution:
                    #     sorted_label_distribution.append(pair)
                    #     sum_prob += pair[1]
                    #     if sum_prob/total_label_occurence >= 0.90:
                    #         break
                    # ============================================================================================
                    # ============================================================================================

                    # Dominant label case: Iterate from the highest label stats according to this feature vector:
                    for label_frequency in sorted_label_distribution:
                        if int(label_frequency[0]) in [
                                old_label, new_label
                        ] and (label_frequency[1] /
                               total_label_occurence) >= self.DOMINANT_RATIO:
                            print 'Dominant labels'
                            # Check for constraint:
                            tmp_label_to_check = int(label_frequency[0])

                            # Find last occurence position of this label
                            if tmp_label_to_check not in [0, 1]:
                                last_occurence = ''.join([
                                    str(c) for c in tmp_combined_labels
                                ]).rfind(str(tmp_label_to_check))
                            elif tmp_label_to_check in [0, 1]:
                                last_occurence_0 = ''.join([
                                    str(c) for c in tmp_combined_labels
                                ]).rfind('0')
                                last_occurence_1 = ''.join([
                                    str(c) for c in tmp_combined_labels
                                ]).rfind('1')
                                last_occurence = max(last_occurence_0,
                                                     last_occurence_1)

                            # Checking constraints by simplifying what we did in viterbi
                            if last_occurence == -1 or last_occurence == (
                                    len(tmp_combined_labels) - 1
                            ):  # Never occurred, or last occurence is the last label
                                # When we are deciding the first label
                                if len(tmp_combined_labels) == 0:
                                    first_bit = self.find_majority_structure(
                                    )[0]
                                    if first_bit == 0 and tmp_label_to_check not in [
                                            0, 1
                                    ]:
                                        continue
                                    if first_bit == 3 and tmp_label_to_check != 3:
                                        continue

                                # VN CANNOT FOLLOW TI W/O DL constraint
                                if tmp_label_to_check == 4 and tmp_combined_labels[
                                        -1] == 3:
                                    continue
                            elif tmp_label_to_check in [0, 1]:
                                flag = False
                                for j in range(last_occurence,
                                               len(tmp_combined_labels)):
                                    if tmp_combined_labels[j] not in [0, 1, 2]:
                                        flag = True
                                        break
                                if flag:
                                    continue
                            elif tmp_label_to_check == 3:
                                continue
                            elif tmp_label_to_check == 4:
                                if tmp_combined_labels[-1] == 3:  #????
                                    continue

                            combined_label = tmp_label_to_check
                            tmp_combined_labels.append(tmp_label_to_check)
                            break

                    # No dominance case OR Dominance-fail-due-to-constraint case: Find relatively if the label with higher possibility follow the constraint of publication order
                    if combined_label == -1:
                        # Iterate from the highest label stats according to this feature vector:

                        for label_frequency in sorted_label_distribution:
                            breakout_flag = False
                            #Test against constraints
                            # 1. DL separate labels principle
                            # 2. AU-TI-VN Order
                            if int(label_frequency[0]) in [
                                    old_label, new_label
                            ]:
                                tmp_label_to_check = int(label_frequency[0])

                                # find structure of the order, and find what have appeared, and so predict what to be appear next
                                structure_overview = [
                                ]  #will record the order in big sense: 0,3,4/4,0,3
                                for tmp_combined_label in tmp_combined_labels:
                                    if tmp_combined_label in [2, 5]:
                                        continue
                                    elif tmp_combined_label in [0, 1]:
                                        if 0 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(0)
                                    elif tmp_combined_label == 3:
                                        if 3 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(3)
                                    elif tmp_combined_label == 4:
                                        if 4 in structure_overview:
                                            continue
                                        else:
                                            structure_overview.append(4)
                                # Based on the structure overview, find what should appear next
                                appear_next = []
                                if structure_overview == [0]:
                                    appear_next = [0, 1, 3, 2, 5]
                                elif structure_overview == [3]:
                                    appear_next = [3, 0, 1, 2, 5]
                                elif structure_overview == [0, 3]:
                                    appear_next = [3, 4, 2, 5]
                                elif structure_overview == [3, 0]:
                                    appear_next = [0, 1, 4, 2, 5]
                                elif structure_overview == [0, 3, 4]:
                                    appear_next = [4, 2, 5]
                                elif structure_overview == [3, 0, 4]:
                                    appear_next = [4, 2, 5]
                                else:  #weird case
                                    print 'Weird structure! Weird case!'
                                    if tmp_feature_entity.label_distribution[str(
                                            old_label
                                    )] > tmp_feature_entity.label_distribution[
                                            str(new_label)]:
                                        tmp_label_to_check_list = [
                                            old_label, new_label
                                        ]
                                    else:
                                        tmp_label_to_check_list = [
                                            new_label, old_label
                                        ]
                                    # Apply constraints here too
                                    for tmp_label_to_check in tmp_label_to_check_list:
                                        if tmp_label_to_check not in [0, 1]:
                                            last_occurence = ''.join([
                                                str(c)
                                                for c in tmp_combined_labels
                                            ]).rfind(str(tmp_label_to_check))
                                        elif tmp_label_to_check in [0, 1]:
                                            last_occurence_0 = ''.join([
                                                str(c)
                                                for c in tmp_combined_labels
                                            ]).rfind('0')
                                            last_occurence_1 = ''.join([
                                                str(c)
                                                for c in tmp_combined_labels
                                            ]).rfind('1')
                                            last_occurence = max(
                                                last_occurence_0,
                                                last_occurence_1)

                                        # Checking constraints by simplifying what we did in viterbi
                                        if last_occurence == -1 or last_occurence == (
                                                len(tmp_combined_labels) - 1):
                                            # When we are deciding the first label
                                            if len(tmp_combined_labels) == 0:
                                                first_bit = self.find_majority_structure(
                                                )[0]
                                                if first_bit == 0 and tmp_label_to_check not in [
                                                        0, 1
                                                ]:
                                                    continue
                                                if first_bit == 3 and tmp_label_to_check != 3:
                                                    continue
                                            try:
                                                if tmp_label_to_check == 4 and tmp_combined_labels[
                                                        -1] == 3:
                                                    continue
                                            except:
                                                continue
                                        elif tmp_label_to_check in [0, 1]:
                                            flag = False
                                            for j in range(
                                                    last_occurence,
                                                    len(tmp_combined_labels)):
                                                if tmp_combined_labels[
                                                        j] not in [0, 1, 2]:
                                                    flag = True
                                                    break
                                            if flag:
                                                continue
                                        elif tmp_label_to_check == 3:
                                            continue
                                        elif tmp_label_to_check == 4:
                                            if tmp_combined_labels[-1] == 3:
                                                continue

                                        combined_label = tmp_label_to_check
                                        tmp_combined_labels.append(
                                            combined_label)
                                        breakout_flag = True
                                        break

                                if breakout_flag:
                                    break
                                if tmp_label_to_check in appear_next:
                                    # Then check constraint. find last occurence, DL constraints
                                    # Just need to check DL constraints, no need to verify more on tokens, assume token verification is done in the first iteration
                                    if tmp_label_to_check not in [0, 1]:
                                        last_occurence = ''.join([
                                            str(c) for c in tmp_combined_labels
                                        ]).rfind(str(tmp_label_to_check))
                                    elif tmp_label_to_check in [0, 1]:
                                        last_occurence_0 = ''.join([
                                            str(c) for c in tmp_combined_labels
                                        ]).rfind('0')
                                        last_occurence_1 = ''.join([
                                            str(c) for c in tmp_combined_labels
                                        ]).rfind('1')
                                        last_occurence = max(
                                            last_occurence_0, last_occurence_1)

                                    # Checking constraints by simplifying what we did in viterbi
                                    if last_occurence == -1 or last_occurence == (
                                            len(tmp_combined_labels) - 1):
                                        if tmp_label_to_check == 4 and tmp_combined_labels[
                                                -1] == 3:  #Hardcode rule [2013/07/23]: For VN, cannot directly follow a TI without DL???? may remove on real effect
                                            continue
                                    elif tmp_label_to_check in [0, 1]:
                                        flag = False
                                        for j in range(
                                                last_occurence,
                                                len(tmp_combined_labels)):
                                            if tmp_combined_labels[j] not in [
                                                    0, 1, 2
                                            ]:
                                                flag = True
                                                break
                                        if flag:
                                            continue

                                    elif tmp_label_to_check == 3:
                                        continue
                                        # flag = False
                                        # for j in range(last_occurence, len(tmp_combined_labels)):
                                        #     if tmp_combined_labels[j] not in [3,2]:
                                        #         flag = True
                                        #         break
                                        # if flag:
                                        #     continue

                                    elif tmp_label_to_check == 4:
                                        if tmp_combined_labels[-1] == 3:  #????
                                            continue

                                    # elif tmp_label_to_check == 2:
                                    # elif tmp_label_to_check == 5:

                                    # Otherwise, pass
                                    log_err('\t\t' + str(i) +
                                            'Should combine this one')
                                    combined_label = tmp_label_to_check
                                    tmp_combined_labels.append(
                                        tmp_label_to_check)
                                    # combined_label = (tmp_label_to_check, sorted_label_distribution)
                                    break

                                else:
                                    continue

                        # Debug
                        if combined_label == -1:
                            log_err(str(i) + 'problem')
                            combined_label = (appear_next,
                                              sorted_label_distribution)
                            tmp_combined_labels.append(-1)

            # Final check the accordance with the major order, ideally, all records under one domain should have the same order... PS very ugly code I admit
            print '==========================tmp_combined_labels', tmp_combined_labels
            majority_order_structure = self.find_majority_structure()[1]
            majority_rate = self.find_majority_structure()[2]
            tmp_combined_labels_length = len(tmp_combined_labels)
            if majority_rate > 0.80 and majority_order_structure == [0, 3, 4]:
                # p1(phase1): author segments
                for p1 in range(tmp_combined_labels_length):
                    if tmp_combined_labels[p1] in [0, 1, 2, 5]:
                        continue
                    else:
                        break

                # p2(phase2): title segments
                for p2 in range(p1, tmp_combined_labels_length):
                    if tmp_combined_labels[p2] == 3:
                        continue
                    else:
                        break

                #p3(phase3): venue segments
                for p3 in range(p2, tmp_combined_labels_length):
                    if tmp_combined_labels[p3] in [2, 5, 4]:
                        continue
                    else:
                        break

                # Decision
                if p1 == 0:
                    print 'Houston we got a SERIOUS problem!'
                    log_err('Houston we got a SERIOUS problem!!!!!!!!')

                if p2 == p1:
                    print 'Houston we got a problem!'
                    for sp2 in range(p2, tmp_combined_labels_length):
                        if tmp_combined_labels[sp2] != 2:
                            tmp_combined_labels[sp2] = 3
                        else:
                            break  # should fix common mislabeling at this point now??????????

            # elif majority_rate > 0.80 and majority_order_structure == [3,0,4]:    # ???? not sure if this is normal
            #     # p1(phase1): title segments
            #     for p1 in range(tmp_combined_labels_length):
            #         if tmp_combined_labels[p1] in [3]:
            #             continue
            #         else:
            #             break

            #     # p2(phase2): author segments
            #     for p2 in range(p1, tmp_combined_labels_length):
            #         if tmp_combined_labels[p2] == 3:
            #             continue
            #         else:
            #             break

            #     #p3(phase3): venue segments
            #     for p3 in range(p2, tmp_combined_labels_length):
            #         if tmp_combined_labels[p3] in [2,5,4]:
            #             continue
            #         else:
            #             break

            #     # Decision
            #     if p1 == 0:
            #         print 'Houston we got a SERIOUS problem!'
            #         log_err('Houston we got a SERIOUS problem!!!!!!!!')

            #     if p2 == p1:
            #         print 'Houston we got a problem!'
            #         for sp2 in range(p2, tmp_combined_labels_length):
            #             if tmp_combined_labels[sp2] != 2:
            #                 tmp_combined_labels[sp2] = 3
            #             else:
            #                 break
            for old_label, new_label, tmp_combined_label, token, feature_vector in zip(
                    label_sequence, new_labels, tmp_combined_labels, tokens,
                    feature_vectors):
                print to_label(old_label), '\t', to_label(
                    new_label), '\t', to_label(
                        tmp_combined_label), '\t', token, '\t', feature_vector
            print '\n'
            i += 1
Пример #27
0
def main():
    """This function implements the command-line interface."""
    # Parse input configuration.
    year = argv[2]
    assert year in ("dry_run", "dev", "2016", "2017")
    config = argv[1].split('-', 8)
    technique_string = config[0]
    assert technique_string in ("hard_terms", "soft_terms", "hard_topics",
                                "soft_topics")
    technique = technique_string

    # Set up the document similarity model.
    if technique == "hard_topics" or technique == "soft_topics":
        similarity_model = TopicCosineSimilarity()
    if technique == "soft_topics" or technique == "soft_terms":
        term_similarity_string = config[1]
        assert term_similarity_string in ("w2v.ql", "w2v.googlenews",
                                          "glove.enwiki_gigaword5",
                                          "glove.common_crawl",
                                          "glove.twitter", "fasttext.enwiki")
        term_similarity = term_similarity_string

        soft_matrices_string = config[2]
        assert soft_matrices_string in ("mrel", "mlev", "mrel_mlev")
        if soft_matrices_string == "mrel":
            soft_matrices = [("mrel", 1.0)]
        elif soft_matrices_string == "mlev":
            soft_matrices = [("mlev", 1.0)]
        else:
            soft_matrices = [("mrel", 0.5), ("mlev", 0.5)]

    if technique == "hard_terms":
        similarity_model = TermHardCosineSimilarity()
        kwargs = {}
    elif technique == "hard_topics":
        kwargs = {}
    elif technique == "soft_terms":
        weighting_string = config[3]
        assert weighting_string in ("early", "late", "none")
        if weighting_string == "none":
            weighting = None
        else:
            weighting = weighting_string

        normalization_string = config[4]
        assert normalization_string in ("soft", "hard", "none")
        if normalization_string == "none":
            normalization = None
        else:
            normalization = normalization_string

        rounding_string = config[5]
        assert rounding_string in ("none", "round", "floor", "ceil")
        if rounding_string == "none":
            rounding = None
        else:
            rounding = rounding_string

        similarity_model = TermSoftCosineSimilarity(weighting=weighting, rounding=rounding, \
                                                    normalization=normalization)

        w2v_min_count = int(config[6])
        m_knn = int(config[7])
        m_threshold = float(config[8])
        kwargs = {"soft_matrices": soft_matrices, "w2v_min_count": w2v_min_count, "m_knn": m_knn, \
                  "m_threshold": m_threshold, "term_similarity": term_similarity }
    elif technique == "soft_topics":
        w2v_min_count = int(config[3])
        m_knn = int(config[4])
        m_threshold = float(config[5])
        kwargs = {"soft_matrices": soft_matrices, "w2v_min_count": w2v_min_count, "m_knn": m_knn, \
                  "m_threshold": m_threshold, "term_similarity": term_similarity }

    if year == "dry_run":
        # Prepare the language model and exit prematurely.
        LanguageModel(similarity=similarity_model,
                      technique=technique,
                      **kwargs)
        return

    # Determine directory and file names.
    if year == "dev":
        test_dirname = TEST2016_DIRNAME
        test_predictions_dirname = TEST2016_PREDICTIONS_DIRNAME
        gold_base_fname = DEV_GOLD_BASE_FNAME
        test_dataset_fname = DEV_DATASET_FNAME
#       train_dataset_fnames = TRAIN2016_DATASET_FNAMES
    elif year == "2016":
        test_dirname = TEST2016_DIRNAME
        test_predictions_dirname = TEST2016_PREDICTIONS_DIRNAME
        gold_base_fname = TEST2016_GOLD_BASE_FNAME
        test_dataset_fname = TEST2016_DATASET_FNAME
#       train_dataset_fnames = TRAIN2016_DATASET_FNAMES + [DEV_DATASET_FNAME]
    elif year == "2017":
        test_dirname = TEST2017_DIRNAME
        test_predictions_dirname = TEST2017_PREDICTIONS_DIRNAME
        gold_base_fname = TEST2017_GOLD_BASE_FNAME
        test_dataset_fname = TEST2017_DATASET_FNAME


#       train_dataset_fnames = TRAIN2017_DATASET_FNAMES + [DEV_DATASET_FNAME]
    output_fname = "%s/subtask_B_%s-%s.txt" % (test_predictions_dirname,
                                               argv[1], argv[2])
    base_output_fname = "%s/subtask_B_%s-%s.txt" % (
        TEST_PREDICTIONS_BASE_DIRNAME, argv[1], argv[2])

    # Perform the evaluation.
    if not path.exists(output_fname):
        LOGGER.info("Producing %s ...", output_fname)
        file_handler = logging.FileHandler("%s.log" % output_fname,
                                           encoding='utf8')
        logging.getLogger().addHandler(file_handler)
        start_time = time()

        language_model = LanguageModel(similarity=similarity_model,
                                       technique=technique,
                                       **kwargs)
        evaluate(language_model, [test_dataset_fname], output_fname)

        LOGGER.info("Time elapsed: %s" %
                    timedelta(seconds=time() - start_time))
        logging.getLogger().removeHandler(file_handler)
    print("%s %s %s" % (test_dirname, gold_base_fname, base_output_fname))
Пример #28
0
    def __init__(self):

        print('----- Loading data -----')
        self.train_set = Amazon('train', False)
        self.test_set = Amazon('test', False)
        self.val_set = Amazon('dev', False)
        print('The train set has {} items'.format(len(self.train_set)))
        print('The test set has {} items'.format(len(self.test_set)))
        print('The val set has {} items'.format(len(self.val_set)))

        self.vocab = self.train_set.vocab

        # Load pretrained classifier for evaluation.
        self.clseval = ClassifierEval(config.test_cls, config.dataset)
        self.clseval.restore_model()

        print('----- Loading model -----')
        embedding = self.vocab.embedding
        self.Emb = nn.Embedding.from_pretrained(embedding.clone(), freeze=False)
        self.Classifier = AttenClassifier(emb_dim=config.emb_dim,
                                          dim_h=config.dim_h,
                                          n_layers=config.n_layers,
                                          dropout=config.dropout,
                                          bi=config.bidirectional)
        self.Gate0 = Gate(dim_h=config.dim_h, n_layers=config.n_layers, dropout=config.dropout, bi=config.bidirectional,
                          temperature=config.temp_gate, embedding=embedding.clone())
        self.Gate1 = Gate(dim_h=config.dim_h, n_layers=config.n_layers, dropout=config.dropout, bi=config.bidirectional,
                          temperature=config.temp_gate, embedding=embedding.clone())
        self.InsFront0 = InsFront(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                  dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                  temperature=config.temp_sub)
        self.InsFront1 = InsFront(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                  dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                  temperature=config.temp_sub)
        self.InsBehind0 = InsBehind(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                    dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                    temperature=config.temp_sub)
        self.InsBehind1 = InsBehind(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                    dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                    temperature=config.temp_sub)
        self.Replace0 = Replace(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                temperature=config.temp_sub)
        self.Replace1 = Replace(dim_h=config.dim_h, embedding=embedding.clone(), n_layers=config.n_layers,
                                dropout=config.dropout, bi=config.bidirectional, voc_size=self.vocab.size,
                                temperature=config.temp_sub)
        self.Del0 = Delete()  # Not a module.
        self.Del1 = Delete()  # Not a module.
        # Language models.
        if config.train_mode == 'pto':
            self.LMf0 = LanguageModel(config.dataset, direction='forward', sentiment=0)
            self.LMf0.model_load()
            self.LMf1 = LanguageModel(config.dataset, direction='forward', sentiment=1)
            self.LMf1.model_load()
            self.LMb0 = LanguageModel(config.dataset, direction='backward', sentiment=0)
            self.LMb0.model_load()
            self.LMb1 = LanguageModel(config.dataset, direction='backward', sentiment=1)
            self.LMb1.model_load()
        # Auxiliary classifier.
        self.Aux_Emb = nn.Embedding.from_pretrained(embedding.clone(), freeze=False)
        self.Aux_Classifier = AttenClassifier(emb_dim=config.emb_dim,
                                              dim_h=config.dim_h,
                                              n_layers=config.n_layers,
                                              dropout=config.dropout,
                                              bi=config.bidirectional)

        self.modules = ['Emb', 'Classifier',
                        'Gate0', 'Gate1',
                        'InsFront0', 'InsFront1',
                        'InsBehind0', 'InsBehind1',
                        'Replace0', 'Replace1',
                        'Aux_Classifier', 'Aux_Emb']
        for module in self.modules:
            print('--- {}: '.format(module))
            print(getattr(self, module))
            setattr(self, module, gpu_wrapper(getattr(self, module)))

        self.scopes = {
            'emb': ['Emb'],
            'cls': ['Classifier'],
            'aux_cls': ['Aux_Classifier', 'Aux_Emb'],
            'gate': ['Gate0', 'Gate1'],
            'oprt': ['InsFront0', 'InsFront1', 'InsBehind0', 'InsBehind1', 'Replace0', 'Replace1'],
        }
        for scope in self.scopes.keys():
            setattr(self, scope + '_lr', getattr(config, scope + '_lr'))

        self.iter_num = -1
        self.logger = None
        if config.train_mode == 'pto':
            pass
        elif config.train_mode == 'aux-cls-only':
            self.train_set = Amazon('train', True)
            self.test_set = Amazon('test', True)
            self.val_set = Amazon('dev', True)
            self.best_acc = 0
        elif config.train_mode == 'cls-only':
            self.best_acc = 0
        self.criterionSeq, self.criterionCls, self.criterionRL, self.criterionBack = None, None, None, None
Пример #29
0
def main():
    if parameters.FULL_DATA_MODE:
        lm = LanguageModel()
        lm.load_data(train_data_path=parameters.SMALL_DATA_PATH)
        lm.define_model()
        lm.compile_model()
        lm.fit_model()
        lm.evaluate_model()
        print(lm.generate_seq('女', 5))
    else:
        lm = LanguageModel()
        lm.prepare_for_generator(train_data_path=parameters.TRAIN_DATA_PATH,
                                 val_data_path=parameters.VAL_DATA_PATH,
                                 test_data_path=parameters.TEST_DATA_PATH)
        lm.define_model()
        lm.compile_model()
        lm.fit_model_with_generator()
        lm.evaluate_model_with_generator()
Пример #30
0
def gradients_clipping(grads_params):
    new_grads_params = []
    for g,p in grads_params:
        clipped_g = tf.clip_by_value(g,-FLAGS.clip_value,FLAGS.clip_value)
        new_grads_params.append((clipped_g,p))
    return new_grads_params

models = []
grads = []
with g.as_default():
         
    # build the model
    for i in xrange(FLAGS.ngpu):
        with tf.device('/device:GPU:{:d}'.format(i)),tf.name_scope('model{:d}'.format(i)):
            reuse = i>0
            models.append(LanguageModel(opts,'train',reuse))
            models[i].build()
                  
    # create a function to validate
    val_fns, generators = [],[]
    with tf.device('/gpu:0'.format(i)):
        # don't use the numpy version generator, use tensorflow version generator instead
        val_fn, _ = create_val_fn(batch_size = 100)
        val_fns.append(val_fn)
        #generators.append(generator) 
    
    batch_size = FLAGS.batch_size*FLAGS.ngpu
    start_decay_steps = int(opts.nImgs//batch_size*opts.start_decay_epoches)
    decay_steps = int(opts.nImgs//batch_size*opts.decay_epoches)
    decayed_learning_rate = tf.train.exponential_decay(opts.learning_rate,
                                                       tf.maximum(models[0].step-start_decay_steps,0),
Пример #31
0
            instance.serialize_class_data()
        instance.log('Run %d of %d:' % (i + 1, n_runs))
        instance.create_model()
        instance.compile_model()
        instance.train()
        instance.results()
        instance.serialize_model()
        instance.serialize_results()


intervening = lambda dep: dep['n_intervening'] >= 1

models = {
    'grammaticality':
    CorruptAgreement(filenames.deps, prop_train=0.1),
    'predict_number':
    PredictVerbNumber(filenames.deps, prop_train=0.1),
    'language_model':
    LanguageModel(filenames.deps, prop_train=0.1),
    'inflect_verb':
    InflectVerb(filenames.deps, prop_train=0.1),
    'predict_number_targeted':
    PredictVerbNumber(filenames.deps, prop_train=0.2, criterion=intervening),
    'predict_number_only_nouns':
    PredictVerbNumberOnlyNouns(filenames.deps, prop_train=0.1),
    'predict_number_only_generalized_nouns':
    PredictVerbNumberOnlyGeneralizedNouns(filenames.deps, prop_train=0.1),
    'predict_number_srn':
    PredictVerbNumber(filenames.deps, prop_train=0.1, rnn_class=SimpleRNN),
}
Пример #32
0
 def setUp(self):
     self.expected = {'hello': 10, 'world': 20, 'goodbye': 10}
     self.logger = logging.getLogger("TestLanguageModel")
     self.bk_model = LanguageModel(file='term_occurrences.txt')
     self.doc_model = LanguageModel(term_dict=self.expected)
Пример #33
0
class UnsupervisedGrammarCorrector:
    def __init__(self, threshold=0.96):
        basename = os.path.dirname(os.path.realpath(__file__))
        self.lm = LanguageModel()
        # Load spaCy
        self.nlp = spacy.load("en")
        # Hunspell spellchecker: https://pypi.python.org/pypi/CyHunspell
        # CyHunspell seems to be more accurate than Aspell in PyEnchant, but a bit slower.
        self.gb = Hunspell("en_GB-large",
                           hunspell_data_dir=basename + '/resources/spelling/')
        # Inflection forms: http://wordlist.aspell.net/other/
        self.gb_infl = loadWordFormDict(basename +
                                        "/resources/agid-2016.01.19/infl.txt")
        # List of common determiners
        self.determiners = {"", "the", "a", "an"}
        # List of common prepositions
        self.prepositions = {
            "", "about", "at", "by", "for", "from", "in", "of", "on", "to",
            "with"
        }
        self.threshold = threshold

    def correct(self, sentence):
        # If the line is empty, preserve the newline in output and continue
        if not sentence:
            return ""
        best = sentence
        score = self.lm.score(best)

        while True:
            new_best, new_score = self.process(best)
            if new_best and new_score > score:
                best = new_best
                score = new_score
            else:
                break

        return best

    def process(self, sentence: str) -> Tuple[str, bool]:
        # Process sent with spacy
        proc_sent = self.nlp.tokenizer(sentence)
        self.nlp.tagger(proc_sent)
        # Calculate avg token prob of the sent so far.
        orig_prob = self.lm.score(proc_sent.text)
        # Store all the candidate corrected sentences here
        candidates = []
        # Process each token.
        for tok in proc_sent:
            # SPELLCHECKING
            # Spell check: tok must be alphabetical and not a real word.

            candidate_tokens = set()

            lower_cased_token = tok.lower_

            if lower_cased_token.isalpha(
            ) and not self.gb.spell(lower_cased_token):
                candidate_tokens |= set(self.gb.suggest(lower_cased_token))
            # MORPHOLOGY
            if tok.lemma_ in self.gb_infl:
                candidate_tokens |= self.gb_infl[tok.lemma_]
            # DETERMINERS
            if lower_cased_token in self.determiners:
                candidate_tokens |= self.determiners
            # PREPOSITIONS
            if lower_cased_token in self.prepositions:
                candidate_tokens |= self.prepositions

            candidate_tokens = [
                c for c in candidate_tokens if self.gb.spell(c)
            ]

            if candidate_tokens:
                if tok.is_title:
                    candidate_tokens = [c.title() for c in candidate_tokens]
                elif tok.is_upper:
                    candidate_tokens = [c.upper() for c in candidate_tokens]

                candidates.extend(
                    self._generate_candidates(tok.i, candidate_tokens,
                                              proc_sent))

        best_prob = orig_prob
        best = sentence

        for candidate in candidates:
            # Score the candidate sentence
            cand_prob = self.lm.score(candidate.text)
            print(candidate.text, self.lm.score(candidate.text), cand_prob)

            # Compare cand_prob against weighted orig_prob and best_prob
            if cand_prob > best_prob:
                best_prob = cand_prob
                best = candidate.text
        # Return the best sentence and a boolean whether to search for more errors
        return best, best_prob

    def _generate_candidates(self, tok_id, candidate_tokens,
                             tokenized_sentence) -> List[str]:
        # Save candidates here.
        candidates = []

        prefix = tokenized_sentence[:tok_id]
        suffix = tokenized_sentence[tok_id + 1:]
        # Loop through the input alternative candidates
        for token in candidate_tokens:
            candidate = prefix.text_with_ws
            if token:
                candidate += token + " "
            candidate += suffix.text_with_ws
            candidate = self.nlp.tokenizer(candidate)
            candidates.append(candidate)
        return candidates
    path_to_lm = path_to_root + 'resources/en-70k-0.2.lm'

# Load Word2Vec (takes approx. 8G RAM)
print "loading GoogleNews..."
start = time.time()
# vectors = Word2Vec(size=3e2, min_count=1)
# vectors.build_vocab([item for sublist in lists_of_tokens.values() for item in sublist])
# vectors.intersect_word2vec_format(path_to_wv, binary=True)
wv = gensim.models.KeyedVectors.load_word2vec_format(path_to_wv, binary=True)
# vectors = Word2Vec.load_word2vec_format(path_to_wv, binary=True)
print "finish loading GoogleNews, time_cost = %.2fs" % (time.time() - start)

# Load language model (takes approx. 8G RAM)
print "loading language model..."
start = time.time()
lm = LanguageModel(model_path=path_to_lm)
print "finish loading language model, time_cost = %.2fs" % (time.time() -
                                                            start)

# ######################
# ### PARAMETER GRID ###
# ######################
system_name_list = ['filippova', 'boudin', 'mehdad', 'tixier']
system_params_dict = {}

for system_name in system_name_list:
    # pos_filtering_grid = [True, False] if system_name == 'tixier' or system_name == 'mehdad' else [False]
    # cr_w_grid = [3, 10, 20] if system_name == 'tixier' else [3]
    cr_w_grid = [6, 12] if system_name == 'tixier' else [3]
    cr_overspanning_grid = [True, False
                            ] if system_name == 'tixier' else [False]
Пример #35
0
    def print_or_value(id, calculated, value):
        if value == calculated:
            print(True)
            # print()
        else:
            print(id,calculated)
            print()

    sentance_pairs = [(["la", "casa"],["the","big","house"]),(["casa", "pez","verde"],["green","house"]),(["casa"],["shop"])]
    t_f_given_e = ibmmodel1.train(sentance_pairs, 100)
    reversed = [(x,y) for y,x in sentance_pairs]
    t_e_given_f = ibmmodel1.train(reversed, 100)
    alignments = [ibmmodel1.get_phrase_alignment(t_f_given_e, t_e_given_f, fs, es) for fs, es in sentance_pairs]

    phrase_table = ibmmodel1.get_phrase_probabilities(alignments, sentance_pairs)
    lang_model = LanguageModel([e for _,e in sentance_pairs], n=2)
    ibmmodel1.print_phrase_table(phrase_table)
    # Tests:
    foreign_sentence = "la casa".split(" ")
    print_or_value(1, cur_cost([], foreign_sentence, phrase_table, lang_model), 1)
    print_or_value(2, cur_cost([(0,0,"the big")], foreign_sentence, phrase_table, lang_model), 0.041666666666666664)
    print_or_value(3, cur_cost([(1,1,"shop")], foreign_sentence, phrase_table, lang_model), 0.125)
    print_or_value(4, cur_cost([(0,0,"the big house")], foreign_sentence, phrase_table, lang_model), 0.013888888888888888)
    print_or_value(5, cur_cost([(0,0,"the big"),(1,1,"shop")], foreign_sentence, phrase_table, lang_model), 0.003472222222222222)

    phrase_to_max_prob = get_phrase_to_max_prob(phrase_table)
    print_or_value(6, future_cost([], foreign_sentence, phrase_to_max_prob), 0.25)
    print_or_value(7, future_cost([(0,0,"the big")], foreign_sentence, phrase_to_max_prob), 0.5)
    print_or_value(8, future_cost([(1,1,"shop")], foreign_sentence, phrase_to_max_prob), 0.5)
    print_or_value(9, future_cost([(0,0,"the big house")], foreign_sentence, phrase_to_max_prob), 0.5)
    print_or_value(10, future_cost([(0,0,"the big"),(1,1,"shop")], foreign_sentence, phrase_to_max_prob), 1)