Exemplo n.º 1
0
    def create_batches(self):
        order = self.random_state.permutation(self.nb_samples)
        tensor_shuf = self.tensor[order, :]

        _batch_lst = make_batches(self.nb_samples, self.batch_size)
        self.batches = []

        for batch_start, batch_end in _batch_lst:
            batch_size = batch_end - batch_start
            batch = tensor_shuf[batch_start:batch_end, :]

            assert batch.shape[0] == batch_size

            x = np.zeros(shape=(batch_size, self.seq_length))
            y = np.zeros(shape=(batch_size, self.seq_length))

            for i in range(batch_size):
                start_idx = self.random_state.randint(low=0,
                                                      high=self.max_len - 1)
                end_idx = min(start_idx + self.seq_length, self.max_len)

                x[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx]

                start_idx += 1
                end_idx = min(start_idx + self.seq_length, self.max_len)

                y[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx]

                d = {'x': x, 'y': y}
                self.batches += [d]

        self.num_batches = len(self.batches)
        return
Exemplo n.º 2
0
def evaluate(session, tensors, placeholders, metric, batch_size=None):
    feed_dict = {placeholders[key]: tensors[key] for key in placeholders}

    if batch_size is None:
        res = session.run(metric, feed_dict=feed_dict)
    else:
        res_lst = []
        tensor_names = [name for name in tensors.keys() if name != 'dropout']
        tensor_name = tensor_names[0]
        nb_instances = tensors[tensor_name].shape[0]
        batches = util.make_batches(size=nb_instances, batch_size=batch_size)
        for batch_start, batch_end in batches:

            def get_batch(tensor):
                return tensor[batch_start:batch_end] if not isinstance(
                    tensor, float) else tensor

            batch_feed_dict = {
                ph: get_batch(tensor)
                for ph, tensor in feed_dict.items()
            }

            batch_res = session.run(metric, feed_dict=batch_feed_dict)
            res_lst += batch_res.tolist()
        res = np.array(res_lst)
    return res
Exemplo n.º 3
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE/NLI models via Adversarial Training',
        formatter_class=fmt)

    argparser.add_argument('--train',
                           '-t',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--valid',
                           '-v',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_dev.jsonl.gz')
    argparser.add_argument('--test',
                           '-T',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_test.jsonl.gz')

    argparser.add_argument('--test2', action='store', type=str, default=None)

    argparser.add_argument('--model',
                           '-m',
                           action='store',
                           type=str,
                           default='ff-dam',
                           choices=['cbilstm', 'ff-dam', 'esim'])
    argparser.add_argument('--optimizer',
                           '-o',
                           action='store',
                           type=str,
                           default='adagrad',
                           choices=['adagrad', 'adam'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           '-r',
                           action='store',
                           type=int,
                           default=200)
    argparser.add_argument('--batch-size',
                           '-b',
                           action='store',
                           type=int,
                           default=32)
    argparser.add_argument('--epochs',
                           '-e',
                           action='store',
                           type=int,
                           default=1)

    argparser.add_argument('--dropout-keep-prob',
                           '-d',
                           action='store',
                           type=float,
                           default=1.0)
    argparser.add_argument('--learning-rate',
                           '--lr',
                           action='store',
                           type=float,
                           default=0.1)
    argparser.add_argument('--clip',
                           '-c',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--seed', action='store', type=int, default=0)
    argparser.add_argument('--glove', action='store', type=str, default=None)

    argparser.add_argument('--restore', action='store', type=str, default=None)
    argparser.add_argument('--save', action='store', type=str, default=None)

    argparser.add_argument('--check-interval',
                           '--check-every',
                           '-C',
                           action='store',
                           type=int,
                           default=None)

    # The following parameters are devoted to regularization
    for rule_index in range(1, 5 + 1):
        argparser.add_argument('--regularizer{}-weight'.format(rule_index),
                               '-{}'.format(rule_index),
                               action='store',
                               type=float,
                               default=None)

    argparser.add_argument('--regularizer-inputs',
                           '--ri',
                           '-R',
                           nargs='+',
                           type=str)
    argparser.add_argument('--regularizer-nb-samples',
                           '--rns',
                           '-S',
                           type=int,
                           default=0)
    argparser.add_argument('--regularizer-nb-flips',
                           '--rnf',
                           '-F',
                           type=int,
                           default=0)

    args = argparser.parse_args(argv)

    # Command line arguments
    train_path = args.train
    valid_path = args.valid
    test_path = args.test

    test2_path = args.test2

    model_name = args.model
    optimizer_name = args.optimizer

    embedding_size = args.embedding_size
    representation_size = args.representation_size
    batch_size = args.batch_size
    nb_epochs = args.epochs

    dropout_keep_prob = args.dropout_keep_prob
    learning_rate = args.learning_rate
    clip_value = args.clip
    seed = args.seed
    glove_path = args.glove

    restore_path = args.restore
    save_path = args.save

    check_interval = args.check_interval

    # The following parameters are devoted to regularization
    r1_weight = args.regularizer1_weight
    r2_weight = args.regularizer2_weight
    r3_weight = args.regularizer3_weight
    r4_weight = args.regularizer4_weight
    r5_weight = args.regularizer5_weight

    r_input_paths = args.regularizer_inputs or []
    nb_r_samples = args.regularizer_nb_samples
    nb_r_flips = args.regularizer_nb_flips

    r_weights = [r1_weight, r2_weight, r3_weight, r4_weight, r5_weight]
    is_regularized = not all(r_weight is None for r_weight in r_weights)

    np.random.seed(seed)
    rs = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.info('Reading corpus ..')

    snli = SNLI()
    train_is = snli.parse(path=train_path)
    valid_is = snli.parse(path=valid_path)
    test_is = snli.parse(path=test_path)

    test2_is = snli.parse(path=test2_path) if test2_path else None

    # Discrete/symbolic inputs used by the regularizers
    regularizer_is = [
        i for path in r_input_paths for i in snli.parse(path=path)
    ]

    # Filtering out unuseful information
    regularizer_is = [{
        k: v
        for k, v in instance.items()
        if k in {'sentence1_parse_tokens', 'sentence2_parse_tokens'}
    } for instance in regularizer_is]

    all_is = train_is + valid_is + test_is

    if test2_is is not None:
        all_is += test2_is

    # Enumeration of tokens start at index=3:
    #   index=0 PADDING
    #   index=1 START_OF_SENTENCE
    #   index=2 END_OF_SENTENCE
    #   index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    # Words start at index 4
    start_idx = 1 + 3

    if restore_path is None:
        token_lst = [
            tkn for inst in all_is for tkn in inst['sentence1_parse_tokens'] +
            inst['sentence2_parse_tokens']
        ]

        from collections import Counter
        token_cnt = Counter(token_lst)

        # Sort the tokens according to their frequency and lexicographic ordering
        sorted_vocabulary = sorted(token_cnt.keys(),
                                   key=lambda t: (-token_cnt[t], t))

        index_to_token = {
            idx: tkn
            for idx, tkn in enumerate(sorted_vocabulary, start=start_idx)
        }
    else:
        vocab_path = '{}_index_to_token.p'.format(restore_path)

        logger.info('Restoring vocabulary from {} ..'.format(vocab_path))
        with open(vocab_path, 'rb') as f:
            index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2

    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    name_to_optimizer = {
        'adagrad': tf.train.AdagradOptimizer,
        'adam': tf.train.AdamOptimizer
    }

    name_to_model = {
        'cbilstm': ConditionalBiLSTM,
        'ff-dam': FeedForwardDAM,
        'ff-damp': FeedForwardDAMP,
        'ff-dams': FeedForwardDAMS,
        'esim': ESIM
    }

    optimizer_class = name_to_optimizer[optimizer_name]
    optimizer = optimizer_class(learning_rate=learning_rate)

    model_class = name_to_model[model_name]

    token_kwargs = dict(bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx)

    train_tensors = util.to_tensors(train_is, token_to_index, label_to_index,
                                    **token_kwargs)
    valid_tensors = util.to_tensors(valid_is, token_to_index, label_to_index,
                                    **token_kwargs)
    test_tensors = util.to_tensors(test_is, token_to_index, label_to_index,
                                   **token_kwargs)

    test2_tensors = None
    if test2_is is not None:
        test2_tensors = util.to_tensors(test2_is, token_to_index,
                                        label_to_index, **token_kwargs)

    train_sequence1 = train_tensors['sequence1']
    train_sequence1_len = train_tensors['sequence1_length']

    train_sequence2 = train_tensors['sequence2']
    train_sequence2_len = train_tensors['sequence2_length']

    train_label = train_tensors['label']

    sequence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sequence1')
    sequence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sequence1_length')

    sequence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sequence2')
    sequence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sequence2_length')

    label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label')
    dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob')

    placeholders = {
        'sequence1': sequence1_ph,
        'sequence1_length': sequence1_len_ph,
        'sequence2': sequence2_ph,
        'sequence2_length': sequence2_len_ph,
        'label': label_ph,
        'dropout': dropout_keep_prob_ph
    }

    # Disable Dropout at evaluation time
    valid_tensors['dropout'] = 1.0
    test_tensors['dropout'] = 1.0

    if test2_tensors is not None:
        test2_tensors['dropout'] = 1.0

    clipped_sequence1 = tfutil.clip_sentence(sequence1_ph, sequence1_len_ph)
    clipped_sequence2 = tfutil.clip_sentence(sequence2_ph, sequence2_len_ph)

    vocab_size = max(token_to_index.values()) + 1

    logger.info('Initializing the Model')

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):

        embedding_matrix_value = E.embedding_matrix(
            nb_tokens=vocab_size,
            embedding_size=embedding_size,
            token_to_index=token_to_index,
            glove_path=glove_path,
            unit_norm=True,
            rs=rs,
            dtype=np.float32)

        embedding_layer = tf.get_variable(
            'embeddings',
            initializer=tf.constant(embedding_matrix_value),
            trainable=False)

        sequence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sequence1)
        sequence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sequence2)

        model_kwargs = {
            'sequence1': sequence1_embedding,
            'sequence1_length': sequence1_len_ph,
            'sequence2': sequence2_embedding,
            'sequence2_length': sequence2_len_ph,
            'representation_size': representation_size,
            'dropout_keep_prob': dropout_keep_prob_ph
        }

        model = model_class(**model_kwargs)

        logits = model()

        predictions = tf.argmax(logits, axis=1, name='predictions')

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ph)
        loss = tf.reduce_mean(losses)

    if is_regularized:
        logger.info('Initializing the Regularizers')

        regularizer_placeholders = R.get_placeholders('regularizer')

        r_sequence1_ph = regularizer_placeholders['sequence1']
        r_sequence1_len_ph = regularizer_placeholders['sequence1_length']

        r_sequence2_ph = regularizer_placeholders['sequence2']
        r_sequence2_len_ph = regularizer_placeholders['sequence2_length']

        r_clipped_sequence1 = tfutil.clip_sentence(r_sequence1_ph,
                                                   r_sequence1_len_ph)
        r_clipped_sequence2 = tfutil.clip_sentence(r_sequence2_ph,
                                                   r_sequence2_len_ph)

        r_sequence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                       r_clipped_sequence1)
        r_sequence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                       r_clipped_sequence2)

        r_model_kwargs = model_kwargs.copy()
        r_model_kwargs.update({
            'sequence1': r_sequence1_embedding,
            'sequence1_length': r_sequence1_len_ph,
            'sequence2': r_sequence2_embedding,
            'sequence2_length': r_sequence2_len_ph
        })

        r_kwargs = {
            'model_class': model_class,
            'model_kwargs': r_model_kwargs,
            'debug': True
        }

        with tf.variable_scope(discriminator_scope_name):
            if r1_weight:
                r_loss, _ = R.contradiction_acl(is_bi=True, **r_kwargs)
                loss += r1_weight * r_loss
            if r2_weight:
                r_loss, _ = R.entailment_acl(is_bi=True, **r_kwargs)
                loss += r2_weight * r_loss
            if r3_weight:
                r_loss, _ = R.neutral_acl(is_bi=True, **r_kwargs)
                loss += r3_weight * r_loss
            if r4_weight:
                r_loss, _ = R.entailment_reflexive_acl(**r_kwargs)
                loss += r4_weight * r_loss
            if r5_weight:
                r_loss, _ = R.entailment_neutral_acl(is_bi=True, **r_kwargs)
                loss += r5_weight * r_loss

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)
    discriminator_init_op = tf.variables_initializer(discriminator_vars)

    trainable_discriminator_vars = list(discriminator_vars)
    trainable_discriminator_vars.remove(embedding_layer)

    discriminator_optimizer_scope_name = 'discriminator_optimizer'
    with tf.variable_scope(discriminator_optimizer_scope_name):
        gradients, v = zip(*optimizer.compute_gradients(
            loss, var_list=trainable_discriminator_vars))
        if clip_value:
            gradients, _ = tf.clip_by_global_norm(gradients, clip_value)
        training_step = optimizer.apply_gradients(zip(gradients, v))

    discriminator_optimizer_vars = tfutil.get_variables_in_scope(
        discriminator_optimizer_scope_name)
    discriminator_optimizer_init_op = tf.variables_initializer(
        discriminator_optimizer_vars)

    predictions_int = tf.cast(predictions, tf.int32)
    labels_int = tf.cast(label_ph, tf.int32)
    accuracy = tf.cast(tf.equal(x=predictions_int, y=labels_int), tf.float32)

    saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars,
                           max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    # session_config.log_device_placement = True

    nb_r_instances = len(regularizer_is)

    r_sampler = WithoutReplacementSampler(
        nb_instances=nb_r_instances) if is_regularized else None
    r_generator = InstanceGenerator(token_to_index=token_to_index)

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))
        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(
                var_list=trainable_discriminator_vars)))

        if restore_path is not None:
            saver.restore(session, restore_path)
        else:
            session.run(
                [discriminator_init_op, discriminator_optimizer_init_op])

        nb_instances = train_sequence1.shape[0]
        batches = util.make_batches(size=nb_instances, batch_size=batch_size)

        loss_values = []
        best_valid_accuracy = None

        iteration_index = 0
        for epoch in range(1, nb_epochs + 1):
            order = rs.permutation(nb_instances)
            shuf_sequence1 = train_sequence1[order]
            shuf_sequence2 = train_sequence2[order]
            shuf_sequence1_len = train_sequence1_len[order]
            shuf_sequence2_len = train_sequence2_len[order]
            shuf_label = train_label[order]

            # Semi-sorting
            order = util.semi_sort(shuf_sequence1_len, shuf_sequence2_len)
            shuf_sequence1 = shuf_sequence1[order]
            shuf_sequence2 = shuf_sequence2[order]
            shuf_sequence1_len = shuf_sequence1_len[order]
            shuf_sequence2_len = shuf_sequence2_len[order]
            shuf_label = shuf_label[order]

            for batch_idx, (batch_start, batch_end) in enumerate(batches,
                                                                 start=1):
                iteration_index += 1

                batch_sequence1 = shuf_sequence1[batch_start:batch_end]
                batch_sequence2 = shuf_sequence2[batch_start:batch_end]
                batch_sequence1_len = shuf_sequence1_len[batch_start:batch_end]
                batch_sequence2_len = shuf_sequence2_len[batch_start:batch_end]
                batch_label = shuf_label[batch_start:batch_end]

                batch_max_size1 = np.max(batch_sequence1_len)
                batch_max_size2 = np.max(batch_sequence2_len)

                batch_sequence1 = batch_sequence1[:, :batch_max_size1]
                batch_sequence2 = batch_sequence2[:, :batch_max_size2]

                current_batch_size = batch_sequence1.shape[0]

                batch_feed_dict = {
                    sequence1_ph: batch_sequence1,
                    sequence1_len_ph: batch_sequence1_len,
                    sequence2_ph: batch_sequence2,
                    sequence2_len_ph: batch_sequence2_len,
                    label_ph: batch_label,
                    dropout_keep_prob_ph: dropout_keep_prob
                }

                if is_regularized:
                    r_instances = [
                        regularizer_is[index]
                        for index in r_sampler.sample(nb_r_samples)
                    ]

                    c_instances = []
                    for r_instance in r_instances:
                        r_sentence1 = r_instance['sentence1_parse_tokens']
                        r_sentence2 = r_instance['sentence2_parse_tokens']

                        f_sentence1_lst, f_sentence2_lst = r_generator.flip(
                            r_sentence1, r_sentence2, nb_r_flips)

                        for f_sentence1, f_sentence2 in zip(
                                f_sentence1_lst, f_sentence2_lst):
                            c_instance = {
                                'sentence1_parse_tokens': f_sentence1,
                                'sentence2_parse_tokens': f_sentence2
                            }
                            c_instances += [c_instance]

                    r_instances += c_instances
                    r_tensors = util.to_tensors(r_instances, token_to_index,
                                                label_to_index, **token_kwargs)

                    assert len(r_instances) == r_tensors['sequence1'].shape[0]
                    # logging.info('Regularising on {} samples ..'.format(len(r_instances)))

                    batch_feed_dict.update({
                        r_sequence1_ph:
                        r_tensors['sequence1'],
                        r_sequence1_len_ph:
                        r_tensors['sequence1_length'],
                        r_sequence2_ph:
                        r_tensors['sequence2'],
                        r_sequence2_len_ph:
                        r_tensors['sequence2_length'],
                    })

                _, loss_value = session.run([training_step, loss],
                                            feed_dict=batch_feed_dict)
                loss_values += [loss_value / current_batch_size]

                if len(loss_values) >= REPORT_LOSS_INTERVAL:
                    logger.info("Epoch {0}, Batch {1}\tLoss: {2}".format(
                        epoch, batch_idx, util.stats(loss_values)))
                    loss_values = []

                # every k iterations, check whether accuracy improves
                if check_interval is not None and iteration_index % check_interval == 0:
                    accuracies_valid = evaluation.evaluate(session,
                                                           valid_tensors,
                                                           placeholders,
                                                           accuracy,
                                                           batch_size=256)
                    accuracies_test = evaluation.evaluate(session,
                                                          test_tensors,
                                                          placeholders,
                                                          accuracy,
                                                          batch_size=256)

                    accuracies_test2 = None
                    if test2_tensors is not None:
                        accuracies_test2 = evaluation.evaluate(session,
                                                               test2_tensors,
                                                               placeholders,
                                                               accuracy,
                                                               batch_size=256)

                    logger.info(
                        "Epoch {0}\tBatch {1}\tValidation Accuracy: {2}, Test Accuracy: {3}"
                        .format(epoch, batch_idx, util.stats(accuracies_valid),
                                util.stats(accuracies_test)))

                    if accuracies_test2 is not None:
                        logger.info(
                            "Epoch {0}\tBatch {1}\tValidation Accuracy: {2}, Test2 Accuracy: {3}"
                            .format(epoch, batch_idx,
                                    util.stats(accuracies_valid),
                                    util.stats(accuracies_test2)))

                    if best_valid_accuracy is None or best_valid_accuracy < np.mean(
                            accuracies_valid):
                        best_valid_accuracy = np.mean(accuracies_valid)
                        logger.info(
                            "Epoch {0}\tBatch {1}\tBest Validation Accuracy: {2}, Test Accuracy: {3}"
                            .format(epoch, batch_idx,
                                    util.stats(accuracies_valid),
                                    util.stats(accuracies_test)))

                        if accuracies_test2 is not None:
                            logger.info(
                                "Epoch {0}\tBatch {1}\tBest Validation Accuracy: {2}, Test2 Accuracy: {3}"
                                .format(epoch, batch_idx,
                                        util.stats(accuracies_valid),
                                        util.stats(accuracies_test2)))

                        save_model(save_path, saver, session, index_to_token)

            # End of epoch statistics
            accuracies_valid = evaluation.evaluate(session,
                                                   valid_tensors,
                                                   placeholders,
                                                   accuracy,
                                                   batch_size=256)
            accuracies_test = evaluation.evaluate(session,
                                                  test_tensors,
                                                  placeholders,
                                                  accuracy,
                                                  batch_size=256)

            accuracies_test2 = None
            if test2_tensors is not None:
                accuracies_test2 = evaluation.evaluate(session,
                                                       test2_tensors,
                                                       placeholders,
                                                       accuracy,
                                                       batch_size=256)

            logger.info(
                "Epoch {0}\tValidation Accuracy: {1}, Test Accuracy: {2}".
                format(epoch, util.stats(accuracies_valid),
                       util.stats(accuracies_test)))

            if accuracies_test2 is not None:
                logger.info(
                    "Epoch {0}\tValidation Accuracy: {1}, Test2 Accuracy: {2}".
                    format(epoch, util.stats(accuracies_valid),
                           util.stats(accuracies_test2)))

            if best_valid_accuracy is None or best_valid_accuracy < np.mean(
                    accuracies_valid):
                best_valid_accuracy = np.mean(accuracies_valid)
                logger.info(
                    "Epoch {0}\tBest Validation Accuracy: {1}, Test Accuracy: {2}"
                    .format(epoch, util.stats(accuracies_valid),
                            util.stats(accuracies_test)))

                if accuracies_test2 is not None:
                    logger.info(
                        "Epoch {0}\tBest Validation Accuracy: {1}, Test2 Accuracy: {2}"
                        .format(epoch, util.stats(accuracies_valid),
                                util.stats(accuracies_test2)))

                save_model(save_path, saver, session, index_to_token)

    logger.info('Training finished.')
Exemplo n.º 4
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Generating adversarial samples for NLI models', formatter_class=fmt)

    argparser.add_argument('--data',
                           '-d',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')

    argparser.add_argument('--model',
                           '-m',
                           action='store',
                           type=str,
                           default='ff-dam',
                           choices=['cbilstm', 'ff-dam', 'esim1'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           action='store',
                           type=int,
                           default=32)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--restore',
                           action='store',
                           type=str,
                           default='saved/snli/dam/2/dam')

    # The following parameters are devoted to regularization
    for rule_index in range(1, 5 + 1):
        argparser.add_argument('--regularizer{}-weight'.format(rule_index),
                               '-{}'.format(rule_index),
                               action='store',
                               type=float,
                               default=None)

    # Parameters for adversarial training
    argparser.add_argument('--lm',
                           action='store',
                           type=str,
                           default='saved/snli/lm/1/',
                           help='Language Model')

    # XXX: default to None (disable) - 0.01
    argparser.add_argument('--epsilon',
                           '--eps',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--nb-corruptions',
                           '--nc',
                           action='store',
                           type=int,
                           default=32)
    argparser.add_argument('--nb-examples-per-batch',
                           '--nepb',
                           action='store',
                           type=int,
                           default=4)

    # XXX: default to -1 (disable) - 4
    argparser.add_argument('--top-k', action='store', type=int, default=-1)

    argparser.add_argument('--flip', '-f', action='store_true', default=False)
    argparser.add_argument('--combine',
                           '-c',
                           action='store_true',
                           default=False)
    argparser.add_argument('--remove',
                           '-r',
                           action='store_true',
                           default=False)
    argparser.add_argument('--scramble',
                           '-s',
                           action='store',
                           type=int,
                           default=None)

    argparser.add_argument('--json', action='store', type=str, default=None)

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    restore_path = args.restore

    # The following parameters are devoted to regularization
    r1_weight = args.regularizer1_weight
    r2_weight = args.regularizer2_weight
    r3_weight = args.regularizer3_weight
    r4_weight = args.regularizer4_weight
    r5_weight = args.regularizer5_weight

    lm_path = args.lm
    epsilon = args.epsilon
    nb_corruptions = args.nb_corruptions
    nb_examples_per_batch = args.nb_examples_per_batch
    top_k = args.top_k
    is_flip = args.flip
    is_combine = args.combine
    is_remove = args.remove
    scramble = args.scramble

    json_path = args.json

    np.random.seed(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')

    snli = SNLI()
    data_is = snli.parse(path=data_path)

    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    #   index=0 PADDING
    #   index=1 START_OF_SENTENCE
    #   index=2 END_OF_SENTENCE
    #   index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    # Words start at index 4
    start_idx = 1 + 3

    assert restore_path is not None
    vocab_path = '{}_index_to_token.p'.format(restore_path)
    logger.info('Restoring vocabulary from {} ..'.format(vocab_path))

    with open(vocab_path, 'rb') as f:
        index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }
    index_to_label = {k: v for v, k in label_to_index.items()}

    token_kwargs = dict(bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx)

    data_tensors = util.to_tensors(data_is, token_to_index, label_to_index,
                                   **token_kwargs)

    sequence1 = data_tensors['sequence1']
    sequence1_len = data_tensors['sequence1_length']

    sequence2 = data_tensors['sequence2']
    sequence2_len = data_tensors['sequence2_length']

    label = data_tensors['label']

    sequence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sequence1')
    sequence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sequence1_length')

    sequence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sequence2')
    sequence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sequence2_length')

    dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob')

    data_tensors['dropout'] = 1.0

    clipped_sequence1 = tfutil.clip_sentence(sequence1_ph, sequence1_len_ph)
    clipped_sequence2 = tfutil.clip_sentence(sequence2_ph, sequence2_len_ph)

    nb_instances = sequence1.shape[0]

    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):

        embedding_layer = tf.get_variable('embeddings',
                                          shape=[vocab_size, embedding_size],
                                          initializer=None)

        sequence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sequence1)
        sequence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sequence2)

        model_kwargs = {
            'sequence1': sequence1_embedding,
            'sequence1_length': sequence1_len_ph,
            'sequence2': sequence2_embedding,
            'sequence2_length': sequence2_len_ph,
            'representation_size': representation_size,
            'dropout_keep_prob': dropout_keep_prob_ph
        }

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'esim1': ESIM
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)
        logits = model()
        probabilities = tf.nn.softmax(logits)

        a_pooling_function = tf.reduce_max

        a_model_kwargs = copy.copy(model_kwargs)

        a_sentence1_ph = tf.placeholder(dtype=tf.int32,
                                        shape=[None, None],
                                        name='a_sentence1')
        a_sentence2_ph = tf.placeholder(dtype=tf.int32,
                                        shape=[None, None],
                                        name='a_sentence2')

        a_sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                            shape=[None],
                                            name='a_sentence1_length')
        a_sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                            shape=[None],
                                            name='a_sentence2_length')

        a_clipped_sentence1 = tfutil.clip_sentence(a_sentence1_ph,
                                                   a_sentence1_len_ph)
        a_clipped_sentence2 = tfutil.clip_sentence(a_sentence2_ph,
                                                   a_sentence2_len_ph)

        a_sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                       a_clipped_sentence1)
        a_sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                       a_clipped_sentence2)

        a_model_kwargs.update({
            'sequence1': a_sentence1_embedding,
            'sequence1_length': a_sentence1_len_ph,
            'sequence2': a_sentence2_embedding,
            'sequence2_length': a_sentence2_len_ph
        })

        a_kwargs = dict(model_class=model_class,
                        model_kwargs=a_model_kwargs,
                        entailment_idx=entailment_idx,
                        contradiction_idx=contradiction_idx,
                        neutral_idx=neutral_idx,
                        pooling_function=a_pooling_function,
                        debug=True)

        a_function_weight_bi_tuple_lst = []

        loss = tf.constant(0.0)

        if r1_weight:
            r_loss, _ = R.contradiction_acl(is_bi=True, **a_kwargs)
            a_function_weight_bi_tuple_lst += [(R.contradiction_acl, r1_weight,
                                                True)]
            loss += r1_weight * r_loss
        if r2_weight:
            r_loss, _ = R.entailment_acl(is_bi=True, **a_kwargs)
            a_function_weight_bi_tuple_lst += [(R.entailment_acl, r2_weight,
                                                True)]
            loss += r2_weight * r_loss
        if r3_weight:
            r_loss, _ = R.neutral_acl(is_bi=True, **a_kwargs)
            a_function_weight_bi_tuple_lst += [(R.neutral_acl, r3_weight, True)
                                               ]
            loss += r3_weight * r_loss
        if r4_weight:
            r_loss, _ = R.entailment_reflexive_acl(**a_kwargs)
            a_function_weight_bi_tuple_lst += [(R.entailment_reflexive_acl,
                                                r4_weight, False)]
            loss += r4_weight * r_loss
        if r5_weight:
            r_loss, _ = R.entailment_neutral_acl(is_bi=True, **a_kwargs)
            a_function_weight_bi_tuple_lst += [(R.entailment_neutral_acl,
                                                r5_weight, True)]
            loss += r5_weight * r_loss

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)

    trainable_discriminator_vars = list(discriminator_vars)
    trainable_discriminator_vars.remove(embedding_layer)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    instance_generator = InstanceGenerator(token_to_index=token_to_index)

    instance_scorer = None
    if top_k is not None:
        with tf.variable_scope(discriminator_scope_name):
            instance_scorer = InstanceScorer(
                embedding_layer=embedding_layer,
                token_to_index=token_to_index,
                model_class=model_class,
                model_kwargs=model_kwargs,
                i_pooling_function=tf.reduce_sum,
                a_function_weight_bi_tuple_lst=a_function_weight_bi_tuple_lst)

    a_batch_size = (nb_corruptions * is_flip) + \
                   (nb_corruptions * is_remove) + \
                   (nb_corruptions * is_combine) + \
                   (nb_corruptions * (scramble is not None))

    lm_scorer_adversarial_batch = lm_scorer_batch = None
    if epsilon is not None:
        lm_scorer_adversarial_batch = LMScorer(embedding_layer=embedding_layer,
                                               token_to_index=token_to_index,
                                               batch_size=a_batch_size)

        lm_scorer_batch = LMScorer(embedding_layer=embedding_layer,
                                   token_to_index=token_to_index,
                                   batch_size=batch_size,
                                   reuse=True)

        lm_vars = lm_scorer_adversarial_batch.get_vars()
        lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

    A_rs = np.random.RandomState(0)

    sentence_pair_to_original_pair = {}
    original_pair_to_label = {}
    sentence_pair_to_corruption_type = {}

    rs = np.random.RandomState(seed)

    with tf.Session(config=session_config) as session:

        if lm_scorer_adversarial_batch is not None:
            lm_ckpt = tf.train.get_checkpoint_state(lm_path)
            lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

        saver.restore(session, restore_path)

        batches = make_batches(size=nb_instances, batch_size=batch_size)

        for batch_idx, (batch_start, batch_end) in enumerate(batches):
            # order = np.arange(nb_instances)
            order = rs.permutation(nb_instances)

            sentences1, sentences2 = sequence1[order], sequence2[order]
            sizes1, sizes2 = sequence1_len[order], sequence2_len[order]
            labels = label[order]

            batch_sentences1 = sentences1[batch_start:batch_end]
            batch_sentences2 = sentences2[batch_start:batch_end]

            batch_sizes1 = sizes1[batch_start:batch_end]
            batch_sizes2 = sizes2[batch_start:batch_end]

            batch_labels = labels[batch_start:batch_end]

            batch_max_size1 = np.max(batch_sizes1)
            batch_max_size2 = np.max(batch_sizes2)

            batch_sentences1 = batch_sentences1[:, :batch_max_size1]
            batch_sentences2 = batch_sentences2[:, :batch_max_size2]

            # Remove the BOS token from sentences
            o_batch_size = batch_sentences1.shape[0]
            o_sentences1, o_sentences2 = [], []

            for i in range(o_batch_size):
                o_sentences1 += [[
                    idx for idx in batch_sentences1[i, 1:].tolist() if idx != 0
                ]]
                o_sentences2 += [[
                    idx for idx in batch_sentences2[i, 1:].tolist() if idx != 0
                ]]

            # Parameters for adversarial training:
            # a_epsilon, a_nb_corruptions, a_nb_examples_per_batch, a_is_flip, a_is_combine, a_is_remove, a_is_scramble
            selected_sentence1, selected_sentence2 = [], []

            # First, add all training sentences
            selected_sentence1 += o_sentences1
            selected_sentence2 += o_sentences2

            op2idx = {}
            for idx, (o1, o2) in enumerate(zip(o_sentences1, o_sentences2)):
                op2idx[(tuple(o1), tuple(o2))] = idx

            for a, b, c in zip(selected_sentence1, selected_sentence2,
                               batch_labels):
                sentence_pair_to_original_pair[(tuple(a), tuple(b))] = (a, b)
                original_pair_to_label[(tuple(a), tuple(b))] = c
                sentence_pair_to_corruption_type[(tuple(a), tuple(b))] = 'none'

            c_idxs = A_rs.choice(o_batch_size,
                                 nb_examples_per_batch,
                                 replace=False)
            for c_idx in c_idxs:
                o_sentence1 = o_sentences1[c_idx]
                o_sentence2 = o_sentences2[c_idx]

                sentence_pair_to_original_pair[(tuple(o_sentence1),
                                                tuple(o_sentence2))] = (
                                                    o_sentence1, o_sentence2)
                sentence_pair_to_corruption_type[(tuple(o_sentence1),
                                                  tuple(o_sentence2))] = 'none'

                # Generating Corruptions
                c_sentence1_lst, c_sentence2_lst = [], []
                if is_flip:
                    corruptions_1, corruptions_2 = instance_generator.flip(
                        sentence1=o_sentence1,
                        sentence2=o_sentence2,
                        nb_corruptions=nb_corruptions)
                    c_sentence1_lst += corruptions_1
                    c_sentence2_lst += corruptions_2

                    for _corruption_1, _corruption_2 in zip(
                            corruptions_1, corruptions_2):
                        sentence_pair_to_original_pair[(
                            tuple(_corruption_1),
                            tuple(_corruption_2))] = (o_sentence1, o_sentence2)
                        if (tuple(_corruption_1), tuple(_corruption_2)
                            ) not in sentence_pair_to_corruption_type:
                            sentence_pair_to_corruption_type[(
                                tuple(_corruption_1),
                                tuple(_corruption_2))] = 'flip'

                if is_remove:
                    corruptions_1, corruptions_2 = instance_generator.remove(
                        sentence1=o_sentence1,
                        sentence2=o_sentence2,
                        nb_corruptions=nb_corruptions)
                    c_sentence1_lst += corruptions_1
                    c_sentence2_lst += corruptions_2

                    for _corruption_1, _corruption_2 in zip(
                            corruptions_1, corruptions_2):
                        sentence_pair_to_original_pair[(
                            tuple(_corruption_1),
                            tuple(_corruption_2))] = (o_sentence1, o_sentence2)
                        if (tuple(_corruption_1), tuple(_corruption_2)
                            ) not in sentence_pair_to_corruption_type:
                            sentence_pair_to_corruption_type[(
                                tuple(_corruption_1),
                                tuple(_corruption_2))] = 'remove'

                if is_combine:
                    corruptions_1, corruptions_2 = instance_generator.combine(
                        sentence1=o_sentence1,
                        sentence2=o_sentence2,
                        nb_corruptions=nb_corruptions)
                    c_sentence1_lst += corruptions_1
                    c_sentence2_lst += corruptions_2

                    for _corruption_1, _corruption_2 in zip(
                            corruptions_1, corruptions_2):
                        sentence_pair_to_original_pair[(
                            tuple(_corruption_1),
                            tuple(_corruption_2))] = (o_sentence1, o_sentence2)
                        if (tuple(_corruption_1), tuple(_corruption_2)
                            ) not in sentence_pair_to_corruption_type:
                            sentence_pair_to_corruption_type[(
                                tuple(_corruption_1),
                                tuple(_corruption_2))] = 'combine'

                if scramble is not None:
                    corruptions_1, corruptions_2 = instance_generator.scramble(
                        sentence1=o_sentence1,
                        sentence2=o_sentence2,
                        nb_corruptions=nb_corruptions,
                        nb_pooled_sentences=scramble,
                        sentence_pool=o_sentences1 + o_sentences2)
                    c_sentence1_lst += corruptions_1
                    c_sentence2_lst += corruptions_2

                    for _corruption_1, _corruption_2 in zip(
                            corruptions_1, corruptions_2):
                        sentence_pair_to_original_pair[(
                            tuple(_corruption_1),
                            tuple(_corruption_2))] = (o_sentence1, o_sentence2)
                        if (tuple(_corruption_1), tuple(_corruption_2)
                            ) not in sentence_pair_to_corruption_type:
                            sentence_pair_to_corruption_type[(
                                tuple(_corruption_1),
                                tuple(_corruption_2))] = 'scramble'

                if epsilon is not None and lm_scorer_adversarial_batch is not None:
                    # Scoring them against a Language Model
                    log_perp1 = lm_scorer_adversarial_batch.score(
                        session, c_sentence1_lst)
                    log_perp2 = lm_scorer_adversarial_batch.score(
                        session, c_sentence2_lst)

                    log_perp_o1 = lm_scorer_batch.score(session, o_sentences1)
                    log_perp_o2 = lm_scorer_batch.score(session, o_sentences2)

                    low_lperp_idxs = []
                    for i, (c1, c2, _lp1, _lp2) in enumerate(
                            zip(c_sentence1_lst, c_sentence2_lst, log_perp1,
                                log_perp2)):
                        o1, o2 = sentence_pair_to_original_pair[(tuple(c1),
                                                                 tuple(c2))]
                        idx = op2idx[(tuple(o1), tuple(o2))]
                        _log_perp_o1 = log_perp_o1[idx]
                        _log_perp_o2 = log_perp_o2[idx]

                        if (_lp1 + _lp2) <= (_log_perp_o1 + _log_perp_o2 +
                                             epsilon):
                            low_lperp_idxs += [i]
                else:
                    low_lperp_idxs = range(len(c_sentence1_lst))

                selected_sentence1 += [
                    c_sentence1_lst[i] for i in low_lperp_idxs
                ]
                selected_sentence2 += [
                    c_sentence2_lst[i] for i in low_lperp_idxs
                ]

            selected_scores = None

            sentence_pair_to_score = {}

            # Now in selected_sentence1 and selected_sentence2 we have the most offending examples
            if top_k >= 0 and instance_scorer is not None:
                iscore_values = instance_scorer.iscore(session,
                                                       selected_sentence1,
                                                       selected_sentence2)

                for a, b, c in zip(selected_sentence1, selected_sentence2,
                                   iscore_values):
                    sentence_pair_to_score[(tuple(a), tuple(b))] = c

                top_k_idxs = np.argsort(iscore_values)[::-1][:top_k]

                selected_sentence1 = [
                    selected_sentence1[i] for i in top_k_idxs
                ]
                selected_sentence2 = [
                    selected_sentence2[i] for i in top_k_idxs
                ]

                selected_scores = [iscore_values[i] for i in top_k_idxs]

            def decode(sentence_ids):
                return ' '.join([index_to_token[idx] for idx in sentence_ids])

            def infer(s1_ids, s2_ids):
                a = np.array([[bos_idx] + s1_ids])
                b = np.array([[bos_idx] + s2_ids])

                c = np.array([1 + len(s1_ids)])
                d = np.array([1 + len(s2_ids)])

                inf_feed = {
                    sequence1_ph: a,
                    sequence2_ph: b,
                    sequence1_len_ph: c,
                    sequence2_len_ph: d,
                    dropout_keep_prob_ph: 1.0
                }
                pv = session.run(probabilities, feed_dict=inf_feed)
                return {
                    'ent': str(pv[0, entailment_idx]),
                    'neu': str(pv[0, neutral_idx]),
                    'con': str(pv[0, contradiction_idx])
                }

            logger.info("No. of generated pairs: {}".format(
                len(selected_sentence1)))

            for i, (s1, s2, score) in enumerate(
                    zip(selected_sentence1, selected_sentence2,
                        selected_scores)):
                o1, o2 = sentence_pair_to_original_pair[(tuple(s1), tuple(s2))]
                lbl = original_pair_to_label[(tuple(o1), tuple(o2))]
                corr = sentence_pair_to_corruption_type[(tuple(s1), tuple(s2))]

                oiscore = sentence_pair_to_score.get((tuple(o1), tuple(o2)),
                                                     1.0)

                print('[{}] Original 1: {}'.format(i, decode(o1)))
                print('[{}] Original 2: {}'.format(i, decode(o2)))
                print('[{}] Original Label: {}'.format(i, index_to_label[lbl]))
                print('[{}] Original Inconsistency Loss: {}'.format(
                    i, oiscore))

                print('[{}] Sentence 1: {}'.format(i, decode(s1)))
                print('[{}] Sentence 2: {}'.format(i, decode(s2)))

                print('[{}] Inconsistency Loss: {}'.format(i, score))

                print('[{}] Corruption type: {}'.format(i, corr))

                print('[{}] (before) s1 -> s2: {}'.format(
                    i, str(infer(o1, o2))))
                print('[{}] (before) s2 -> s1: {}'.format(
                    i, str(infer(o2, o1))))

                print('[{}] (after) s1 -> s2: {}'.format(
                    i, str(infer(s1, s2))))
                print('[{}] (after) s2 -> s1: {}'.format(
                    i, str(infer(s2, s1))))

                jdata = {
                    'original_sentence1': decode(o1),
                    'original_sentence2': decode(o2),
                    'original_inconsistency_loss': str(oiscore),
                    'original_label': index_to_label[lbl],
                    'sentence1': decode(s1),
                    'sentence2': decode(s2),
                    'inconsistency_loss': str(score),
                    'inconsistency_loss_increase': str(score - oiscore),
                    'corruption': str(corr),
                    'probabilities_before_s1_s2': infer(o1, o2),
                    'probabilities_before_s2_s1': infer(o2, o1),
                    'probabilities_after_s1_s2': infer(s1, s2),
                    'probabilities_after_s2_s1': infer(s2, s1)
                }

                if json_path is not None:
                    with open(json_path, 'a') as f:
                        json.dump(jdata, f)
                        f.write('\n')