Example #1
0
def accuracy(session, dataset, name,
             sentence1_ph, sentence1_length_ph, sentence2_ph, sentence2_length_ph, label_ph, dropout_keep_prob_ph,
             predictions_int, labels_int, contradiction_idx, entailment_idx, neutral_idx, batch_size):

    nb_eval_instances = len(dataset['sentence1'])
    eval_batches = make_batches(size=nb_eval_instances, batch_size=batch_size)
    p_vals, l_vals = [], []

    for e_batch_start, e_batch_end in eval_batches:
        feed_dict = {
            sentence1_ph: dataset['sentence1'][e_batch_start:e_batch_end],
            sentence1_length_ph: dataset['sentence1_length'][e_batch_start:e_batch_end],
            sentence2_ph: dataset['sentence2'][e_batch_start:e_batch_end],
            sentence2_length_ph: dataset['sentence2_length'][e_batch_start:e_batch_end],
            label_ph: dataset['label'][e_batch_start:e_batch_end],
            dropout_keep_prob_ph: 1.0
        }

        p_val, l_val = session.run([predictions_int, labels_int], feed_dict=feed_dict)

        p_vals += p_val.tolist()
        l_vals += l_val.tolist()

    matches = np.equal(p_vals, l_vals)
    acc = np.mean(matches)

    acc_c = np.mean(matches[np.where(np.array(l_vals) == contradiction_idx)])
    acc_e = np.mean(matches[np.where(np.array(l_vals) == entailment_idx)])
    acc_n = np.mean(matches[np.where(np.array(l_vals) == neutral_idx)])

    if name:
        logger.debug('{0} Accuracy: {1:.4f} - C: {2:.4f}, E: {3:.4f}, N: {4:.4f}'.format(
            name, acc * 100, acc_c * 100, acc_e * 100, acc_n * 100))

    return acc, acc_c, acc_e, acc_n
Example #2
0
    def create_batches(self):
        order = self.random_state.permutation(self.nb_samples)
        tensor_shuf = self.tensor[order, :]

        _batch_lst = make_batches(self.nb_samples, self.batch_size)
        self.batches = []

        for batch_start, batch_end in _batch_lst:
            batch_size = batch_end - batch_start
            batch = tensor_shuf[batch_start:batch_end, :]

            assert batch.shape[0] == batch_size

            x = np.zeros(shape=(batch_size, self.seq_length))
            y = np.zeros(shape=(batch_size, self.seq_length))

            for i in range(batch_size):
                start_idx = self.random_state.randint(low=0,
                                                      high=self.max_len - 1)
                end_idx = min(start_idx + self.seq_length, self.max_len)

                x[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx]

                start_idx += 1
                end_idx = min(start_idx + self.seq_length, self.max_len)

                y[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx]

                d = {'x': x, 'y': y}
                self.batches += [d]

        self.num_batches = len(self.batches)
        return
Example #3
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE via Adversarial Sets Regularisation',
        formatter_class=fmt)

    argparser.add_argument('--train',
                           '-t',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--valid',
                           '-v',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_dev.jsonl.gz')
    argparser.add_argument('--test',
                           '-T',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_test.jsonl.gz')

    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='cbilstm',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])
    argparser.add_argument('--optimizer',
                           '-o',
                           action='store',
                           type=str,
                           default='adagrad',
                           choices=['adagrad', 'adam'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           action='store',
                           type=int,
                           default=1024)

    argparser.add_argument('--nb-epochs',
                           '-e',
                           action='store',
                           type=int,
                           default=1000)
    argparser.add_argument('--nb-discriminator-epochs',
                           '-D',
                           action='store',
                           type=int,
                           default=1)
    argparser.add_argument('--nb-adversary-epochs',
                           '-A',
                           action='store',
                           type=int,
                           default=1000)

    argparser.add_argument('--dropout-keep-prob',
                           action='store',
                           type=float,
                           default=1.0)
    argparser.add_argument('--learning-rate',
                           action='store',
                           type=float,
                           default=0.1)
    argparser.add_argument('--clip',
                           '-c',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--nb-words',
                           action='store',
                           type=int,
                           default=None)
    argparser.add_argument('--seed', action='store', type=int, default=0)
    argparser.add_argument('--std-dev',
                           action='store',
                           type=float,
                           default=0.01)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--initialize-embeddings',
                           '-i',
                           action='store',
                           type=str,
                           default=None,
                           choices=['normal', 'uniform'])

    argparser.add_argument('--fixed-embeddings', '-f', action='store_true')
    argparser.add_argument('--normalize-embeddings', '-n', action='store_true')
    argparser.add_argument('--only-use-pretrained-embeddings',
                           '-p',
                           action='store_true',
                           help='Only use pre-trained word embeddings')
    argparser.add_argument('--train-special-token-embeddings',
                           '-s',
                           action='store_true')
    argparser.add_argument('--semi-sort', '-S', action='store_true')

    argparser.add_argument('--save', action='store', type=str, default=None)
    argparser.add_argument('--hard-save',
                           action='store',
                           type=str,
                           default=None)
    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--glove', action='store', type=str, default=None)
    argparser.add_argument('--word2vec',
                           action='store',
                           type=str,
                           default=None)

    argparser.add_argument('--rule0-weight',
                           '-0',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule1-weight',
                           '-1',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule2-weight',
                           '-2',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule3-weight',
                           '-3',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule4-weight',
                           '-4',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule5-weight',
                           '-5',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule6-weight',
                           '-6',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule7-weight',
                           '-7',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule8-weight',
                           '-8',
                           action='store',
                           type=float,
                           default=None)

    argparser.add_argument('--adversarial-batch-size',
                           '-B',
                           action='store',
                           type=int,
                           default=32)
    argparser.add_argument('--adversarial-sentence-length',
                           '-L',
                           action='store',
                           type=int,
                           default=10)

    argparser.add_argument('--adversarial-pooling',
                           '-P',
                           default='max',
                           choices=['sum', 'max', 'mean', 'logsumexp'])
    argparser.add_argument(
        '--adversarial-smart-init',
        '-I',
        action='store_true',
        default=False,
        help='Initialize sentence embeddings with actual word embeddings')

    argparser.add_argument(
        '--report',
        '-r',
        default=100,
        type=int,
        help='Number of batches between performance reports')
    argparser.add_argument('--report-loss',
                           default=100,
                           type=int,
                           help='Number of batches between loss reports')

    argparser.add_argument(
        '--memory-limit',
        default=None,
        type=int,
        help=
        'The maximum area (in bytes) of address space which may be taken by the process.'
    )
    argparser.add_argument('--universum', '-U', action='store_true')

    args = argparser.parse_args(argv)

    # Command line arguments
    train_path, valid_path, test_path = args.train, args.valid, args.test

    model_name = args.model
    optimizer_name = args.optimizer

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    nb_epochs = args.nb_epochs
    nb_discriminator_epochs = args.nb_discriminator_epochs
    nb_adversary_epochs = args.nb_adversary_epochs

    dropout_keep_prob = args.dropout_keep_prob
    learning_rate = args.learning_rate
    clip_value = args.clip
    seed = args.seed
    std_dev = args.std_dev

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    initialize_embeddings = args.initialize_embeddings

    is_fixed_embeddings = args.fixed_embeddings
    is_normalize_embeddings = args.normalize_embeddings
    is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings
    is_train_special_token_embeddings = args.train_special_token_embeddings
    is_semi_sort = args.semi_sort

    logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format(
        has_bos, has_eos, has_unk))
    logger.info(
        'is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}'.
        format(is_lower, is_fixed_embeddings, is_normalize_embeddings))
    logger.info(
        'is_only_use_pretrained_embeddings: {}, is_train_special_token_embeddings: {}, is_semi_sort: {}'
        .format(is_only_use_pretrained_embeddings,
                is_train_special_token_embeddings, is_semi_sort))

    save_path = args.save
    hard_save_path = args.hard_save
    restore_path = args.restore

    glove_path = args.glove
    word2vec_path = args.word2vec

    # Experimental RTE regularizers
    rule0_weight = args.rule0_weight
    rule1_weight = args.rule1_weight
    rule2_weight = args.rule2_weight
    rule3_weight = args.rule3_weight
    rule4_weight = args.rule4_weight
    rule5_weight = args.rule5_weight
    rule6_weight = args.rule6_weight
    rule7_weight = args.rule7_weight
    rule8_weight = args.rule8_weight

    adversarial_batch_size = args.adversarial_batch_size
    adversarial_sentence_length = args.adversarial_sentence_length
    adversarial_pooling_name = args.adversarial_pooling
    adversarial_smart_init = args.adversarial_smart_init

    name_to_adversarial_pooling = {
        'sum': tf.reduce_sum,
        'max': tf.reduce_max,
        'mean': tf.reduce_mean,
        'logsumexp': tf.reduce_logsumexp
    }

    report_interval = args.report
    report_loss_interval = args.report_loss

    memory_limit = args.memory_limit
    is_universum = args.universum

    if memory_limit:
        import resource
        soft, hard = resource.getrlimit(resource.RLIMIT_AS)
        logging.info('Current memory limit: {}, {}'.format(soft, hard))
        resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))
        soft, hard = resource.getrlimit(resource.RLIMIT_AS)
        logging.info('New memory limit: {}, {}'.format(soft, hard))

    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path,
                                                   valid_path=valid_path,
                                                   test_path=test_path,
                                                   is_lower=is_lower)

    logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format(
        len(train_is), len(dev_is), len(test_is)))
    all_is = train_is + dev_is + test_is

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3
    start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else
                                             0) + (1 if has_unk else 0)

    if not restore_path:
        # Create a sequence of tokens containing all sentences in the dataset
        token_seq = []
        for instance in all_is:
            token_seq += instance['sentence1_parse_tokens'] + instance[
                'sentence2_parse_tokens']

        token_set = set(token_seq)
        allowed_words = None
        if is_only_use_pretrained_embeddings:
            assert (glove_path is not None) or (word2vec_path is not None)
            if glove_path:
                logger.info('Loading GloVe words from {}'.format(glove_path))
                assert os.path.isfile(glove_path)
                allowed_words = load_glove_words(path=glove_path,
                                                 words=token_set)
            elif word2vec_path:
                logger.info(
                    'Loading word2vec words from {}'.format(word2vec_path))
                assert os.path.isfile(word2vec_path)
                allowed_words = load_word2vec_words(path=word2vec_path,
                                                    words=token_set)
            logger.info('Number of allowed words: {}'.format(
                len(allowed_words)))

        # Count the number of occurrences of each token
        token_counts = dict()
        for token in token_seq:
            if (allowed_words is None) or (token in allowed_words):
                if token not in token_counts:
                    token_counts[token] = 0
                token_counts[token] += 1

        # Sort the tokens according to their frequency and lexicographic ordering
        sorted_vocabulary = sorted(token_counts.keys(),
                                   key=lambda t: (-token_counts[t], t))

        index_to_token = {
            index: token
            for index, token in enumerate(sorted_vocabulary, start=start_idx)
        }
    else:
        with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
            index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx, none_idx = 0, 1, 2, 3
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    if is_universum:
        label_to_index['none'] = none_idx

    max_len = None
    optimizer_name_to_class = {
        'adagrad': tf.train.AdagradOptimizer,
        'adam': tf.train.AdamOptimizer
    }

    optimizer_class = optimizer_name_to_class[optimizer_name]
    assert optimizer_class

    optimizer = optimizer_class(learning_rate=learning_rate)

    args = dict(has_bos=has_bos,
                has_eos=has_eos,
                has_unk=has_unk,
                bos_idx=bos_idx,
                eos_idx=eos_idx,
                unk_idx=unk_idx,
                max_len=max_len)

    train_dataset = util.instances_to_dataset(train_is, token_to_index,
                                              label_to_index, **args)
    dev_dataset = util.instances_to_dataset(dev_is, token_to_index,
                                            label_to_index, **args)
    test_dataset = util.instances_to_dataset(test_is, token_to_index,
                                             label_to_index, **args)

    sentence1 = train_dataset['sentence1']
    sentence1_length = train_dataset['sentence1_length']
    sentence2 = train_dataset['sentence2']
    sentence2_length = train_dataset['sentence2_length']
    label = train_dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label')

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    nb_words = len(token_to_index)
    nb_special_tokens = vocab_size - nb_words

    token_to_embedding = dict()
    if not restore_path:
        if glove_path:
            logger.info(
                'Loading GloVe word embeddings from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            token_to_embedding = load_glove(glove_path, token_set)
        elif word2vec_path:
            logger.info('Loading word2vec word embeddings from {}'.format(
                word2vec_path))
            assert os.path.isfile(word2vec_path)
            token_to_embedding = load_word2vec(word2vec_path, token_set)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        if initialize_embeddings == 'normal':
            logger.info('Initializing the embeddings with 𝓝(0, 1)')
            embedding_initializer = tf.random_normal_initializer(0.0, 1.0)
        elif initialize_embeddings == 'uniform':
            logger.info('Initializing the embeddings with 𝒰(-1, 1)')
            embedding_initializer = tf.random_uniform_initializer(minval=-1.0,
                                                                  maxval=1.0)
        else:
            logger.info(
                'Initializing the embeddings with Xavier initialization')
            embedding_initializer = tf.contrib.layers.xavier_initializer()

        if is_train_special_token_embeddings:
            embedding_layer_special = tf.get_variable(
                'special_embeddings',
                shape=[nb_special_tokens, embedding_size],
                initializer=embedding_initializer,
                trainable=True)
            embedding_layer_words = tf.get_variable(
                'word_embeddings',
                shape=[nb_words, embedding_size],
                initializer=embedding_initializer,
                trainable=not is_fixed_embeddings)
            embedding_layer = tf.concat(
                values=[embedding_layer_special, embedding_layer_words],
                axis=0)
        else:
            embedding_layer = tf.get_variable(
                'embeddings',
                shape=[vocab_size, embedding_size],
                initializer=embedding_initializer,
                trainable=not is_fixed_embeddings)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32,
                                              name='dropout_keep_prob')

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        if is_universum:
            model_kwargs['nb_classes'] = 4

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = std_dev

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        predictions = tf.argmax(logits, axis=1, name='predictions')

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ph)
        loss = tf.reduce_mean(losses)

        if rule0_weight:
            loss += rule0_weight * contradiction_symmetry_l2(
                model_class, model_kwargs, contradiction_idx=contradiction_idx)

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)
    discriminator_init_op = tf.variables_initializer(discriminator_vars)

    trainable_discriminator_vars = list(discriminator_vars)
    if is_fixed_embeddings:
        if is_train_special_token_embeddings:
            trainable_discriminator_vars.remove(embedding_layer_words)
        else:
            trainable_discriminator_vars.remove(embedding_layer)

    discriminator_optimizer_scope_name = 'discriminator_optimizer'
    with tf.variable_scope(discriminator_optimizer_scope_name):
        if clip_value:
            gradients, v = zip(*optimizer.compute_gradients(
                loss, var_list=trainable_discriminator_vars))
            gradients, _ = tf.clip_by_global_norm(gradients, clip_value)
            training_step = optimizer.apply_gradients(zip(gradients, v))
        else:
            training_step = optimizer.minimize(
                loss, var_list=trainable_discriminator_vars)

    discriminator_optimizer_vars = tfutil.get_variables_in_scope(
        discriminator_optimizer_scope_name)
    discriminator_optimizer_init_op = tf.variables_initializer(
        discriminator_optimizer_vars)

    token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx')
    token_embedding_ph = tf.placeholder(dtype=tf.float32,
                                        shape=[None],
                                        name='word_embedding')

    if is_train_special_token_embeddings:
        assign_token_embedding = embedding_layer_words[
            token_idx_ph - nb_special_tokens, :].assign(token_embedding_ph)
    else:
        assign_token_embedding = embedding_layer[token_idx_ph, :].assign(
            token_embedding_ph)

    init_projection_steps = []
    learning_projection_steps = []

    if is_normalize_embeddings:
        if is_train_special_token_embeddings:
            special_embeddings_projection = constraints.unit_sphere(
                embedding_layer_special, norm=1.0)
            word_embeddings_projection = constraints.unit_sphere(
                embedding_layer_words, norm=1.0)

            init_projection_steps += [special_embeddings_projection]
            init_projection_steps += [word_embeddings_projection]

            learning_projection_steps += [special_embeddings_projection]
            if not is_fixed_embeddings:
                learning_projection_steps += [word_embeddings_projection]
        else:
            embeddings_projection = constraints.unit_sphere(embedding_layer,
                                                            norm=1.0)
            init_projection_steps += [embeddings_projection]

            if not is_fixed_embeddings:
                learning_projection_steps += [embeddings_projection]

    predictions_int = tf.cast(predictions, tf.int32)
    labels_int = tf.cast(label_ph, tf.int32)

    use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight

    if use_adversarial_training:
        adversary_scope_name = discriminator_scope_name
        with tf.variable_scope(adversary_scope_name):
            adversarial = AdversarialSets(
                model_class=model_class,
                model_kwargs=model_kwargs,
                embedding_size=embedding_size,
                scope_name='adversary',
                batch_size=adversarial_batch_size,
                sequence_length=adversarial_sentence_length,
                entailment_idx=entailment_idx,
                contradiction_idx=contradiction_idx,
                neutral_idx=neutral_idx)

            adversary_loss = tf.constant(0.0, dtype=tf.float32)
            adversary_vars = []

            adversarial_pooling = name_to_adversarial_pooling[
                adversarial_pooling_name]

            if rule1_weight:
                rule1_loss, rule1_vars = adversarial.rule1_loss()
                adversary_loss += rule1_weight * adversarial_pooling(
                    rule1_loss)
                adversary_vars += rule1_vars
            if rule2_weight:
                rule2_loss, rule2_vars = adversarial.rule2_loss()
                adversary_loss += rule2_weight * adversarial_pooling(
                    rule2_loss)
                adversary_vars += rule2_vars
            if rule3_weight:
                rule3_loss, rule3_vars = adversarial.rule3_loss()
                adversary_loss += rule3_weight * adversarial_pooling(
                    rule3_loss)
                adversary_vars += rule3_vars
            if rule4_weight:
                rule4_loss, rule4_vars = adversarial.rule4_loss()
                adversary_loss += rule4_weight * adversarial_pooling(
                    rule4_loss)
                adversary_vars += rule4_vars
            if rule5_weight:
                rule5_loss, rule5_vars = adversarial.rule5_loss()
                adversary_loss += rule5_weight * adversarial_pooling(
                    rule5_loss)
                adversary_vars += rule5_vars
            if rule6_weight:
                rule6_loss, rule6_vars = adversarial.rule6_loss()
                adversary_loss += rule6_weight * adversarial_pooling(
                    rule6_loss)
                adversary_vars += rule6_vars
            if rule7_weight:
                rule7_loss, rule7_vars = adversarial.rule7_loss()
                adversary_loss += rule7_weight * adversarial_pooling(
                    rule7_loss)
                adversary_vars += rule7_vars
            if rule8_weight:
                rule8_loss, rule8_vars = adversarial.rule8_loss()
                adversary_loss += rule8_weight * adversarial_pooling(
                    rule8_loss)
                adversary_vars += rule8_vars

            loss += adversary_loss

            assert len(adversary_vars) > 0
            for adversary_var in adversary_vars:
                assert adversary_var.name.startswith(
                    'discriminator/adversary/rule')

            adversary_var_to_assign_op = dict()
            adversary_var_value_ph = tf.placeholder(dtype=tf.float32,
                                                    shape=[None, None, None],
                                                    name='adversary_var_value')
            for a_var in adversary_vars:
                adversary_var_to_assign_op[a_var] = a_var.assign(
                    adversary_var_value_ph)

        adversary_init_op = tf.variables_initializer(adversary_vars)

        adv_opt_scope_name = 'adversary_optimizer'
        with tf.variable_scope(adv_opt_scope_name):
            adversary_optimizer = optimizer_class(learning_rate=learning_rate)
            adversary_training_step = adversary_optimizer.minimize(
                -adversary_loss, var_list=adversary_vars)

            adversary_optimizer_vars = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope=adv_opt_scope_name)
            adversary_optimizer_init_op = tf.variables_initializer(
                adversary_optimizer_vars)

        logger.info(
            'Adversarial Batch Size: {}'.format(adversarial_batch_size))

        adversary_projection_steps = []
        for var in adversary_vars:
            if is_normalize_embeddings:
                unit_sphere_adversarial_embeddings = constraints.unit_sphere(
                    var, norm=1.0, axis=-1)
                adversary_projection_steps += [
                    unit_sphere_adversarial_embeddings
                ]

            assert adversarial_batch_size == var.get_shape()[0].value

            def token_init_op(_var, _token_idx, target_idx):
                token_emb = tf.nn.embedding_lookup(embedding_layer, _token_idx)
                tiled_token_emb = tf.tile(tf.expand_dims(token_emb, 0),
                                          (adversarial_batch_size, 1))
                return _var[:, target_idx, :].assign(tiled_token_emb)

            if has_bos:
                adversary_projection_steps += [token_init_op(var, bos_idx, 0)]

    saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars,
                           max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    # session_config.log_device_placement = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))
        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(
                var_list=trainable_discriminator_vars)))

        if use_adversarial_training:
            session.run([adversary_init_op, adversary_optimizer_init_op])

        if restore_path:
            saver.restore(session, restore_path)
        else:
            session.run(
                [discriminator_init_op, discriminator_optimizer_init_op])

            # Initialising pre-trained embeddings
            logger.info('Initialising the embeddings pre-trained vectors ..')
            for token in token_to_embedding:
                token_idx, token_embedding = token_to_index[
                    token], token_to_embedding[token]
                assert embedding_size == len(token_embedding)
                session.run(assign_token_embedding,
                            feed_dict={
                                token_idx_ph: token_idx,
                                token_embedding_ph: token_embedding
                            })
            logger.info('Done!')

            for adversary_projection_step in init_projection_steps:
                session.run([adversary_projection_step])

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        best_dev_acc, best_test_acc = None, None
        discriminator_batch_counter = 0

        for epoch in range(1, nb_epochs + 1):

            for d_epoch in range(1, nb_discriminator_epochs + 1):
                order = random_state.permutation(nb_instances)

                sentences1, sentences2 = sentence1[order], sentence2[order]
                sizes1, sizes2 = sentence1_length[order], sentence2_length[
                    order]
                labels = label[order]

                if is_semi_sort:
                    order = util.semi_sort(sizes1, sizes2)
                    sentences1, sentences2 = sentence1[order], sentence2[order]
                    sizes1, sizes2 = sentence1_length[order], sentence2_length[
                        order]
                    labels = label[order]

                loss_values, epoch_loss_values = [], []
                for batch_idx, (batch_start, batch_end) in enumerate(batches):
                    discriminator_batch_counter += 1

                    batch_sentences1, batch_sentences2 = sentences1[
                        batch_start:batch_end], sentences2[
                            batch_start:batch_end]
                    batch_sizes1, batch_sizes2 = sizes1[
                        batch_start:batch_end], sizes2[batch_start:batch_end]
                    batch_labels = labels[batch_start:batch_end]

                    batch_max_size1 = np.max(batch_sizes1)
                    batch_max_size2 = np.max(batch_sizes2)

                    batch_sentences1 = batch_sentences1[:, :batch_max_size1]
                    batch_sentences2 = batch_sentences2[:, :batch_max_size2]

                    batch_feed_dict = {
                        sentence1_ph: batch_sentences1,
                        sentence1_len_ph: batch_sizes1,
                        sentence2_ph: batch_sentences2,
                        sentence2_len_ph: batch_sizes2,
                        label_ph: batch_labels,
                        dropout_keep_prob_ph: dropout_keep_prob
                    }

                    _, loss_value = session.run([training_step, loss],
                                                feed_dict=batch_feed_dict)

                    logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format(
                        epoch, d_epoch, batch_idx, loss_value))

                    cur_batch_size = batch_sentences1.shape[0]
                    loss_values += [loss_value / cur_batch_size]
                    epoch_loss_values += [loss_value / cur_batch_size]

                    for adversary_projection_step in learning_projection_steps:
                        session.run([adversary_projection_step])

                    if discriminator_batch_counter % report_loss_interval == 0:
                        logger.info(
                            'Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format(
                                epoch, d_epoch, batch_idx, stats(loss_values)))
                        loss_values = []

                    if discriminator_batch_counter % report_interval == 0:
                        accuracy_args = [
                            sentence1_ph, sentence1_len_ph, sentence2_ph,
                            sentence2_len_ph, label_ph, dropout_keep_prob_ph,
                            predictions_int, labels_int, contradiction_idx,
                            entailment_idx, neutral_idx, batch_size
                        ]
                        dev_acc, _, _, _ = accuracy(session, dev_dataset,
                                                    'Dev', *accuracy_args)
                        test_acc, _, _, _ = accuracy(session, test_dataset,
                                                     'Test', *accuracy_args)

                        logger.info(
                            'Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}'
                            .format(epoch, d_epoch, batch_idx, dev_acc * 100,
                                    test_acc * 100))

                        if best_dev_acc is None or dev_acc > best_dev_acc:
                            best_dev_acc, best_test_acc = dev_acc, test_acc

                            if save_path:
                                with open(
                                        '{}_index_to_token.p'.format(
                                            save_path), 'wb') as f:
                                    pickle.dump(index_to_token, f)

                                saved_path = saver.save(session, save_path)
                                logger.info('Model saved in file: {}'.format(
                                    saved_path))

                        logger.info(
                            'Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}'
                            .format(epoch, d_epoch, batch_idx,
                                    best_dev_acc * 100, best_test_acc * 100))

                logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format(
                    epoch, d_epoch, stats(epoch_loss_values)))

                if hard_save_path:
                    with open('{}_index_to_token.p'.format(hard_save_path),
                              'wb') as f:
                        pickle.dump(index_to_token, f)

                    hard_saved_path = saver.save(session, hard_save_path)
                    logger.info(
                        'Model saved in file: {}'.format(hard_saved_path))

            if use_adversarial_training:
                session.run([adversary_init_op, adversary_optimizer_init_op])

                if adversarial_smart_init:
                    _token_indices = np.array(sorted(index_to_token.keys()))

                    for a_var in adversary_vars:
                        # Create a [batch size, sentence length, embedding size] NumPy tensor of sentence embeddings
                        a_word_idx = _token_indices[random_state.randint(
                            low=0,
                            high=len(_token_indices),
                            size=[
                                adversarial_batch_size,
                                adversarial_sentence_length
                            ])]
                        np_embedding_layer = session.run(embedding_layer)
                        np_adversarial_embeddings = np_embedding_layer[
                            a_word_idx]
                        assert np_adversarial_embeddings.shape == (
                            adversarial_batch_size,
                            adversarial_sentence_length, embedding_size)

                        assert a_var in adversary_var_to_assign_op
                        assign_op = adversary_var_to_assign_op[a_var]

                        logger.info(
                            'Clever initialization of the adversarial embeddings ..'
                        )
                        session.run(assign_op,
                                    feed_dict={
                                        adversary_var_value_ph:
                                        np_adversarial_embeddings
                                    })

                for a_epoch in range(1, nb_adversary_epochs + 1):
                    adversary_feed_dict = {dropout_keep_prob_ph: 1.0}
                    _, adversary_loss_value = session.run(
                        [adversary_training_step, adversary_loss],
                        feed_dict=adversary_feed_dict)
                    logger.info('Adversary Epoch {0}/{1}\tLoss: {2}'.format(
                        epoch, a_epoch, adversary_loss_value))

                    for adversary_projection_step in adversary_projection_steps:
                        session.run(adversary_projection_step)

    logger.info('Training finished.')
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt)

    argparser.add_argument('--train', '-t', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--valid', '-v', action='store', type=str, default='data/snli/snli_1.0_dev.jsonl.gz')
    argparser.add_argument('--test', '-T', action='store', type=str, default='data/snli/snli_1.0_test.jsonl.gz')

    argparser.add_argument('--model', '-m', action='store', type=str, default='cbilstm',
                           choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])
    argparser.add_argument('--optimizer', '-o', action='store', type=str, default='adagrad',
                           choices=['adagrad', 'adam'])

    argparser.add_argument('--embedding-size', action='store', type=int, default=300)
    argparser.add_argument('--representation-size', action='store', type=int, default=200)

    argparser.add_argument('--batch-size', action='store', type=int, default=1024)

    argparser.add_argument('--nb-epochs', '-e', action='store', type=int, default=1000)
    argparser.add_argument('--nb-discriminator-epochs', '-D', action='store', type=int, default=1)
    argparser.add_argument('--nb-adversary-epochs', '-A', action='store', type=int, default=1000)

    argparser.add_argument('--dropout-keep-prob', action='store', type=float, default=1.0)
    argparser.add_argument('--learning-rate', action='store', type=float, default=0.1)
    argparser.add_argument('--clip', '-c', action='store', type=float, default=None)
    argparser.add_argument('--nb-words', action='store', type=int, default=None)
    argparser.add_argument('--seed', action='store', type=int, default=0)
    argparser.add_argument('--std-dev', action='store', type=float, default=0.01)

    argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token')
    argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus')

    argparser.add_argument('--initialize-embeddings', '-i', action='store', type=str, default=None,
                           choices=['normal', 'uniform'])

    argparser.add_argument('--fixed-embeddings', '-f', action='store_true')
    argparser.add_argument('--normalize-embeddings', '-n', action='store_true')
    argparser.add_argument('--only-use-pretrained-embeddings', '-p', action='store_true',
                           help='Only use pre-trained word embeddings')
    argparser.add_argument('--semi-sort', '-S', action='store_true')

    argparser.add_argument('--save', action='store', type=str, default=None)
    argparser.add_argument('--hard-save', action='store', type=str, default=None)
    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--glove', action='store', type=str, default=None)

    argparser.add_argument('--rule00-weight', '--00', action='store', type=float, default=None)
    argparser.add_argument('--rule01-weight', '--01', action='store', type=float, default=None)
    argparser.add_argument('--rule02-weight', '--02', action='store', type=float, default=None)
    argparser.add_argument('--rule03-weight', '--03', action='store', type=float, default=None)

    for i in range(1, 9):
        argparser.add_argument('--rule{}-weight'.format(i), '-{}'.format(i), action='store', type=float, default=None)

    argparser.add_argument('--adversarial-batch-size', '-B', action='store', type=int, default=32)
    argparser.add_argument('--adversarial-pooling', '-P', default='max', choices=['sum', 'max', 'mean', 'logsumexp'])

    argparser.add_argument('--report', '-r', default=100, type=int, help='Number of batches between performance reports')
    argparser.add_argument('--report-loss', default=100, type=int, help='Number of batches between loss reports')

    argparser.add_argument('--eval', '-E', nargs='+', type=str, help='Evaluate on these additional sets')

    args = argparser.parse_args(argv)

    # Command line arguments
    train_path, valid_path, test_path = args.train, args.valid, args.test

    model_name = args.model
    optimizer_name = args.optimizer

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    nb_epochs = args.nb_epochs
    nb_discriminator_epochs = args.nb_discriminator_epochs

    dropout_keep_prob = args.dropout_keep_prob
    learning_rate = args.learning_rate
    clip_value = args.clip
    seed = args.seed
    std_dev = args.std_dev

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    initialize_embeddings = args.initialize_embeddings

    is_fixed_embeddings = args.fixed_embeddings
    is_normalize_embeddings = args.normalize_embeddings
    is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings
    is_semi_sort = args.semi_sort

    logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format(has_bos, has_eos, has_unk))
    logger.info('is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}'
                .format(is_lower, is_fixed_embeddings, is_normalize_embeddings))
    logger.info('is_only_use_pretrained_embeddings: {}, is_semi_sort: {}'
                .format(is_only_use_pretrained_embeddings, is_semi_sort))

    save_path = args.save
    hard_save_path = args.hard_save
    restore_path = args.restore

    glove_path = args.glove

    # Experimental RTE regularizers
    rule00_weight = args.rule00_weight
    rule01_weight = args.rule01_weight
    rule02_weight = args.rule02_weight
    rule03_weight = args.rule03_weight

    rule1_weight = args.rule1_weight
    rule2_weight = args.rule2_weight
    rule3_weight = args.rule3_weight
    rule4_weight = args.rule4_weight
    rule5_weight = args.rule5_weight
    rule6_weight = args.rule6_weight
    rule7_weight = args.rule7_weight
    rule8_weight = args.rule8_weight

    a_batch_size = args.adversarial_batch_size
    adversarial_pooling_name = args.adversarial_pooling

    name_to_adversarial_pooling = {
        'sum': tf.reduce_sum,
        'max': tf.reduce_max,
        'mean': tf.reduce_mean,
        'logsumexp': tf.reduce_logsumexp
    }

    report_interval = args.report
    report_loss_interval = args.report_loss

    eval_paths = args.eval

    np.random.seed(seed)
    rs = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path, valid_path=valid_path, test_path=test_path, is_lower=is_lower)

    logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format(len(train_is), len(dev_is), len(test_is)))
    all_is = train_is + dev_is + test_is

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3
    start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else 0) + (1 if has_unk else 0)

    if not restore_path:
        # Create a sequence of tokens containing all sentences in the dataset
        token_seq = []
        for instance in all_is:
            token_seq += instance['sentence1_parse_tokens'] + instance['sentence2_parse_tokens']

        token_set = set(token_seq)
        allowed_words = None
        if is_only_use_pretrained_embeddings:
            assert glove_path is not None
            logger.info('Loading GloVe words from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            allowed_words = load_glove_words(path=glove_path, words=token_set)
            logger.info('Number of allowed words: {}'.format(len(allowed_words)))

        # Count the number of occurrences of each token
        token_counts = dict()
        for token in token_seq:
            if (allowed_words is None) or (token in allowed_words):
                if token not in token_counts:
                    token_counts[token] = 0
                token_counts[token] += 1

        # Sort the tokens according to their frequency and lexicographic ordering
        sorted_vocabulary = sorted(token_counts.keys(), key=lambda t: (- token_counts[t], t))

        index_to_token = {index: token for index, token in enumerate(sorted_vocabulary, start=start_idx)}
    else:
        with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
            index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None
    optimizer_name_to_class = {
        'adagrad': tf.train.AdagradOptimizer,
        'adam': tf.train.AdamOptimizer
    }

    optimizer_class = optimizer_name_to_class[optimizer_name]
    assert optimizer_class

    optimizer = optimizer_class(learning_rate=learning_rate)

    args = dict(has_bos=has_bos, has_eos=has_eos, has_unk=has_unk,
                bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx,
                max_len=max_len)

    train_dataset = util.instances_to_dataset(train_is, token_to_index, label_to_index, **args)
    dev_dataset = util.instances_to_dataset(dev_is, token_to_index, label_to_index, **args)
    test_dataset = util.instances_to_dataset(test_is, token_to_index, label_to_index, **args)

    sentence1 = train_dataset['sentence1']
    sentence1_length = train_dataset['sentence1_length']
    sentence2 = train_dataset['sentence2']
    sentence2_length = train_dataset['sentence2_length']
    label = train_dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label')

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    token_to_embedding = dict()
    if not restore_path:
        if glove_path:
            logger.info('Loading GloVe word embeddings from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            token_to_embedding = load_glove(glove_path, token_set)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        if initialize_embeddings == 'normal':
            logger.info('Initializing the embeddings with 𝓝(0, 1)')
            embedding_initializer = tf.random_normal_initializer(0.0, 1.0)
        elif initialize_embeddings == 'uniform':
            logger.info('Initializing the embeddings with 𝒰(-1, 1)')
            embedding_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0)
        else:
            logger.info('Initializing the embeddings with Xavier initialization')
            embedding_initializer = tf.contrib.layers.xavier_initializer()

        embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size],
                                          initializer=embedding_initializer, trainable=not is_fixed_embeddings)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob')

        model_kwargs = dict(
            sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph,
            sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph,
            representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = std_dev

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        predictions = tf.argmax(logits, axis=1, name='predictions')

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label_ph)
        loss = tf.reduce_mean(losses)

        a_pooling_function = name_to_adversarial_pooling[adversarial_pooling_name]

        a_losses = None
        if rule00_weight:
            a_loss, a_losses = contradiction_symmetry_l2(model_class, model_kwargs,
                                                         contradiction_idx=contradiction_idx,
                                                         pooling_function=a_pooling_function,
                                                         debug=True)
            loss += rule00_weight * a_loss

        if rule01_weight:
            a_loss, a_losses = contradiction_symmetry_l1(model_class, model_kwargs,
                                                         contradiction_idx=contradiction_idx,
                                                         pooling_function=a_pooling_function,
                                                         debug=True)
            loss += rule01_weight * a_loss

        if rule02_weight:
            a_loss, a_losses = contradiction_kullback_leibler(model_class, model_kwargs,
                                                              contradiction_idx=contradiction_idx,
                                                              pooling_function=a_pooling_function,
                                                              debug=True)
            loss += rule02_weight * a_loss

        if rule03_weight:
            a_loss, a_losses = contradiction_jensen_shannon(model_class, model_kwargs,
                                                            contradiction_idx=contradiction_idx,
                                                            pooling_function=a_pooling_function,
                                                            debug=True)
            loss += rule03_weight * a_loss

    discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name)
    discriminator_init_op = tf.variables_initializer(discriminator_vars)

    trainable_discriminator_vars = list(discriminator_vars)
    if is_fixed_embeddings:
        trainable_discriminator_vars.remove(embedding_layer)

    discriminator_optimizer_scope_name = 'discriminator_optimizer'
    with tf.variable_scope(discriminator_optimizer_scope_name):
        if clip_value:
            gradients, v = zip(*optimizer.compute_gradients(loss, var_list=trainable_discriminator_vars))
            gradients, _ = tf.clip_by_global_norm(gradients, clip_value)
            training_step = optimizer.apply_gradients(zip(gradients, v))
        else:
            training_step = optimizer.minimize(loss, var_list=trainable_discriminator_vars)

    discriminator_optimizer_vars = tfutil.get_variables_in_scope(discriminator_optimizer_scope_name)
    discriminator_optimizer_init_op = tf.variables_initializer(discriminator_optimizer_vars)

    token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx')
    token_embedding_ph = tf.placeholder(dtype=tf.float32, shape=[None], name='word_embedding')

    assign_token_embedding = embedding_layer[token_idx_ph, :].assign(token_embedding_ph)

    init_projection_steps = []
    learning_projection_steps = []

    if is_normalize_embeddings:
        embeddings_projection = constraints.unit_sphere(embedding_layer, norm=1.0)
        init_projection_steps += [embeddings_projection]

        if not is_fixed_embeddings:
            learning_projection_steps += [embeddings_projection]

    predictions_int = tf.cast(predictions, tf.int32)
    labels_int = tf.cast(label_ph, tf.int32)

    use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight

    rule_id_to_placeholders = dict()

    if use_adversarial_training:
        adversary_scope_name = discriminator_scope_name
        with tf.variable_scope(adversary_scope_name):

            adversarial = AdversarialSets3(model_class=model_class, model_kwargs=model_kwargs, scope_name='adversary',
                                           entailment_idx=entailment_idx, contradiction_idx=contradiction_idx, neutral_idx=neutral_idx)

            adversary_loss = tf.constant(0.0, dtype=tf.float32)

            adversarial_pooling = name_to_adversarial_pooling[adversarial_pooling_name]

            def f(rule_idx):
                nb_sequences = adversarial.rule_nb_sequences(rule_idx)
                a_args, rule_placeholders = [], []

                for seq_id in range(nb_sequences):
                    a_sentence_ph = tf.placeholder(dtype=tf.int32, shape=[None, None],
                                                   name='a_rule{}_sentence{}'.format(rule_idx, seq_id))
                    a_sentence_len_ph = tf.placeholder(dtype=tf.int32, shape=[None],
                                                       name='a_rule{}_sentence{}_length'.format(rule_idx, seq_id))

                    a_clipped_sentence = tfutil.clip_sentence(a_sentence_ph, a_sentence_len_ph)
                    a_sentence_embedding = tf.nn.embedding_lookup(embedding_layer, a_clipped_sentence)

                    a_args += [a_sentence_embedding, a_sentence_len_ph]
                    rule_placeholders += [(a_sentence_ph, a_sentence_len_ph)]

                rule_id_to_placeholders[rule_idx] = rule_placeholders
                rule_loss = adversarial.rule1_loss(*a_args)
                return rule_loss

            if rule1_weight:
                r_loss = f(1)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule2_weight:
                r_loss = f(2)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule3_weight:
                r_loss = f(3)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule4_weight:
                r_loss = f(4)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule5_weight:
                r_loss = f(5)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule6_weight:
                r_loss = f(6)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule7_weight:
                r_loss = f(7)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule8_weight:
                r_loss = f(8)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)

            loss += adversary_loss

        logger.info('Adversarial Batch Size: {}'.format(a_batch_size))

    a_feed_dict = dict()
    a_rs = np.random.RandomState(seed)

    d_sentence1, d_sentence2 = train_dataset['sentence1'], train_dataset['sentence2']
    d_sentence1_len, d_sentence2_len = train_dataset['sentence1_length'], train_dataset['sentence2_length']
    d_label = train_dataset['label']

    nb_train_instances = d_label.shape[0]

    max_sentence_len = max(d_sentence1.shape[1], d_sentence2.shape[1])
    d_sentence = np.zeros(shape=(nb_train_instances * 2, max_sentence_len), dtype=np.int)
    d_sentence[0:d_sentence1.shape[0], 0:d_sentence1.shape[1]] = d_sentence1
    d_sentence[d_sentence1.shape[0]:, 0:d_sentence2.shape[1]] = d_sentence2

    d_sentence_len = np.concatenate((d_sentence1_len, d_sentence2_len), axis=0)

    nb_train_sentences = d_sentence_len.shape[0]

    saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))
        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))
        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=trainable_discriminator_vars)))

        if restore_path:
            saver.restore(session, restore_path)
        else:
            session.run([discriminator_init_op, discriminator_optimizer_init_op])

            # Initialising pre-trained embeddings
            logger.info('Initialising the embeddings pre-trained vectors ..')
            for token in token_to_embedding:
                token_idx, token_embedding = token_to_index[token], token_to_embedding[token]
                assert embedding_size == len(token_embedding)
                session.run(assign_token_embedding,
                            feed_dict={
                                token_idx_ph: token_idx,
                                token_embedding_ph: token_embedding
                            })
            logger.info('Done!')

            for adversary_projection_step in init_projection_steps:
                session.run([adversary_projection_step])

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        best_dev_acc, best_test_acc = None, None
        discriminator_batch_counter = 0

        for epoch in range(1, nb_epochs + 1):

            if use_adversarial_training:
                for rule_idx, rule_placeholders in rule_id_to_placeholders.items():

                    a_idxs = a_rs.choice(a_batch_size, nb_train_sentences)
                    for a_sentence_ph, a_sentence_len_ph in rule_placeholders:
                        # Select a random sentence from the training set
                        a_sentence_batch = d_sentence[a_idxs]
                        a_sentence_len_batch = d_sentence_len[a_idxs]

                        a_feed_dict[a_sentence_ph] = a_sentence_batch
                        a_feed_dict[a_sentence_len_ph] = a_sentence_len_batch

            for d_epoch in range(1, nb_discriminator_epochs + 1):
                order = rs.permutation(nb_instances)

                sentences1, sentences2 = sentence1[order], sentence2[order]
                sizes1, sizes2 = sentence1_length[order], sentence2_length[order]
                labels = label[order]

                if is_semi_sort:
                    order = util.semi_sort(sizes1, sizes2)
                    sentences1, sentences2 = sentence1[order], sentence2[order]
                    sizes1, sizes2 = sentence1_length[order], sentence2_length[order]
                    labels = label[order]

                loss_values, epoch_loss_values = [], []
                for batch_idx, (batch_start, batch_end) in enumerate(batches):
                    discriminator_batch_counter += 1

                    batch_sentences1, batch_sentences2 = sentences1[batch_start:batch_end], sentences2[batch_start:batch_end]
                    batch_sizes1, batch_sizes2 = sizes1[batch_start:batch_end], sizes2[batch_start:batch_end]
                    batch_labels = labels[batch_start:batch_end]

                    batch_max_size1 = np.max(batch_sizes1)
                    batch_max_size2 = np.max(batch_sizes2)

                    batch_sentences1 = batch_sentences1[:, :batch_max_size1]
                    batch_sentences2 = batch_sentences2[:, :batch_max_size2]

                    batch_feed_dict = {
                        sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1,
                        sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2,
                        label_ph: batch_labels, dropout_keep_prob_ph: dropout_keep_prob
                    }

                    # Adding the adversaries
                    batch_feed_dict.update(a_feed_dict)

                    _, loss_value = session.run([training_step, loss], feed_dict=batch_feed_dict)

                    logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format(epoch, d_epoch, batch_idx, loss_value))

                    cur_batch_size = batch_sentences1.shape[0]
                    loss_values += [loss_value / cur_batch_size]
                    epoch_loss_values += [loss_value / cur_batch_size]

                    for adversary_projection_step in learning_projection_steps:
                        session.run([adversary_projection_step])

                    if discriminator_batch_counter % report_loss_interval == 0:
                        logger.info('Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format(epoch, d_epoch, batch_idx, stats(loss_values)))
                        loss_values = []

                    if discriminator_batch_counter % report_interval == 0:
                        accuracy_args = [sentence1_ph, sentence1_len_ph, sentence2_ph, sentence2_len_ph,
                                         label_ph, dropout_keep_prob_ph, predictions_int, labels_int,
                                         contradiction_idx, entailment_idx, neutral_idx, batch_size]
                        dev_acc, _, _, _ = accuracy(session, dev_dataset, 'Dev', *accuracy_args)
                        test_acc, _, _, _ = accuracy(session, test_dataset, 'Test', *accuracy_args)

                        logger.info('Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}'
                                    .format(epoch, d_epoch, batch_idx, dev_acc * 100, test_acc * 100))

                        if best_dev_acc is None or dev_acc > best_dev_acc:
                            best_dev_acc, best_test_acc = dev_acc, test_acc

                            if save_path:
                                with open('{}_index_to_token.p'.format(save_path), 'wb') as f:
                                    pickle.dump(index_to_token, f)

                                saved_path = saver.save(session, save_path)
                                logger.info('Model saved in file: {}'.format(saved_path))

                        logger.info('Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}'
                                    .format(epoch, d_epoch, batch_idx, best_dev_acc * 100, best_test_acc * 100))

                        for eval_path in eval_paths:
                            eval_path_acc = eutil.evaluate(session, eval_path, label_to_index, token_to_index, predictions, batch_size,
                                                           sentence1_ph, sentence2_ph, sentence1_len_ph, sentence2_len_ph, dropout_keep_prob_ph,
                                                           has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, is_lower=is_lower,
                                                           bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx)
                            logger.info('Epoch {0}/{1}/{2}\tAccuracy on {3} is {4}'.format(epoch, d_epoch, batch_idx,
                                                                                           eval_path, eval_path_acc))

                        if a_losses is not None:
                            t_feed_dict = a_feed_dict
                            if len(t_feed_dict) == 0:
                                t_feed_dict = {
                                    sentence1_ph: sentences1[:1024], sentence1_len_ph: sizes1[:1024],
                                    sentence2_ph: sentences2[:1024], sentence2_len_ph: sizes2[:1024],
                                    dropout_keep_prob_ph: 1.0
                                }
                            a_losses_value = session.run(a_losses, feed_dict=t_feed_dict)

                            a_input_idxs = np.argsort(- a_losses_value)
                            for i in a_input_idxs[:10]:
                                t_sentence1 = t_feed_dict[sentence1_ph][i]
                                t_sentence2 = t_feed_dict[sentence2_ph][i]

                                logger.info('[ {} / {} ] Sentence1: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence1 if x not in [0, 1, 2]])))
                                logger.info('[ {} / {} ] Sentence2: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence2 if x not in [0, 1, 2]])))

                logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format(epoch, d_epoch, stats(epoch_loss_values)))

                if hard_save_path:
                    with open('{}_index_to_token.p'.format(hard_save_path), 'wb') as f:
                        pickle.dump(index_to_token, f)

                    hard_saved_path = saver.save(session, hard_save_path)
                    logger.info('Model saved in file: {}'.format(hard_saved_path))

    logger.info('Training finished.')
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt)

    argparser.add_argument('--data', '-d', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--model', '-m', action='store', type=str, default='ff-dam',
                           choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size', action='store', type=int, default=300)
    argparser.add_argument('--representation-size', action='store', type=int, default=200)

    argparser.add_argument('--batch-size', action='store', type=int, default=32)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token')
    argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus')

    argparser.add_argument('--restore', action='store', type=str, default=None)
    argparser.add_argument('--lm', action='store', type=str, default='models/lm/')

    argparser.add_argument('--corrupt', '-c', action='store_true', default=False,
                           help='Corrupt examples so to maximise their inconsistency')
    argparser.add_argument('--most-violating', '-M', action='store_true', default=False,
                           help='Show most violating examples')

    argparser.add_argument('--epsilon', '-e', action='store', type=float, default=1e-4)
    argparser.add_argument('--lambda-weight', '-L', action='store', type=float, default=1.0)

    argparser.add_argument('--inconsistency', '-i', action='store', type=str, default='contradiction')

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore
    lm_path = args.lm

    is_corrupt = args.corrupt
    is_most_violating = args.most_violating

    epsilon = args.epsilon
    lambda_w = args.lambda_weight

    inconsistency_name = args.inconsistency

    iloss = None
    if inconsistency_name == 'contradiction':
        iloss = contradiction_loss
    elif inconsistency_name == 'neutral':
        iloss = neutral_loss
    elif inconsistency_name == 'entailment':
        iloss = entailment_loss

    assert iloss is not None

    np.random.seed(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    data_is, _, _ = util.SNLI.generate(train_path=data_path, valid_path=None, test_path=None, is_lower=is_lower)
    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    global index_to_token, token_to_index
    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    with open('{}/config.json'.format(lm_path), 'r') as f:
        config = json.load(f)

    seq_length = 1
    lm_batch_size = batch_size
    rnn_size = config['rnn_size']
    num_layers = config['num_layers']

    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None

    args = dict(
        has_bos=has_bos, has_eos=has_eos, has_unk=has_unk,
        bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx,
        max_len=max_len)

    dataset = util.instances_to_dataset(data_is, token_to_index, label_to_index, **args)

    sentence1, sentence1_length = dataset['sentence1'], dataset['sentence1_length']
    sentence2, sentence2_length = dataset['sentence2'], dataset['sentence2_length']
    label = dataset['label']

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size], trainable=False)
        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2)

        model_kwargs = dict(
            sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph,
            sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph,
            representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = 0.01

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }
        model_class = mode_name_to_class[model_name]
        assert model_class is not None

        model = model_class(**model_kwargs)
        logits = model()

        global probabilities
        probabilities = tf.nn.softmax(logits)

        predictions = tf.argmax(logits, axis=1, name='predictions')

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        cell_fn = rnn.BasicLSTMCell
        cells = [cell_fn(rnn_size) for _ in range(num_layers)]

        global lm_cell
        lm_cell = rnn.MultiRNNCell(cells)

        global lm_input_data_ph, lm_targets_ph, lm_initial_state
        lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length], name='input_data')
        lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length], name='targets')
        lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32)

        with tf.variable_scope('rnnlm'):
            lm_W = tf.get_variable(name='W', shape=[rnn_size, vocab_size],
                                   initializer=tf.contrib.layers.xavier_initializer())

            lm_b = tf.get_variable(name='b', shape=[vocab_size],
                                   initializer=tf.zeros_initializer())

            lm_emb_lookup = tf.nn.embedding_lookup(embedding_layer, lm_input_data_ph)
            lm_emb_projection = tf.contrib.layers.fully_connected(inputs=lm_emb_lookup, num_outputs=rnn_size,
                                                                  weights_initializer=tf.contrib.layers.xavier_initializer(),
                                                                  biases_initializer=tf.zeros_initializer())

            lm_inputs = tf.split(lm_emb_projection, seq_length, 1)
            lm_inputs = [tf.squeeze(input_, [1]) for input_ in lm_inputs]

        lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(decoder_inputs=lm_inputs, initial_state=lm_initial_state,
                                                               cell=lm_cell, loop_function=None, scope='rnnlm')

        lm_output = tf.reshape(tf.concat(lm_outputs, 1), [-1, rnn_size])

        lm_logits = tf.matmul(lm_output, lm_W) + lm_b
        lm_probabilities = tf.nn.softmax(lm_logits)

        global lm_loss, lm_cost, lm_final_state
        lm_loss = legacy_seq2seq.sequence_loss_by_example(logits=[lm_logits], targets=[tf.reshape(lm_targets_ph, [-1])],
                                                          weights=[tf.ones([lm_batch_size * seq_length])])
        lm_cost = tf.reduce_sum(lm_loss) / lm_batch_size / seq_length
        lm_final_state = lm_last_state

    discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name)
    lm_vars = tfutil.get_variables_in_scope(lm_scope_name)

    predictions_int = tf.cast(predictions, tf.int32)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
    lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    global session
    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))

        saver.restore(session, restore_path)

        lm_ckpt = tf.train.get_checkpoint_state(lm_path)
        lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        order = np.arange(nb_instances)

        sentences1 = sentence1[order]
        sentences2 = sentence2[order]

        sizes1 = sentence1_length[order]
        sizes2 = sentence2_length[order]

        labels = label[order]

        logger.info('Number of examples: {}'.format(labels.shape[0]))

        predictions_int_value = []
        c_losses, e_losses, n_losses = [], [], []

        for batch_idx, (batch_start, batch_end) in enumerate(batches):
            batch_sentences1 = sentences1[batch_start:batch_end]
            batch_sentences2 = sentences2[batch_start:batch_end]

            batch_sizes1 = sizes1[batch_start:batch_end]
            batch_sizes2 = sizes2[batch_start:batch_end]

            batch_feed_dict = {
                sentence1_ph: batch_sentences1,
                sentence1_len_ph: batch_sizes1,

                sentence2_ph: batch_sentences2,
                sentence2_len_ph: batch_sizes2,

                dropout_keep_prob_ph: 1.0
            }

            batch_predictions_int = session.run(predictions_int, feed_dict=batch_feed_dict)
            predictions_int_value += batch_predictions_int.tolist()

            batch_c_loss = contradiction_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            c_losses += batch_c_loss.tolist()

            batch_e_loss = entailment_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            e_losses += batch_e_loss.tolist()

            batch_n_loss = neutral_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            n_losses += batch_n_loss.tolist()

            if is_corrupt:
                search(sentences1=batch_sentences1, sizes1=batch_sizes1,
                       sentences2=batch_sentences2, sizes2=batch_sizes2,
                       batch_size=batch_size, epsilon=epsilon, lambda_w=lambda_w,
                       inconsistency_loss=iloss)

        train_accuracy_value = np.mean(labels == np.array(predictions_int_value))
        logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value))

        if is_most_violating:
            c_ranking = np.argsort(np.array(c_losses))[::-1]
            assert c_ranking.shape[0] == len(data_is)

            for i in range(min(1024, c_ranking.shape[0])):
                idx = c_ranking[i]
                print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], c_losses[idx]))
                print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], c_losses[idx]))

            e_ranking = np.argsort(np.array(e_losses))[::-1]
            assert e_ranking.shape[0] == len(data_is)

            for i in range(min(1024, e_ranking.shape[0])):
                idx = e_ranking[i]
                print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], e_losses[idx]))
                print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], e_losses[idx]))

            n_ranking = np.argsort(np.array(n_losses))[::-1]
            assert n_ranking.shape[0] == len(data_is)

            for i in range(min(1024, n_ranking.shape[0])):
                idx = n_ranking[i]
                print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], n_losses[idx]))
                print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], n_losses[idx]))
def search(sentences1, sizes1, sentences2, sizes2,
           lambda_w=0.1, inconsistency_loss=contradiction_loss,
           epsilon=1e-4, batch_size=32,
           nb_corruptions=1024, nb_words=256):

    loss_value, iloss_value, logperp_value = loss(sentences1=sentences1, sizes1=sizes1,
                                                  sentences2=sentences2, sizes2=sizes2,
                                                  lambda_w=lambda_w, inconsistency_loss=inconsistency_loss)

    # Find examples that have a nearly-zero inconsistency loss, and only work on making those more "adversarial"
    low_iloss_idxs = np.where(iloss_value < 1e-6)[0]

    for low_iloss_idx in low_iloss_idxs.tolist():
        sentence1, size1 = sentences1[low_iloss_idx, :], sizes1[low_iloss_idx]
        sentence2, size2 = sentences2[low_iloss_idx, :], sizes2[low_iloss_idx]

        sample_loss_value, sample_iloss_value, sample_logperp_value = \
            loss_value[low_iloss_idx], iloss_value[low_iloss_idx], logperp_value[low_iloss_idx]

        sentence1_str = ' '.join([index_to_token[tidx] for tidx in sentence1 if tidx != 0])
        sentence2_str = ' '.join([index_to_token[tidx] for tidx in sentence2 if tidx != 0])

        print('SENTENCE 1 (inconsistency loss: {} / log-perplexity: {}): {}'
              .format(sample_iloss_value, sample_logperp_value, sentence1_str))
        print('SENTENCE 2 (inconsistency loss: {} / log-perplexity: {}): {}'
              .format(sample_iloss_value, sample_logperp_value, sentence2_str))

        # Generate mutations that do not increase the perplexity too much, and maximise the inconsistency loss
        corruptions1, corruption_sizes1, corruptions2, corruption_sizes2 = \
            corrupt(sentence1=sentence1, size1=size1, sentence2=sentence2, size2=size2,
                    nb_corruptions=nb_corruptions, nb_words=nb_words)

        # Compute all relevant metrics for the corruptions
        nb_corruptions = corruptions1.shape[0]
        batches = make_batches(size=nb_corruptions, batch_size=batch_size)

        corruption_loss_values, corruption_iloss_values, corruption_logperp_values = [], [], []
        for batch_start, batch_end in batches:
            batch_corruptions1 = corruptions1[batch_start:batch_end, :]
            batch_corruption_sizes1 = corruption_sizes1[batch_start:batch_end]

            batch_corruptions2 = corruptions2[batch_start:batch_end, :]
            batch_corruption_sizes2 = corruption_sizes2[batch_start:batch_end]

            batch_loss_values, batch_iloss_values, batch_logperp_values = \
                loss(sentences1=batch_corruptions1, sizes1=batch_corruption_sizes1,
                     sentences2=batch_corruptions2, sizes2=batch_corruption_sizes2,
                     lambda_w=lambda_w, inconsistency_loss=inconsistency_loss)

            corruption_loss_values += batch_loss_values.tolist()
            corruption_iloss_values += batch_iloss_values.tolist()
            corruption_logperp_values += batch_logperp_values.tolist()

        corruption_loss_values = np.array(corruption_loss_values)
        corruption_iloss_values = np.array(corruption_iloss_values)
        corruption_logperp_values = np.array(corruption_logperp_values)

        # Sort the corruptions by their inconsistency loss:
        corruptions_order = np.argsort(corruption_iloss_values)[::-1]

        # Select corruptions that did not increase the log-perplexity too much
        low_perplexity_mask = corruption_logperp_values <= logperp_value[low_iloss_idx] + epsilon

        counter = 0
        for idx in corruptions_order.tolist():
            if idx in np.where(low_perplexity_mask)[0].tolist():
                if counter < 10:
                    corruption_str = ' '.join([index_to_token[tidx] for tidx in corruptions2[idx] if tidx != 0])
                    msg = '[{}] CORRUPTION 2 (inconsistency loss: {} / log-perplexity: {}): {}'\
                        .format(counter, corruption_iloss_values[idx], corruption_logperp_values[idx], corruption_str)

                    print(msg)

                    _sentence1 = np.array([sentence1])
                    _size1 = np.array([size1])

                    _sentence2 = np.array([corruptions2[idx]])
                    _size2 = np.array([size2])

                    probabilities_1 = inference(_sentence1, _size1, _sentence2, _size2)
                    probabilities_2 = inference(_sentence2, _size2, _sentence1, _size1)

                    msg = 'A -> B: {}\tB -> A: {}'.format(str(probabilities_1), str(probabilities_2))

                    print(msg)
                counter += 1

    return
Example #7
0
def evaluate(session, eval_path, label_to_index, token_to_index, predictions_op, batch_size,
             sentence1_ph, sentence2_ph, sentence1_len_ph, sentence2_len_ph, dropout_keep_prob_ph,
             has_bos=False, has_eos=False, has_unk=False, is_lower=False,
             bos_idx=1, eos_idx=2, unk_idx=3):
    sentence1_all = []
    sentence2_all = []
    gold_label_all = []

    with gzip.open(eval_path, 'rb') as f:
        for line in f:
            decoded_line = line.decode('utf-8')

            if is_lower:
                decoded_line = decoded_line.lower()

            obj = json.loads(decoded_line)

            gold_label = obj['gold_label']

            if gold_label in ['contradiction', 'entailment', 'neutral']:
                gold_label_all += [label_to_index[gold_label]]

                sentence1_parse = obj['sentence1_parse']
                sentence2_parse = obj['sentence2_parse']

                sentence1_tree = nltk.Tree.fromstring(sentence1_parse)
                sentence2_tree = nltk.Tree.fromstring(sentence2_parse)

                sentence1_tokens = sentence1_tree.leaves()
                sentence2_tokens = sentence2_tree.leaves()

                sentence1_ids = []
                sentence2_ids = []

                if has_bos:
                    sentence1_ids += [bos_idx]
                    sentence2_ids += [bos_idx]

                for token in sentence1_tokens:
                    if token in token_to_index:
                        sentence1_ids += [token_to_index[token]]
                    elif has_unk:
                        sentence1_ids += [unk_idx]

                for token in sentence2_tokens:
                    if token in token_to_index:
                        sentence2_ids += [token_to_index[token]]
                    elif has_unk:
                        sentence2_ids += [unk_idx]

                if has_eos:
                    sentence1_ids += [eos_idx]
                    sentence2_ids += [eos_idx]

                sentence1_all += [sentence1_ids]
                sentence2_all += [sentence2_ids]

    sentence1_all_len = [len(s) for s in sentence1_all]
    sentence2_all_len = [len(s) for s in sentence2_all]

    np_sentence1 = util.pad_sequences(sequences=sentence1_all)
    np_sentence2 = util.pad_sequences(sequences=sentence2_all)

    np_sentence1_len = np.array(sentence1_all_len)
    np_sentence2_len = np.array(sentence2_all_len)

    gold_label = np.array(gold_label_all)

    nb_instances = gold_label.shape[0]
    batches = make_batches(size=nb_instances, batch_size=batch_size)

    predictions = []

    for batch_idx, (batch_start, batch_end) in enumerate(batches):
        feed_dict = {
            sentence1_ph: np_sentence1[batch_start:batch_end],
            sentence2_ph: np_sentence2[batch_start:batch_end],

            sentence1_len_ph: np_sentence1_len[batch_start:batch_end],
            sentence2_len_ph: np_sentence2_len[batch_start:batch_end],

            dropout_keep_prob_ph: 1.0
        }

        _predictions = session.run(predictions_op, feed_dict=feed_dict)
        predictions += _predictions.tolist()

    matches = np.array(predictions) == gold_label
    return np.mean(matches)
def main(argv):
    def formatter(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser('NLI Service',
                                        formatter_class=formatter)

    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='cbilstm',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size',
                           '-e',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           '-r',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--restore',
                           '-R',
                           action='store',
                           type=str,
                           default=None,
                           required=True)

    argparser.add_argument('--eval', action='store', default=None, type=str)
    argparser.add_argument('--batch-size',
                           '-b',
                           action='store',
                           default=32,
                           type=int)

    args = argparser.parse_args(argv)

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore

    eval_path = args.eval
    batch_size = args.batch_size

    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }
    vocab_size = max(token_to_index.values()) + 1

    sentence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence2_length')

    dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):

        embedding_layer = tf.get_variable('embeddings',
                                          shape=[vocab_size, embedding_size])

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        predictions_op = tf.argmax(logits, axis=1, name='predictions')

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)

    sentence1_all = []
    sentence2_all = []
    gold_label_all = []

    with gzip.open(eval_path, 'rb') as f:
        for line in f:
            decoded_line = line.decode('utf-8')

            if is_lower:
                decoded_line = decoded_line.lower()

            obj = json.loads(decoded_line)

            gold_label = obj['gold_label']

            if gold_label in ['contradiction', 'entailment', 'neutral']:
                gold_label_all += [label_to_index[gold_label]]

                sentence1_parse = obj['sentence1_parse']
                sentence2_parse = obj['sentence2_parse']

                sentence1_tree = nltk.Tree.fromstring(sentence1_parse)
                sentence2_tree = nltk.Tree.fromstring(sentence2_parse)

                sentence1_tokens = sentence1_tree.leaves()
                sentence2_tokens = sentence2_tree.leaves()

                sentence1_ids = []
                sentence2_ids = []

                if has_bos:
                    sentence1_ids += [bos_idx]
                    sentence2_ids += [bos_idx]

                for token in sentence1_tokens:
                    if token in token_to_index:
                        sentence1_ids += [token_to_index[token]]
                    elif has_unk:
                        sentence1_ids += [unk_idx]

                for token in sentence2_tokens:
                    if token in token_to_index:
                        sentence2_ids += [token_to_index[token]]
                    elif has_unk:
                        sentence2_ids += [unk_idx]

                if has_eos:
                    sentence1_ids += [eos_idx]
                    sentence2_ids += [eos_idx]

                sentence1_all += [sentence1_ids]
                sentence2_all += [sentence2_ids]

    sentence1_all_len = [len(s) for s in sentence1_all]
    sentence2_all_len = [len(s) for s in sentence2_all]

    np_sentence1 = util.pad_sequences(sequences=sentence1_all)
    np_sentence2 = util.pad_sequences(sequences=sentence2_all)

    np_sentence1_len = np.array(sentence1_all_len)
    np_sentence2_len = np.array(sentence2_all_len)

    gold_label = np.array(gold_label_all)

    with tf.Session() as session:
        saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
        saver.restore(session, restore_path)

        from inferbeddings.models.training.util import make_batches
        nb_instances = gold_label.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        predictions = []

        for batch_idx, (batch_start, batch_end) in enumerate(batches):
            feed_dict = {
                sentence1_ph: np_sentence1[batch_start:batch_end],
                sentence2_ph: np_sentence2[batch_start:batch_end],
                sentence1_len_ph: np_sentence1_len[batch_start:batch_end],
                sentence2_len_ph: np_sentence2_len[batch_start:batch_end],
                dropout_keep_prob_ph: 1.0
            }

            _predictions = session.run(predictions_op, feed_dict=feed_dict)
            predictions += _predictions.tolist()

        matches = np.array(predictions) == gold_label
        print(np.mean(matches))
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE via Adversarial Sets Regularisation',
        formatter_class=fmt)

    argparser.add_argument('--data',
                           '-d',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='ff-dam',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           action='store',
                           type=int,
                           default=32)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--check-transitivity',
                           '-x',
                           action='store_true',
                           default=False)

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore

    is_check_transitivity = args.check_transitivity

    np.random.seed(seed)
    rs = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    data_is, _, _ = util.SNLI.generate(train_path=data_path,
                                       valid_path=None,
                                       test_path=None,
                                       is_lower=is_lower)

    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None

    args = dict(has_bos=has_bos,
                has_eos=has_eos,
                has_unk=has_unk,
                bos_idx=bos_idx,
                eos_idx=eos_idx,
                unk_idx=unk_idx,
                max_len=max_len)

    dataset = util.instances_to_dataset(data_is, token_to_index,
                                        label_to_index, **args)

    sentence1 = dataset['sentence1']
    sentence1_length = dataset['sentence1_length']
    sentence2 = dataset['sentence2']
    sentence2_length = dataset['sentence2_length']
    label = dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable('embeddings',
                                          shape=[vocab_size, embedding_size],
                                          trainable=False)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32,
                                              name='dropout_keep_prob')

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = 0.01

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        probabilities = tf.nn.softmax(logits)

        predictions = tf.argmax(logits, axis=1, name='predictions')

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)

    trainable_discriminator_vars = list(discriminator_vars)

    predictions_int = tf.cast(predictions, tf.int32)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))

        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))

        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(
                var_list=trainable_discriminator_vars)))

        saver.restore(session, restore_path)

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        order = np.arange(nb_instances)

        sentences1 = sentence1[order]
        sentences2 = sentence2[order]

        sizes1 = sentence1_length[order]
        sizes2 = sentence2_length[order]

        labels = label[order]

        a_predictions_int_value = []
        b_predictions_int_value = []

        a_probabilities_value = []
        b_probabilities_value = []

        for batch_idx, (batch_start,
                        batch_end) in tqdm(list(enumerate(batches))):
            batch_sentences1 = sentences1[batch_start:batch_end]
            batch_sentences2 = sentences2[batch_start:batch_end]
            batch_sizes1 = sizes1[batch_start:batch_end]
            batch_sizes2 = sizes2[batch_start:batch_end]

            batch_a_feed_dict = {
                sentence1_ph: batch_sentences1,
                sentence1_len_ph: batch_sizes1,
                sentence2_ph: batch_sentences2,
                sentence2_len_ph: batch_sizes2,
                dropout_keep_prob_ph: 1.0
            }

            batch_a_predictions_int_value, batch_a_probabilities_value = session.run(
                [predictions_int, probabilities], feed_dict=batch_a_feed_dict)

            a_predictions_int_value += batch_a_predictions_int_value.tolist()
            for i in range(batch_a_probabilities_value.shape[0]):
                a_probabilities_value += [{
                    'neutral':
                    batch_a_probabilities_value[i, neutral_idx],
                    'contradiction':
                    batch_a_probabilities_value[i, contradiction_idx],
                    'entailment':
                    batch_a_probabilities_value[i, entailment_idx]
                }]

            batch_b_feed_dict = {
                sentence1_ph: batch_sentences2,
                sentence1_len_ph: batch_sizes2,
                sentence2_ph: batch_sentences1,
                sentence2_len_ph: batch_sizes1,
                dropout_keep_prob_ph: 1.0
            }

            batch_b_predictions_int_value, batch_b_probabilities_value = session.run(
                [predictions_int, probabilities], feed_dict=batch_b_feed_dict)
            b_predictions_int_value += batch_b_predictions_int_value.tolist()
            for i in range(batch_b_probabilities_value.shape[0]):
                b_probabilities_value += [{
                    'neutral':
                    batch_b_probabilities_value[i, neutral_idx],
                    'contradiction':
                    batch_b_probabilities_value[i, contradiction_idx],
                    'entailment':
                    batch_b_probabilities_value[i, entailment_idx]
                }]

        for i, instance in enumerate(data_is):
            instance.update({
                'a': a_probabilities_value[i],
                'b': b_probabilities_value[i],
            })

        logger.info('Number of examples: {}'.format(labels.shape[0]))

        train_accuracy_value = np.mean(
            labels == np.array(a_predictions_int_value))
        logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value))

        s1s2_con = (np.array(a_predictions_int_value) == contradiction_idx)
        s2s1_con = (np.array(b_predictions_int_value) == contradiction_idx)

        assert s1s2_con.shape == s2s1_con.shape

        s1s2_ent = (np.array(a_predictions_int_value) == entailment_idx)
        s2s1_ent = (np.array(b_predictions_int_value) == entailment_idx)

        s1s2_neu = (np.array(a_predictions_int_value) == neutral_idx)
        s2s1_neu = (np.array(b_predictions_int_value) == neutral_idx)

        a = np.logical_xor(s1s2_con, s2s1_con)
        logger.info('(S1 contradicts S2) XOR (S2 contradicts S1): {0}'.format(
            a.sum()))

        b = s1s2_con
        logger.info('(S1 contradicts S2): {0}'.format(b.sum()))
        c = np.logical_and(s1s2_con, np.logical_not(s2s1_con))
        logger.info(
            '(S1 contradicts S2) AND NOT(S2 contradicts S1): {0} ({1:.4f})'.
            format(c.sum(),
                   c.sum() / b.sum()))

        with open('c.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(c)[0].tolist()]
            pickle.dump(tmp, f)

        d = s1s2_ent
        logger.info('(S1 entailment S2): {0}'.format(d.sum()))
        e = np.logical_and(s1s2_ent, s2s1_con)
        logger.info(
            '(S1 entailment S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format(
                e.sum(),
                e.sum() / d.sum()))

        with open('e.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(e)[0].tolist()]
            pickle.dump(tmp, f)

        f = s1s2_con
        logger.info('(S1 neutral S2): {0}'.format(f.sum()))
        g = np.logical_and(s1s2_neu, s2s1_con)
        logger.info(
            '(S1 neutral S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format(
                g.sum(),
                g.sum() / f.sum()))

        with open('g.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(g)[0].tolist()]
            pickle.dump(tmp, f)

        if is_check_transitivity:
            # Find S1, S2 such that entails(S1, S2)
            print(type(s1s2_ent))

            c_predictions_int_value = []
            c_probabilities_value = []

            d_predictions_int_value = []
            d_probabilities_value = []

            # Find candidate S3 sentences
            order = np.arange(nb_instances)

            sentences3 = sentence2[order]
            sizes3 = sentence2_length[order]

            for batch_idx, (batch_start,
                            batch_end) in tqdm(list(enumerate(batches))):
                batch_sentences2 = sentences2[batch_start:batch_end]
                batch_sentences3 = sentences3[batch_start:batch_end]

                batch_sizes2 = sizes2[batch_start:batch_end]
                batch_sizes3 = sizes3[batch_start:batch_end]

                batch_c_feed_dict = {
                    sentence1_ph: batch_sentences2,
                    sentence1_len_ph: batch_sizes2,
                    sentence2_ph: batch_sentences3,
                    sentence2_len_ph: batch_sizes3,
                    dropout_keep_prob_ph: 1.0
                }

                batch_c_predictions_int_value, batch_c_probabilities_value = session.run(
                    [predictions_int, probabilities],
                    feed_dict=batch_c_feed_dict)

                c_predictions_int_value += batch_c_predictions_int_value.tolist(
                )
                for i in range(batch_c_probabilities_value.shape[0]):
                    c_probabilities_value += [{
                        'neutral':
                        batch_c_probabilities_value[i, neutral_idx],
                        'contradiction':
                        batch_c_probabilities_value[i, contradiction_idx],
                        'entailment':
                        batch_c_probabilities_value[i, entailment_idx]
                    }]

                batch_sentences1 = sentences1[batch_start:batch_end]
                batch_sentences3 = sentences3[batch_start:batch_end]

                batch_sizes1 = sizes1[batch_start:batch_end]
                batch_sizes3 = sizes3[batch_start:batch_end]

                batch_d_feed_dict = {
                    sentence1_ph: batch_sentences1,
                    sentence1_len_ph: batch_sizes1,
                    sentence2_ph: batch_sentences3,
                    sentence2_len_ph: batch_sizes3,
                    dropout_keep_prob_ph: 1.0
                }

                batch_d_predictions_int_value, batch_d_probabilities_value = session.run(
                    [predictions_int, probabilities],
                    feed_dict=batch_d_feed_dict)

                d_predictions_int_value += batch_d_predictions_int_value.tolist(
                )
                for i in range(batch_d_probabilities_value.shape[0]):
                    d_probabilities_value += [{
                        'neutral':
                        batch_d_probabilities_value[i, neutral_idx],
                        'contradiction':
                        batch_d_probabilities_value[i, contradiction_idx],
                        'entailment':
                        batch_d_probabilities_value[i, entailment_idx]
                    }]

            s2s3_ent = (np.array(c_predictions_int_value) == entailment_idx)
            s1s3_ent = (np.array(c_predictions_int_value) == entailment_idx)

            body = np.logical_and(s1s2_ent, s2s3_ent)
            body_not_head = np.logical_and(body, np.logical_not(s1s3_ent))

            logger.info('(S1 entails S2) and (S2 entails S3): {0}'.format(
                body.sum()))
            logger.info('body AND NOT(head): {0} ({1:.4f})'.format(
                body_not_head.sum(),
                body_not_head.sum() / body.sum()))

            with open('h.p', 'wb') as f:
                tmp = []
                for idx in np.where(body_not_head)[0].tolist():
                    s1 = data_is[idx]['sentence1']
                    s2 = data_is[idx]['sentence2']
                    s3 = data_is[order[idx]]['sentence2']
                    tmp += [{'s1': s1, 's2': s2, 's3': s3}]
                pickle.dump(tmp, f)