예제 #1
0
def main(argv):
    vocabulary_path = 'models/snli/dam_1/dam_1_index_to_token.p'
    checkpoint_path = 'models/snli/dam_1/dam_1'
    lm_path = 'models/lm/'

    with open(vocabulary_path, 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    with open('{}/config.json'.format(lm_path), 'r') as f:
        config = json.load(f)

    seq_length = 1
    batch_size = 1
    rnn_size = config['rnn_size']
    num_layers = config['num_layers']

    vocab_size = len(token_to_index)
    assert vocab_size == config['vocab_size']

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable(
            'embeddings',
            shape=[vocab_size, config['embedding_size']],
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=False)

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        cell_fn = rnn.BasicLSTMCell
        cells = [cell_fn(rnn_size) for _ in range(num_layers)]

        cell = rnn.MultiRNNCell(cells)

        input_data = tf.placeholder(tf.int32, [None, seq_length])
        targets = tf.placeholder(tf.int32, [None, seq_length])
        initial_state = cell.zero_state(batch_size, tf.float32)

        with tf.variable_scope('rnnlm'):
            W = tf.get_variable(
                name='W',
                shape=[rnn_size, vocab_size],
                initializer=tf.contrib.layers.xavier_initializer())
            b = tf.get_variable(name='b',
                                shape=[vocab_size],
                                initializer=tf.zeros_initializer())

            emb_lookup = tf.nn.embedding_lookup(embedding_layer, input_data)
            emb_projection = tf.contrib.layers.fully_connected(
                inputs=emb_lookup,
                num_outputs=rnn_size,
                weights_initializer=tf.contrib.layers.xavier_initializer(),
                biases_initializer=tf.zeros_initializer())

            inputs = tf.split(emb_projection, seq_length, 1)
            inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

        outputs, last_state = legacy_seq2seq.rnn_decoder(
            decoder_inputs=inputs,
            initial_state=initial_state,
            cell=cell,
            loop_function=None,
            scope='rnnlm')
        output = tf.reshape(tf.concat(outputs, 1), [-1, rnn_size])

        logits = tf.matmul(output, W) + b
        probabilities = tf.nn.softmax(logits)

        loss = legacy_seq2seq.sequence_loss_by_example(
            logits=[logits],
            targets=[tf.reshape(targets, [-1])],
            weights=[tf.ones([batch_size * seq_length])])

        cost = tf.reduce_sum(loss) / batch_size / seq_length
        final_state = last_state

    saver = tf.train.Saver(tf.global_variables())
    emb_saver = tf.train.Saver([embedding_layer], max_to_keep=1)

    logger.info('Creating the session ..')

    with tf.Session() as session:
        emb_saver.restore(session, checkpoint_path)

        ckpt = tf.train.get_checkpoint_state(lm_path)
        assert ckpt is not None
        saver.restore(session, ckpt.model_checkpoint_path)

        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))

        sequence = [
            token_to_index[w]
            for w in ['A', 'happy', 'girl', 'flies', 'in', 'a', 'gun']
        ]

        state = session.run(cell.zero_state(1, tf.float32))

        x = np.zeros(shape=(1, seq_length))
        y = np.zeros(shape=(1, seq_length))

        log_perplexity = 0.0
        for i in range(len(sequence)):
            if i + 1 < len(sequence):
                x[0, 0] = sequence[i]
                y[0, 0] = sequence[i + 1]
                feed = {input_data: x, targets: y, initial_state: state}
                cost_value, state = session.run([cost, final_state],
                                                feed_dict=feed)
                log_perplexity += cost_value

        print(log_perplexity)
예제 #2
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt)

    argparser.add_argument('--train', '-t', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--valid', '-v', action='store', type=str, default='data/snli/snli_1.0_dev.jsonl.gz')
    argparser.add_argument('--test', '-T', action='store', type=str, default='data/snli/snli_1.0_test.jsonl.gz')

    argparser.add_argument('--model', '-m', action='store', type=str, default='cbilstm',
                           choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])
    argparser.add_argument('--optimizer', '-o', action='store', type=str, default='adagrad',
                           choices=['adagrad', 'adam'])

    argparser.add_argument('--embedding-size', action='store', type=int, default=300)
    argparser.add_argument('--representation-size', action='store', type=int, default=200)

    argparser.add_argument('--batch-size', action='store', type=int, default=1024)

    argparser.add_argument('--nb-epochs', '-e', action='store', type=int, default=1000)
    argparser.add_argument('--nb-discriminator-epochs', '-D', action='store', type=int, default=1)
    argparser.add_argument('--nb-adversary-epochs', '-A', action='store', type=int, default=1000)

    argparser.add_argument('--dropout-keep-prob', action='store', type=float, default=1.0)
    argparser.add_argument('--learning-rate', action='store', type=float, default=0.1)
    argparser.add_argument('--clip', '-c', action='store', type=float, default=None)
    argparser.add_argument('--nb-words', action='store', type=int, default=None)
    argparser.add_argument('--seed', action='store', type=int, default=0)
    argparser.add_argument('--std-dev', action='store', type=float, default=0.01)

    argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token')
    argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus')

    argparser.add_argument('--initialize-embeddings', '-i', action='store', type=str, default=None,
                           choices=['normal', 'uniform'])

    argparser.add_argument('--fixed-embeddings', '-f', action='store_true')
    argparser.add_argument('--normalize-embeddings', '-n', action='store_true')
    argparser.add_argument('--only-use-pretrained-embeddings', '-p', action='store_true',
                           help='Only use pre-trained word embeddings')
    argparser.add_argument('--semi-sort', '-S', action='store_true')

    argparser.add_argument('--save', action='store', type=str, default=None)
    argparser.add_argument('--hard-save', action='store', type=str, default=None)
    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--glove', action='store', type=str, default=None)

    argparser.add_argument('--rule00-weight', '--00', action='store', type=float, default=None)
    argparser.add_argument('--rule01-weight', '--01', action='store', type=float, default=None)
    argparser.add_argument('--rule02-weight', '--02', action='store', type=float, default=None)
    argparser.add_argument('--rule03-weight', '--03', action='store', type=float, default=None)

    for i in range(1, 9):
        argparser.add_argument('--rule{}-weight'.format(i), '-{}'.format(i), action='store', type=float, default=None)

    argparser.add_argument('--adversarial-batch-size', '-B', action='store', type=int, default=32)
    argparser.add_argument('--adversarial-pooling', '-P', default='max', choices=['sum', 'max', 'mean', 'logsumexp'])

    argparser.add_argument('--report', '-r', default=100, type=int, help='Number of batches between performance reports')
    argparser.add_argument('--report-loss', default=100, type=int, help='Number of batches between loss reports')

    argparser.add_argument('--eval', '-E', nargs='+', type=str, help='Evaluate on these additional sets')

    args = argparser.parse_args(argv)

    # Command line arguments
    train_path, valid_path, test_path = args.train, args.valid, args.test

    model_name = args.model
    optimizer_name = args.optimizer

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    nb_epochs = args.nb_epochs
    nb_discriminator_epochs = args.nb_discriminator_epochs

    dropout_keep_prob = args.dropout_keep_prob
    learning_rate = args.learning_rate
    clip_value = args.clip
    seed = args.seed
    std_dev = args.std_dev

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    initialize_embeddings = args.initialize_embeddings

    is_fixed_embeddings = args.fixed_embeddings
    is_normalize_embeddings = args.normalize_embeddings
    is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings
    is_semi_sort = args.semi_sort

    logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format(has_bos, has_eos, has_unk))
    logger.info('is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}'
                .format(is_lower, is_fixed_embeddings, is_normalize_embeddings))
    logger.info('is_only_use_pretrained_embeddings: {}, is_semi_sort: {}'
                .format(is_only_use_pretrained_embeddings, is_semi_sort))

    save_path = args.save
    hard_save_path = args.hard_save
    restore_path = args.restore

    glove_path = args.glove

    # Experimental RTE regularizers
    rule00_weight = args.rule00_weight
    rule01_weight = args.rule01_weight
    rule02_weight = args.rule02_weight
    rule03_weight = args.rule03_weight

    rule1_weight = args.rule1_weight
    rule2_weight = args.rule2_weight
    rule3_weight = args.rule3_weight
    rule4_weight = args.rule4_weight
    rule5_weight = args.rule5_weight
    rule6_weight = args.rule6_weight
    rule7_weight = args.rule7_weight
    rule8_weight = args.rule8_weight

    a_batch_size = args.adversarial_batch_size
    adversarial_pooling_name = args.adversarial_pooling

    name_to_adversarial_pooling = {
        'sum': tf.reduce_sum,
        'max': tf.reduce_max,
        'mean': tf.reduce_mean,
        'logsumexp': tf.reduce_logsumexp
    }

    report_interval = args.report
    report_loss_interval = args.report_loss

    eval_paths = args.eval

    np.random.seed(seed)
    rs = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path, valid_path=valid_path, test_path=test_path, is_lower=is_lower)

    logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format(len(train_is), len(dev_is), len(test_is)))
    all_is = train_is + dev_is + test_is

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3
    start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else 0) + (1 if has_unk else 0)

    if not restore_path:
        # Create a sequence of tokens containing all sentences in the dataset
        token_seq = []
        for instance in all_is:
            token_seq += instance['sentence1_parse_tokens'] + instance['sentence2_parse_tokens']

        token_set = set(token_seq)
        allowed_words = None
        if is_only_use_pretrained_embeddings:
            assert glove_path is not None
            logger.info('Loading GloVe words from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            allowed_words = load_glove_words(path=glove_path, words=token_set)
            logger.info('Number of allowed words: {}'.format(len(allowed_words)))

        # Count the number of occurrences of each token
        token_counts = dict()
        for token in token_seq:
            if (allowed_words is None) or (token in allowed_words):
                if token not in token_counts:
                    token_counts[token] = 0
                token_counts[token] += 1

        # Sort the tokens according to their frequency and lexicographic ordering
        sorted_vocabulary = sorted(token_counts.keys(), key=lambda t: (- token_counts[t], t))

        index_to_token = {index: token for index, token in enumerate(sorted_vocabulary, start=start_idx)}
    else:
        with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
            index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None
    optimizer_name_to_class = {
        'adagrad': tf.train.AdagradOptimizer,
        'adam': tf.train.AdamOptimizer
    }

    optimizer_class = optimizer_name_to_class[optimizer_name]
    assert optimizer_class

    optimizer = optimizer_class(learning_rate=learning_rate)

    args = dict(has_bos=has_bos, has_eos=has_eos, has_unk=has_unk,
                bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx,
                max_len=max_len)

    train_dataset = util.instances_to_dataset(train_is, token_to_index, label_to_index, **args)
    dev_dataset = util.instances_to_dataset(dev_is, token_to_index, label_to_index, **args)
    test_dataset = util.instances_to_dataset(test_is, token_to_index, label_to_index, **args)

    sentence1 = train_dataset['sentence1']
    sentence1_length = train_dataset['sentence1_length']
    sentence2 = train_dataset['sentence2']
    sentence2_length = train_dataset['sentence2_length']
    label = train_dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label')

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    token_to_embedding = dict()
    if not restore_path:
        if glove_path:
            logger.info('Loading GloVe word embeddings from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            token_to_embedding = load_glove(glove_path, token_set)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        if initialize_embeddings == 'normal':
            logger.info('Initializing the embeddings with 𝓝(0, 1)')
            embedding_initializer = tf.random_normal_initializer(0.0, 1.0)
        elif initialize_embeddings == 'uniform':
            logger.info('Initializing the embeddings with 𝒰(-1, 1)')
            embedding_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0)
        else:
            logger.info('Initializing the embeddings with Xavier initialization')
            embedding_initializer = tf.contrib.layers.xavier_initializer()

        embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size],
                                          initializer=embedding_initializer, trainable=not is_fixed_embeddings)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob')

        model_kwargs = dict(
            sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph,
            sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph,
            representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = std_dev

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        predictions = tf.argmax(logits, axis=1, name='predictions')

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label_ph)
        loss = tf.reduce_mean(losses)

        a_pooling_function = name_to_adversarial_pooling[adversarial_pooling_name]

        a_losses = None
        if rule00_weight:
            a_loss, a_losses = contradiction_symmetry_l2(model_class, model_kwargs,
                                                         contradiction_idx=contradiction_idx,
                                                         pooling_function=a_pooling_function,
                                                         debug=True)
            loss += rule00_weight * a_loss

        if rule01_weight:
            a_loss, a_losses = contradiction_symmetry_l1(model_class, model_kwargs,
                                                         contradiction_idx=contradiction_idx,
                                                         pooling_function=a_pooling_function,
                                                         debug=True)
            loss += rule01_weight * a_loss

        if rule02_weight:
            a_loss, a_losses = contradiction_kullback_leibler(model_class, model_kwargs,
                                                              contradiction_idx=contradiction_idx,
                                                              pooling_function=a_pooling_function,
                                                              debug=True)
            loss += rule02_weight * a_loss

        if rule03_weight:
            a_loss, a_losses = contradiction_jensen_shannon(model_class, model_kwargs,
                                                            contradiction_idx=contradiction_idx,
                                                            pooling_function=a_pooling_function,
                                                            debug=True)
            loss += rule03_weight * a_loss

    discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name)
    discriminator_init_op = tf.variables_initializer(discriminator_vars)

    trainable_discriminator_vars = list(discriminator_vars)
    if is_fixed_embeddings:
        trainable_discriminator_vars.remove(embedding_layer)

    discriminator_optimizer_scope_name = 'discriminator_optimizer'
    with tf.variable_scope(discriminator_optimizer_scope_name):
        if clip_value:
            gradients, v = zip(*optimizer.compute_gradients(loss, var_list=trainable_discriminator_vars))
            gradients, _ = tf.clip_by_global_norm(gradients, clip_value)
            training_step = optimizer.apply_gradients(zip(gradients, v))
        else:
            training_step = optimizer.minimize(loss, var_list=trainable_discriminator_vars)

    discriminator_optimizer_vars = tfutil.get_variables_in_scope(discriminator_optimizer_scope_name)
    discriminator_optimizer_init_op = tf.variables_initializer(discriminator_optimizer_vars)

    token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx')
    token_embedding_ph = tf.placeholder(dtype=tf.float32, shape=[None], name='word_embedding')

    assign_token_embedding = embedding_layer[token_idx_ph, :].assign(token_embedding_ph)

    init_projection_steps = []
    learning_projection_steps = []

    if is_normalize_embeddings:
        embeddings_projection = constraints.unit_sphere(embedding_layer, norm=1.0)
        init_projection_steps += [embeddings_projection]

        if not is_fixed_embeddings:
            learning_projection_steps += [embeddings_projection]

    predictions_int = tf.cast(predictions, tf.int32)
    labels_int = tf.cast(label_ph, tf.int32)

    use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight

    rule_id_to_placeholders = dict()

    if use_adversarial_training:
        adversary_scope_name = discriminator_scope_name
        with tf.variable_scope(adversary_scope_name):

            adversarial = AdversarialSets3(model_class=model_class, model_kwargs=model_kwargs, scope_name='adversary',
                                           entailment_idx=entailment_idx, contradiction_idx=contradiction_idx, neutral_idx=neutral_idx)

            adversary_loss = tf.constant(0.0, dtype=tf.float32)

            adversarial_pooling = name_to_adversarial_pooling[adversarial_pooling_name]

            def f(rule_idx):
                nb_sequences = adversarial.rule_nb_sequences(rule_idx)
                a_args, rule_placeholders = [], []

                for seq_id in range(nb_sequences):
                    a_sentence_ph = tf.placeholder(dtype=tf.int32, shape=[None, None],
                                                   name='a_rule{}_sentence{}'.format(rule_idx, seq_id))
                    a_sentence_len_ph = tf.placeholder(dtype=tf.int32, shape=[None],
                                                       name='a_rule{}_sentence{}_length'.format(rule_idx, seq_id))

                    a_clipped_sentence = tfutil.clip_sentence(a_sentence_ph, a_sentence_len_ph)
                    a_sentence_embedding = tf.nn.embedding_lookup(embedding_layer, a_clipped_sentence)

                    a_args += [a_sentence_embedding, a_sentence_len_ph]
                    rule_placeholders += [(a_sentence_ph, a_sentence_len_ph)]

                rule_id_to_placeholders[rule_idx] = rule_placeholders
                rule_loss = adversarial.rule1_loss(*a_args)
                return rule_loss

            if rule1_weight:
                r_loss = f(1)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule2_weight:
                r_loss = f(2)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule3_weight:
                r_loss = f(3)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule4_weight:
                r_loss = f(4)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule5_weight:
                r_loss = f(5)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule6_weight:
                r_loss = f(6)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule7_weight:
                r_loss = f(7)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)
            if rule8_weight:
                r_loss = f(8)
                adversary_loss += rule1_weight * adversarial_pooling(r_loss)

            loss += adversary_loss

        logger.info('Adversarial Batch Size: {}'.format(a_batch_size))

    a_feed_dict = dict()
    a_rs = np.random.RandomState(seed)

    d_sentence1, d_sentence2 = train_dataset['sentence1'], train_dataset['sentence2']
    d_sentence1_len, d_sentence2_len = train_dataset['sentence1_length'], train_dataset['sentence2_length']
    d_label = train_dataset['label']

    nb_train_instances = d_label.shape[0]

    max_sentence_len = max(d_sentence1.shape[1], d_sentence2.shape[1])
    d_sentence = np.zeros(shape=(nb_train_instances * 2, max_sentence_len), dtype=np.int)
    d_sentence[0:d_sentence1.shape[0], 0:d_sentence1.shape[1]] = d_sentence1
    d_sentence[d_sentence1.shape[0]:, 0:d_sentence2.shape[1]] = d_sentence2

    d_sentence_len = np.concatenate((d_sentence1_len, d_sentence2_len), axis=0)

    nb_train_sentences = d_sentence_len.shape[0]

    saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))
        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))
        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=trainable_discriminator_vars)))

        if restore_path:
            saver.restore(session, restore_path)
        else:
            session.run([discriminator_init_op, discriminator_optimizer_init_op])

            # Initialising pre-trained embeddings
            logger.info('Initialising the embeddings pre-trained vectors ..')
            for token in token_to_embedding:
                token_idx, token_embedding = token_to_index[token], token_to_embedding[token]
                assert embedding_size == len(token_embedding)
                session.run(assign_token_embedding,
                            feed_dict={
                                token_idx_ph: token_idx,
                                token_embedding_ph: token_embedding
                            })
            logger.info('Done!')

            for adversary_projection_step in init_projection_steps:
                session.run([adversary_projection_step])

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        best_dev_acc, best_test_acc = None, None
        discriminator_batch_counter = 0

        for epoch in range(1, nb_epochs + 1):

            if use_adversarial_training:
                for rule_idx, rule_placeholders in rule_id_to_placeholders.items():

                    a_idxs = a_rs.choice(a_batch_size, nb_train_sentences)
                    for a_sentence_ph, a_sentence_len_ph in rule_placeholders:
                        # Select a random sentence from the training set
                        a_sentence_batch = d_sentence[a_idxs]
                        a_sentence_len_batch = d_sentence_len[a_idxs]

                        a_feed_dict[a_sentence_ph] = a_sentence_batch
                        a_feed_dict[a_sentence_len_ph] = a_sentence_len_batch

            for d_epoch in range(1, nb_discriminator_epochs + 1):
                order = rs.permutation(nb_instances)

                sentences1, sentences2 = sentence1[order], sentence2[order]
                sizes1, sizes2 = sentence1_length[order], sentence2_length[order]
                labels = label[order]

                if is_semi_sort:
                    order = util.semi_sort(sizes1, sizes2)
                    sentences1, sentences2 = sentence1[order], sentence2[order]
                    sizes1, sizes2 = sentence1_length[order], sentence2_length[order]
                    labels = label[order]

                loss_values, epoch_loss_values = [], []
                for batch_idx, (batch_start, batch_end) in enumerate(batches):
                    discriminator_batch_counter += 1

                    batch_sentences1, batch_sentences2 = sentences1[batch_start:batch_end], sentences2[batch_start:batch_end]
                    batch_sizes1, batch_sizes2 = sizes1[batch_start:batch_end], sizes2[batch_start:batch_end]
                    batch_labels = labels[batch_start:batch_end]

                    batch_max_size1 = np.max(batch_sizes1)
                    batch_max_size2 = np.max(batch_sizes2)

                    batch_sentences1 = batch_sentences1[:, :batch_max_size1]
                    batch_sentences2 = batch_sentences2[:, :batch_max_size2]

                    batch_feed_dict = {
                        sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1,
                        sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2,
                        label_ph: batch_labels, dropout_keep_prob_ph: dropout_keep_prob
                    }

                    # Adding the adversaries
                    batch_feed_dict.update(a_feed_dict)

                    _, loss_value = session.run([training_step, loss], feed_dict=batch_feed_dict)

                    logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format(epoch, d_epoch, batch_idx, loss_value))

                    cur_batch_size = batch_sentences1.shape[0]
                    loss_values += [loss_value / cur_batch_size]
                    epoch_loss_values += [loss_value / cur_batch_size]

                    for adversary_projection_step in learning_projection_steps:
                        session.run([adversary_projection_step])

                    if discriminator_batch_counter % report_loss_interval == 0:
                        logger.info('Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format(epoch, d_epoch, batch_idx, stats(loss_values)))
                        loss_values = []

                    if discriminator_batch_counter % report_interval == 0:
                        accuracy_args = [sentence1_ph, sentence1_len_ph, sentence2_ph, sentence2_len_ph,
                                         label_ph, dropout_keep_prob_ph, predictions_int, labels_int,
                                         contradiction_idx, entailment_idx, neutral_idx, batch_size]
                        dev_acc, _, _, _ = accuracy(session, dev_dataset, 'Dev', *accuracy_args)
                        test_acc, _, _, _ = accuracy(session, test_dataset, 'Test', *accuracy_args)

                        logger.info('Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}'
                                    .format(epoch, d_epoch, batch_idx, dev_acc * 100, test_acc * 100))

                        if best_dev_acc is None or dev_acc > best_dev_acc:
                            best_dev_acc, best_test_acc = dev_acc, test_acc

                            if save_path:
                                with open('{}_index_to_token.p'.format(save_path), 'wb') as f:
                                    pickle.dump(index_to_token, f)

                                saved_path = saver.save(session, save_path)
                                logger.info('Model saved in file: {}'.format(saved_path))

                        logger.info('Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}'
                                    .format(epoch, d_epoch, batch_idx, best_dev_acc * 100, best_test_acc * 100))

                        for eval_path in eval_paths:
                            eval_path_acc = eutil.evaluate(session, eval_path, label_to_index, token_to_index, predictions, batch_size,
                                                           sentence1_ph, sentence2_ph, sentence1_len_ph, sentence2_len_ph, dropout_keep_prob_ph,
                                                           has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, is_lower=is_lower,
                                                           bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx)
                            logger.info('Epoch {0}/{1}/{2}\tAccuracy on {3} is {4}'.format(epoch, d_epoch, batch_idx,
                                                                                           eval_path, eval_path_acc))

                        if a_losses is not None:
                            t_feed_dict = a_feed_dict
                            if len(t_feed_dict) == 0:
                                t_feed_dict = {
                                    sentence1_ph: sentences1[:1024], sentence1_len_ph: sizes1[:1024],
                                    sentence2_ph: sentences2[:1024], sentence2_len_ph: sizes2[:1024],
                                    dropout_keep_prob_ph: 1.0
                                }
                            a_losses_value = session.run(a_losses, feed_dict=t_feed_dict)

                            a_input_idxs = np.argsort(- a_losses_value)
                            for i in a_input_idxs[:10]:
                                t_sentence1 = t_feed_dict[sentence1_ph][i]
                                t_sentence2 = t_feed_dict[sentence2_ph][i]

                                logger.info('[ {} / {} ] Sentence1: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence1 if x not in [0, 1, 2]])))
                                logger.info('[ {} / {} ] Sentence2: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence2 if x not in [0, 1, 2]])))

                logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format(epoch, d_epoch, stats(epoch_loss_values)))

                if hard_save_path:
                    with open('{}_index_to_token.p'.format(hard_save_path), 'wb') as f:
                        pickle.dump(index_to_token, f)

                    hard_saved_path = saver.save(session, hard_save_path)
                    logger.info('Model saved in file: {}'.format(hard_saved_path))

    logger.info('Training finished.')
예제 #3
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE via Adversarial Sets Regularisation',
        formatter_class=fmt)

    argparser.add_argument('--train',
                           '-t',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--valid',
                           '-v',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_dev.jsonl.gz')
    argparser.add_argument('--test',
                           '-T',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_test.jsonl.gz')

    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='cbilstm',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])
    argparser.add_argument('--optimizer',
                           '-o',
                           action='store',
                           type=str,
                           default='adagrad',
                           choices=['adagrad', 'adam'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           action='store',
                           type=int,
                           default=1024)

    argparser.add_argument('--nb-epochs',
                           '-e',
                           action='store',
                           type=int,
                           default=1000)
    argparser.add_argument('--nb-discriminator-epochs',
                           '-D',
                           action='store',
                           type=int,
                           default=1)
    argparser.add_argument('--nb-adversary-epochs',
                           '-A',
                           action='store',
                           type=int,
                           default=1000)

    argparser.add_argument('--dropout-keep-prob',
                           action='store',
                           type=float,
                           default=1.0)
    argparser.add_argument('--learning-rate',
                           action='store',
                           type=float,
                           default=0.1)
    argparser.add_argument('--clip',
                           '-c',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--nb-words',
                           action='store',
                           type=int,
                           default=None)
    argparser.add_argument('--seed', action='store', type=int, default=0)
    argparser.add_argument('--std-dev',
                           action='store',
                           type=float,
                           default=0.01)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--initialize-embeddings',
                           '-i',
                           action='store',
                           type=str,
                           default=None,
                           choices=['normal', 'uniform'])

    argparser.add_argument('--fixed-embeddings', '-f', action='store_true')
    argparser.add_argument('--normalize-embeddings', '-n', action='store_true')
    argparser.add_argument('--only-use-pretrained-embeddings',
                           '-p',
                           action='store_true',
                           help='Only use pre-trained word embeddings')
    argparser.add_argument('--train-special-token-embeddings',
                           '-s',
                           action='store_true')
    argparser.add_argument('--semi-sort', '-S', action='store_true')

    argparser.add_argument('--save', action='store', type=str, default=None)
    argparser.add_argument('--hard-save',
                           action='store',
                           type=str,
                           default=None)
    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--glove', action='store', type=str, default=None)
    argparser.add_argument('--word2vec',
                           action='store',
                           type=str,
                           default=None)

    argparser.add_argument('--rule0-weight',
                           '-0',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule1-weight',
                           '-1',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule2-weight',
                           '-2',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule3-weight',
                           '-3',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule4-weight',
                           '-4',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule5-weight',
                           '-5',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule6-weight',
                           '-6',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule7-weight',
                           '-7',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--rule8-weight',
                           '-8',
                           action='store',
                           type=float,
                           default=None)

    argparser.add_argument('--adversarial-batch-size',
                           '-B',
                           action='store',
                           type=int,
                           default=32)
    argparser.add_argument('--adversarial-sentence-length',
                           '-L',
                           action='store',
                           type=int,
                           default=10)

    argparser.add_argument('--adversarial-pooling',
                           '-P',
                           default='max',
                           choices=['sum', 'max', 'mean', 'logsumexp'])
    argparser.add_argument(
        '--adversarial-smart-init',
        '-I',
        action='store_true',
        default=False,
        help='Initialize sentence embeddings with actual word embeddings')

    argparser.add_argument(
        '--report',
        '-r',
        default=100,
        type=int,
        help='Number of batches between performance reports')
    argparser.add_argument('--report-loss',
                           default=100,
                           type=int,
                           help='Number of batches between loss reports')

    argparser.add_argument(
        '--memory-limit',
        default=None,
        type=int,
        help=
        'The maximum area (in bytes) of address space which may be taken by the process.'
    )
    argparser.add_argument('--universum', '-U', action='store_true')

    args = argparser.parse_args(argv)

    # Command line arguments
    train_path, valid_path, test_path = args.train, args.valid, args.test

    model_name = args.model
    optimizer_name = args.optimizer

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    nb_epochs = args.nb_epochs
    nb_discriminator_epochs = args.nb_discriminator_epochs
    nb_adversary_epochs = args.nb_adversary_epochs

    dropout_keep_prob = args.dropout_keep_prob
    learning_rate = args.learning_rate
    clip_value = args.clip
    seed = args.seed
    std_dev = args.std_dev

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    initialize_embeddings = args.initialize_embeddings

    is_fixed_embeddings = args.fixed_embeddings
    is_normalize_embeddings = args.normalize_embeddings
    is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings
    is_train_special_token_embeddings = args.train_special_token_embeddings
    is_semi_sort = args.semi_sort

    logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format(
        has_bos, has_eos, has_unk))
    logger.info(
        'is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}'.
        format(is_lower, is_fixed_embeddings, is_normalize_embeddings))
    logger.info(
        'is_only_use_pretrained_embeddings: {}, is_train_special_token_embeddings: {}, is_semi_sort: {}'
        .format(is_only_use_pretrained_embeddings,
                is_train_special_token_embeddings, is_semi_sort))

    save_path = args.save
    hard_save_path = args.hard_save
    restore_path = args.restore

    glove_path = args.glove
    word2vec_path = args.word2vec

    # Experimental RTE regularizers
    rule0_weight = args.rule0_weight
    rule1_weight = args.rule1_weight
    rule2_weight = args.rule2_weight
    rule3_weight = args.rule3_weight
    rule4_weight = args.rule4_weight
    rule5_weight = args.rule5_weight
    rule6_weight = args.rule6_weight
    rule7_weight = args.rule7_weight
    rule8_weight = args.rule8_weight

    adversarial_batch_size = args.adversarial_batch_size
    adversarial_sentence_length = args.adversarial_sentence_length
    adversarial_pooling_name = args.adversarial_pooling
    adversarial_smart_init = args.adversarial_smart_init

    name_to_adversarial_pooling = {
        'sum': tf.reduce_sum,
        'max': tf.reduce_max,
        'mean': tf.reduce_mean,
        'logsumexp': tf.reduce_logsumexp
    }

    report_interval = args.report
    report_loss_interval = args.report_loss

    memory_limit = args.memory_limit
    is_universum = args.universum

    if memory_limit:
        import resource
        soft, hard = resource.getrlimit(resource.RLIMIT_AS)
        logging.info('Current memory limit: {}, {}'.format(soft, hard))
        resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))
        soft, hard = resource.getrlimit(resource.RLIMIT_AS)
        logging.info('New memory limit: {}, {}'.format(soft, hard))

    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path,
                                                   valid_path=valid_path,
                                                   test_path=test_path,
                                                   is_lower=is_lower)

    logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format(
        len(train_is), len(dev_is), len(test_is)))
    all_is = train_is + dev_is + test_is

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3
    start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else
                                             0) + (1 if has_unk else 0)

    if not restore_path:
        # Create a sequence of tokens containing all sentences in the dataset
        token_seq = []
        for instance in all_is:
            token_seq += instance['sentence1_parse_tokens'] + instance[
                'sentence2_parse_tokens']

        token_set = set(token_seq)
        allowed_words = None
        if is_only_use_pretrained_embeddings:
            assert (glove_path is not None) or (word2vec_path is not None)
            if glove_path:
                logger.info('Loading GloVe words from {}'.format(glove_path))
                assert os.path.isfile(glove_path)
                allowed_words = load_glove_words(path=glove_path,
                                                 words=token_set)
            elif word2vec_path:
                logger.info(
                    'Loading word2vec words from {}'.format(word2vec_path))
                assert os.path.isfile(word2vec_path)
                allowed_words = load_word2vec_words(path=word2vec_path,
                                                    words=token_set)
            logger.info('Number of allowed words: {}'.format(
                len(allowed_words)))

        # Count the number of occurrences of each token
        token_counts = dict()
        for token in token_seq:
            if (allowed_words is None) or (token in allowed_words):
                if token not in token_counts:
                    token_counts[token] = 0
                token_counts[token] += 1

        # Sort the tokens according to their frequency and lexicographic ordering
        sorted_vocabulary = sorted(token_counts.keys(),
                                   key=lambda t: (-token_counts[t], t))

        index_to_token = {
            index: token
            for index, token in enumerate(sorted_vocabulary, start=start_idx)
        }
    else:
        with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
            index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx, none_idx = 0, 1, 2, 3
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    if is_universum:
        label_to_index['none'] = none_idx

    max_len = None
    optimizer_name_to_class = {
        'adagrad': tf.train.AdagradOptimizer,
        'adam': tf.train.AdamOptimizer
    }

    optimizer_class = optimizer_name_to_class[optimizer_name]
    assert optimizer_class

    optimizer = optimizer_class(learning_rate=learning_rate)

    args = dict(has_bos=has_bos,
                has_eos=has_eos,
                has_unk=has_unk,
                bos_idx=bos_idx,
                eos_idx=eos_idx,
                unk_idx=unk_idx,
                max_len=max_len)

    train_dataset = util.instances_to_dataset(train_is, token_to_index,
                                              label_to_index, **args)
    dev_dataset = util.instances_to_dataset(dev_is, token_to_index,
                                            label_to_index, **args)
    test_dataset = util.instances_to_dataset(test_is, token_to_index,
                                             label_to_index, **args)

    sentence1 = train_dataset['sentence1']
    sentence1_length = train_dataset['sentence1_length']
    sentence2 = train_dataset['sentence2']
    sentence2_length = train_dataset['sentence2_length']
    label = train_dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label')

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    nb_words = len(token_to_index)
    nb_special_tokens = vocab_size - nb_words

    token_to_embedding = dict()
    if not restore_path:
        if glove_path:
            logger.info(
                'Loading GloVe word embeddings from {}'.format(glove_path))
            assert os.path.isfile(glove_path)
            token_to_embedding = load_glove(glove_path, token_set)
        elif word2vec_path:
            logger.info('Loading word2vec word embeddings from {}'.format(
                word2vec_path))
            assert os.path.isfile(word2vec_path)
            token_to_embedding = load_word2vec(word2vec_path, token_set)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        if initialize_embeddings == 'normal':
            logger.info('Initializing the embeddings with 𝓝(0, 1)')
            embedding_initializer = tf.random_normal_initializer(0.0, 1.0)
        elif initialize_embeddings == 'uniform':
            logger.info('Initializing the embeddings with 𝒰(-1, 1)')
            embedding_initializer = tf.random_uniform_initializer(minval=-1.0,
                                                                  maxval=1.0)
        else:
            logger.info(
                'Initializing the embeddings with Xavier initialization')
            embedding_initializer = tf.contrib.layers.xavier_initializer()

        if is_train_special_token_embeddings:
            embedding_layer_special = tf.get_variable(
                'special_embeddings',
                shape=[nb_special_tokens, embedding_size],
                initializer=embedding_initializer,
                trainable=True)
            embedding_layer_words = tf.get_variable(
                'word_embeddings',
                shape=[nb_words, embedding_size],
                initializer=embedding_initializer,
                trainable=not is_fixed_embeddings)
            embedding_layer = tf.concat(
                values=[embedding_layer_special, embedding_layer_words],
                axis=0)
        else:
            embedding_layer = tf.get_variable(
                'embeddings',
                shape=[vocab_size, embedding_size],
                initializer=embedding_initializer,
                trainable=not is_fixed_embeddings)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32,
                                              name='dropout_keep_prob')

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        if is_universum:
            model_kwargs['nb_classes'] = 4

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = std_dev

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        predictions = tf.argmax(logits, axis=1, name='predictions')

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ph)
        loss = tf.reduce_mean(losses)

        if rule0_weight:
            loss += rule0_weight * contradiction_symmetry_l2(
                model_class, model_kwargs, contradiction_idx=contradiction_idx)

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)
    discriminator_init_op = tf.variables_initializer(discriminator_vars)

    trainable_discriminator_vars = list(discriminator_vars)
    if is_fixed_embeddings:
        if is_train_special_token_embeddings:
            trainable_discriminator_vars.remove(embedding_layer_words)
        else:
            trainable_discriminator_vars.remove(embedding_layer)

    discriminator_optimizer_scope_name = 'discriminator_optimizer'
    with tf.variable_scope(discriminator_optimizer_scope_name):
        if clip_value:
            gradients, v = zip(*optimizer.compute_gradients(
                loss, var_list=trainable_discriminator_vars))
            gradients, _ = tf.clip_by_global_norm(gradients, clip_value)
            training_step = optimizer.apply_gradients(zip(gradients, v))
        else:
            training_step = optimizer.minimize(
                loss, var_list=trainable_discriminator_vars)

    discriminator_optimizer_vars = tfutil.get_variables_in_scope(
        discriminator_optimizer_scope_name)
    discriminator_optimizer_init_op = tf.variables_initializer(
        discriminator_optimizer_vars)

    token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx')
    token_embedding_ph = tf.placeholder(dtype=tf.float32,
                                        shape=[None],
                                        name='word_embedding')

    if is_train_special_token_embeddings:
        assign_token_embedding = embedding_layer_words[
            token_idx_ph - nb_special_tokens, :].assign(token_embedding_ph)
    else:
        assign_token_embedding = embedding_layer[token_idx_ph, :].assign(
            token_embedding_ph)

    init_projection_steps = []
    learning_projection_steps = []

    if is_normalize_embeddings:
        if is_train_special_token_embeddings:
            special_embeddings_projection = constraints.unit_sphere(
                embedding_layer_special, norm=1.0)
            word_embeddings_projection = constraints.unit_sphere(
                embedding_layer_words, norm=1.0)

            init_projection_steps += [special_embeddings_projection]
            init_projection_steps += [word_embeddings_projection]

            learning_projection_steps += [special_embeddings_projection]
            if not is_fixed_embeddings:
                learning_projection_steps += [word_embeddings_projection]
        else:
            embeddings_projection = constraints.unit_sphere(embedding_layer,
                                                            norm=1.0)
            init_projection_steps += [embeddings_projection]

            if not is_fixed_embeddings:
                learning_projection_steps += [embeddings_projection]

    predictions_int = tf.cast(predictions, tf.int32)
    labels_int = tf.cast(label_ph, tf.int32)

    use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight

    if use_adversarial_training:
        adversary_scope_name = discriminator_scope_name
        with tf.variable_scope(adversary_scope_name):
            adversarial = AdversarialSets(
                model_class=model_class,
                model_kwargs=model_kwargs,
                embedding_size=embedding_size,
                scope_name='adversary',
                batch_size=adversarial_batch_size,
                sequence_length=adversarial_sentence_length,
                entailment_idx=entailment_idx,
                contradiction_idx=contradiction_idx,
                neutral_idx=neutral_idx)

            adversary_loss = tf.constant(0.0, dtype=tf.float32)
            adversary_vars = []

            adversarial_pooling = name_to_adversarial_pooling[
                adversarial_pooling_name]

            if rule1_weight:
                rule1_loss, rule1_vars = adversarial.rule1_loss()
                adversary_loss += rule1_weight * adversarial_pooling(
                    rule1_loss)
                adversary_vars += rule1_vars
            if rule2_weight:
                rule2_loss, rule2_vars = adversarial.rule2_loss()
                adversary_loss += rule2_weight * adversarial_pooling(
                    rule2_loss)
                adversary_vars += rule2_vars
            if rule3_weight:
                rule3_loss, rule3_vars = adversarial.rule3_loss()
                adversary_loss += rule3_weight * adversarial_pooling(
                    rule3_loss)
                adversary_vars += rule3_vars
            if rule4_weight:
                rule4_loss, rule4_vars = adversarial.rule4_loss()
                adversary_loss += rule4_weight * adversarial_pooling(
                    rule4_loss)
                adversary_vars += rule4_vars
            if rule5_weight:
                rule5_loss, rule5_vars = adversarial.rule5_loss()
                adversary_loss += rule5_weight * adversarial_pooling(
                    rule5_loss)
                adversary_vars += rule5_vars
            if rule6_weight:
                rule6_loss, rule6_vars = adversarial.rule6_loss()
                adversary_loss += rule6_weight * adversarial_pooling(
                    rule6_loss)
                adversary_vars += rule6_vars
            if rule7_weight:
                rule7_loss, rule7_vars = adversarial.rule7_loss()
                adversary_loss += rule7_weight * adversarial_pooling(
                    rule7_loss)
                adversary_vars += rule7_vars
            if rule8_weight:
                rule8_loss, rule8_vars = adversarial.rule8_loss()
                adversary_loss += rule8_weight * adversarial_pooling(
                    rule8_loss)
                adversary_vars += rule8_vars

            loss += adversary_loss

            assert len(adversary_vars) > 0
            for adversary_var in adversary_vars:
                assert adversary_var.name.startswith(
                    'discriminator/adversary/rule')

            adversary_var_to_assign_op = dict()
            adversary_var_value_ph = tf.placeholder(dtype=tf.float32,
                                                    shape=[None, None, None],
                                                    name='adversary_var_value')
            for a_var in adversary_vars:
                adversary_var_to_assign_op[a_var] = a_var.assign(
                    adversary_var_value_ph)

        adversary_init_op = tf.variables_initializer(adversary_vars)

        adv_opt_scope_name = 'adversary_optimizer'
        with tf.variable_scope(adv_opt_scope_name):
            adversary_optimizer = optimizer_class(learning_rate=learning_rate)
            adversary_training_step = adversary_optimizer.minimize(
                -adversary_loss, var_list=adversary_vars)

            adversary_optimizer_vars = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope=adv_opt_scope_name)
            adversary_optimizer_init_op = tf.variables_initializer(
                adversary_optimizer_vars)

        logger.info(
            'Adversarial Batch Size: {}'.format(adversarial_batch_size))

        adversary_projection_steps = []
        for var in adversary_vars:
            if is_normalize_embeddings:
                unit_sphere_adversarial_embeddings = constraints.unit_sphere(
                    var, norm=1.0, axis=-1)
                adversary_projection_steps += [
                    unit_sphere_adversarial_embeddings
                ]

            assert adversarial_batch_size == var.get_shape()[0].value

            def token_init_op(_var, _token_idx, target_idx):
                token_emb = tf.nn.embedding_lookup(embedding_layer, _token_idx)
                tiled_token_emb = tf.tile(tf.expand_dims(token_emb, 0),
                                          (adversarial_batch_size, 1))
                return _var[:, target_idx, :].assign(tiled_token_emb)

            if has_bos:
                adversary_projection_steps += [token_init_op(var, bos_idx, 0)]

    saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars,
                           max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    # session_config.log_device_placement = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))
        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(
                var_list=trainable_discriminator_vars)))

        if use_adversarial_training:
            session.run([adversary_init_op, adversary_optimizer_init_op])

        if restore_path:
            saver.restore(session, restore_path)
        else:
            session.run(
                [discriminator_init_op, discriminator_optimizer_init_op])

            # Initialising pre-trained embeddings
            logger.info('Initialising the embeddings pre-trained vectors ..')
            for token in token_to_embedding:
                token_idx, token_embedding = token_to_index[
                    token], token_to_embedding[token]
                assert embedding_size == len(token_embedding)
                session.run(assign_token_embedding,
                            feed_dict={
                                token_idx_ph: token_idx,
                                token_embedding_ph: token_embedding
                            })
            logger.info('Done!')

            for adversary_projection_step in init_projection_steps:
                session.run([adversary_projection_step])

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        best_dev_acc, best_test_acc = None, None
        discriminator_batch_counter = 0

        for epoch in range(1, nb_epochs + 1):

            for d_epoch in range(1, nb_discriminator_epochs + 1):
                order = random_state.permutation(nb_instances)

                sentences1, sentences2 = sentence1[order], sentence2[order]
                sizes1, sizes2 = sentence1_length[order], sentence2_length[
                    order]
                labels = label[order]

                if is_semi_sort:
                    order = util.semi_sort(sizes1, sizes2)
                    sentences1, sentences2 = sentence1[order], sentence2[order]
                    sizes1, sizes2 = sentence1_length[order], sentence2_length[
                        order]
                    labels = label[order]

                loss_values, epoch_loss_values = [], []
                for batch_idx, (batch_start, batch_end) in enumerate(batches):
                    discriminator_batch_counter += 1

                    batch_sentences1, batch_sentences2 = sentences1[
                        batch_start:batch_end], sentences2[
                            batch_start:batch_end]
                    batch_sizes1, batch_sizes2 = sizes1[
                        batch_start:batch_end], sizes2[batch_start:batch_end]
                    batch_labels = labels[batch_start:batch_end]

                    batch_max_size1 = np.max(batch_sizes1)
                    batch_max_size2 = np.max(batch_sizes2)

                    batch_sentences1 = batch_sentences1[:, :batch_max_size1]
                    batch_sentences2 = batch_sentences2[:, :batch_max_size2]

                    batch_feed_dict = {
                        sentence1_ph: batch_sentences1,
                        sentence1_len_ph: batch_sizes1,
                        sentence2_ph: batch_sentences2,
                        sentence2_len_ph: batch_sizes2,
                        label_ph: batch_labels,
                        dropout_keep_prob_ph: dropout_keep_prob
                    }

                    _, loss_value = session.run([training_step, loss],
                                                feed_dict=batch_feed_dict)

                    logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format(
                        epoch, d_epoch, batch_idx, loss_value))

                    cur_batch_size = batch_sentences1.shape[0]
                    loss_values += [loss_value / cur_batch_size]
                    epoch_loss_values += [loss_value / cur_batch_size]

                    for adversary_projection_step in learning_projection_steps:
                        session.run([adversary_projection_step])

                    if discriminator_batch_counter % report_loss_interval == 0:
                        logger.info(
                            'Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format(
                                epoch, d_epoch, batch_idx, stats(loss_values)))
                        loss_values = []

                    if discriminator_batch_counter % report_interval == 0:
                        accuracy_args = [
                            sentence1_ph, sentence1_len_ph, sentence2_ph,
                            sentence2_len_ph, label_ph, dropout_keep_prob_ph,
                            predictions_int, labels_int, contradiction_idx,
                            entailment_idx, neutral_idx, batch_size
                        ]
                        dev_acc, _, _, _ = accuracy(session, dev_dataset,
                                                    'Dev', *accuracy_args)
                        test_acc, _, _, _ = accuracy(session, test_dataset,
                                                     'Test', *accuracy_args)

                        logger.info(
                            'Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}'
                            .format(epoch, d_epoch, batch_idx, dev_acc * 100,
                                    test_acc * 100))

                        if best_dev_acc is None or dev_acc > best_dev_acc:
                            best_dev_acc, best_test_acc = dev_acc, test_acc

                            if save_path:
                                with open(
                                        '{}_index_to_token.p'.format(
                                            save_path), 'wb') as f:
                                    pickle.dump(index_to_token, f)

                                saved_path = saver.save(session, save_path)
                                logger.info('Model saved in file: {}'.format(
                                    saved_path))

                        logger.info(
                            'Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}'
                            .format(epoch, d_epoch, batch_idx,
                                    best_dev_acc * 100, best_test_acc * 100))

                logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format(
                    epoch, d_epoch, stats(epoch_loss_values)))

                if hard_save_path:
                    with open('{}_index_to_token.p'.format(hard_save_path),
                              'wb') as f:
                        pickle.dump(index_to_token, f)

                    hard_saved_path = saver.save(session, hard_save_path)
                    logger.info(
                        'Model saved in file: {}'.format(hard_saved_path))

            if use_adversarial_training:
                session.run([adversary_init_op, adversary_optimizer_init_op])

                if adversarial_smart_init:
                    _token_indices = np.array(sorted(index_to_token.keys()))

                    for a_var in adversary_vars:
                        # Create a [batch size, sentence length, embedding size] NumPy tensor of sentence embeddings
                        a_word_idx = _token_indices[random_state.randint(
                            low=0,
                            high=len(_token_indices),
                            size=[
                                adversarial_batch_size,
                                adversarial_sentence_length
                            ])]
                        np_embedding_layer = session.run(embedding_layer)
                        np_adversarial_embeddings = np_embedding_layer[
                            a_word_idx]
                        assert np_adversarial_embeddings.shape == (
                            adversarial_batch_size,
                            adversarial_sentence_length, embedding_size)

                        assert a_var in adversary_var_to_assign_op
                        assign_op = adversary_var_to_assign_op[a_var]

                        logger.info(
                            'Clever initialization of the adversarial embeddings ..'
                        )
                        session.run(assign_op,
                                    feed_dict={
                                        adversary_var_value_ph:
                                        np_adversarial_embeddings
                                    })

                for a_epoch in range(1, nb_adversary_epochs + 1):
                    adversary_feed_dict = {dropout_keep_prob_ph: 1.0}
                    _, adversary_loss_value = session.run(
                        [adversary_training_step, adversary_loss],
                        feed_dict=adversary_feed_dict)
                    logger.info('Adversary Epoch {0}/{1}\tLoss: {2}'.format(
                        epoch, a_epoch, adversary_loss_value))

                    for adversary_projection_step in adversary_projection_steps:
                        session.run(adversary_projection_step)

    logger.info('Training finished.')
예제 #4
0
def train(args):
    data_loader = TextLoader(args.data, args.batch_size, args.seq_length,
                             args.input_encoding)
    vocab_size = data_loader.vocab_size

    config = {
        'model': args.model,
        'seq_length': args.seq_length,
        'batch_size': args.batch_size,
        'vocab_size': vocab_size,
        'embedding_size': args.embedding_size,
        'rnn_size': args.rnn_size,
        'num_layers': args.num_layers
    }

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(args.init_from)
        assert os.path.isfile(os.path.join(args.init_from, "config.pkl"))
        assert os.path.isfile(os.path.join(args.init_from, "words_vocab.pkl"))

        ckpt = tf.train.get_checkpoint_state(args.init_from)

        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = pickle.load(f)

        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args)[checkme], "Disagreement on '{}'".format(checkme)

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
            saved_words, saved_vocab = pickle.load(f)

        assert saved_words == data_loader.words, "Data and loaded model disagree on word set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with open(os.path.join(args.save, 'config.pkl'), 'wb') as f:
        pickle.dump(config, f)

    with open(os.path.join(args.save, 'words_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.words, data_loader.vocab), f)

    embedding_layer = tf.get_variable(
        'embeddings',
        shape=[vocab_size, args.embedding_size],
        initializer=tf.contrib.layers.xavier_initializer(),
        trainable=False)

    model = LanguageModel(model=config['model'],
                          seq_length=config['seq_length'],
                          batch_size=config['batch_size'],
                          rnn_size=config['rnn_size'],
                          num_layers=config['num_layers'],
                          vocab_size=config['vocab_size'],
                          embedding_layer=embedding_layer,
                          infer=False)

    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(model.cost, tvars),
                                      args.grad_clip)

    optimizer = tf.train.AdagradOptimizer(args.learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver(tf.global_variables())

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        session.run(init_op)

        # restore model
        if args.init_from is not None:
            saver.restore(session, ckpt.model_checkpoint_path)

        for epoch_id in range(0, args.num_epochs):
            logger.debug('Epoch: {}'.format(epoch_id))

            data_loader.reset_batch_pointer()
            state = session.run(model.initial_state)

            for batch_id in range(data_loader.pointer,
                                  data_loader.num_batches):
                logger.debug('Epoch: {}\tBatch: {}'.format(epoch_id, batch_id))
                x, y = data_loader.next_batch()

                feed_dict = {
                    model.input_data: x,
                    model.targets: y,
                    model.initial_state: state
                }
                train_loss, state, _ = session.run(
                    [model.cost, model.final_state, train_op],
                    feed_dict=feed_dict)

                if (epoch_id * data_loader.num_batches +
                        batch_id) % args.batch_size == 0:
                    logger.info("{}/{} (epoch {}), train_loss = {:.3f}".format(
                        epoch_id * data_loader.num_batches + batch_id,
                        args.num_epochs * data_loader.num_batches, epoch_id,
                        train_loss))

                if (epoch_id * data_loader.num_batches + batch_id) % args.save_every == 0 or \
                        (epoch_id == args.num_epochs - 1 and batch_id == data_loader.num_batches - 1):

                    checkpoint_path = os.path.join(args.save, 'model.ckpt')
                    saver.save(session,
                               checkpoint_path,
                               global_step=epoch_id * data_loader.num_batches +
                               batch_id)
                    logger.info("model saved to {}".format(checkpoint_path))
def sample(args):
    vocabulary_path = args.vocabulary
    checkpoint_path = args.checkpoint

    with open(vocabulary_path, 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    with open(os.path.join(args.save, 'config.json'), 'r') as f:
        config = json.load(f)

    logger.info('Config: {}'.format(str(config)))

    vocab_size = len(token_to_index)

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable(
            'embeddings',
            shape=[vocab_size, config['embedding_size']],
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=False)

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        model = LanguageModel(model=config['model'],
                              seq_length=config['seq_length'],
                              batch_size=config['batch_size'],
                              rnn_size=config['rnn_size'],
                              num_layers=config['num_layers'],
                              vocab_size=config['vocab_size'],
                              embedding_layer=embedding_layer,
                              infer=True)

    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver(tf.global_variables())
    emb_saver = tf.train.Saver([embedding_layer], max_to_keep=1)

    logger.info('Creating the session ..')

    with tf.Session() as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        session.run(init_op)

        emb_saver.restore(session, checkpoint_path)

        ckpt = tf.train.get_checkpoint_state(args.save)

        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(session, ckpt.model_checkpoint_path)

            for _ in range(10):
                sample_value = model.sample(session, index_to_token,
                                            token_to_index, args.nb_words,
                                            args.prime, args.sample, args.pick,
                                            args.width)
                logger.info('Sample: \"{}\"'.format(sample_value))
예제 #6
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt)

    argparser.add_argument('--data', '-d', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument('--model', '-m', action='store', type=str, default='ff-dam',
                           choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size', action='store', type=int, default=300)
    argparser.add_argument('--representation-size', action='store', type=int, default=200)

    argparser.add_argument('--batch-size', action='store', type=int, default=32)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token')
    argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus')

    argparser.add_argument('--restore', action='store', type=str, default=None)
    argparser.add_argument('--lm', action='store', type=str, default='models/lm/')

    argparser.add_argument('--corrupt', '-c', action='store_true', default=False,
                           help='Corrupt examples so to maximise their inconsistency')
    argparser.add_argument('--most-violating', '-M', action='store_true', default=False,
                           help='Show most violating examples')

    argparser.add_argument('--epsilon', '-e', action='store', type=float, default=1e-4)
    argparser.add_argument('--lambda-weight', '-L', action='store', type=float, default=1.0)

    argparser.add_argument('--inconsistency', '-i', action='store', type=str, default='contradiction')

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore
    lm_path = args.lm

    is_corrupt = args.corrupt
    is_most_violating = args.most_violating

    epsilon = args.epsilon
    lambda_w = args.lambda_weight

    inconsistency_name = args.inconsistency

    iloss = None
    if inconsistency_name == 'contradiction':
        iloss = contradiction_loss
    elif inconsistency_name == 'neutral':
        iloss = neutral_loss
    elif inconsistency_name == 'entailment':
        iloss = entailment_loss

    assert iloss is not None

    np.random.seed(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    data_is, _, _ = util.SNLI.generate(train_path=data_path, valid_path=None, test_path=None, is_lower=is_lower)
    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    global index_to_token, token_to_index
    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    with open('{}/config.json'.format(lm_path), 'r') as f:
        config = json.load(f)

    seq_length = 1
    lm_batch_size = batch_size
    rnn_size = config['rnn_size']
    num_layers = config['num_layers']

    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None

    args = dict(
        has_bos=has_bos, has_eos=has_eos, has_unk=has_unk,
        bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx,
        max_len=max_len)

    dataset = util.instances_to_dataset(data_is, token_to_index, label_to_index, **args)

    sentence1, sentence1_length = dataset['sentence1'], dataset['sentence1_length']
    sentence2, sentence2_length = dataset['sentence2'], dataset['sentence2_length']
    label = dataset['label']

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size], trainable=False)
        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2)

        model_kwargs = dict(
            sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph,
            sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph,
            representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = 0.01

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }
        model_class = mode_name_to_class[model_name]
        assert model_class is not None

        model = model_class(**model_kwargs)
        logits = model()

        global probabilities
        probabilities = tf.nn.softmax(logits)

        predictions = tf.argmax(logits, axis=1, name='predictions')

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        cell_fn = rnn.BasicLSTMCell
        cells = [cell_fn(rnn_size) for _ in range(num_layers)]

        global lm_cell
        lm_cell = rnn.MultiRNNCell(cells)

        global lm_input_data_ph, lm_targets_ph, lm_initial_state
        lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length], name='input_data')
        lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length], name='targets')
        lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32)

        with tf.variable_scope('rnnlm'):
            lm_W = tf.get_variable(name='W', shape=[rnn_size, vocab_size],
                                   initializer=tf.contrib.layers.xavier_initializer())

            lm_b = tf.get_variable(name='b', shape=[vocab_size],
                                   initializer=tf.zeros_initializer())

            lm_emb_lookup = tf.nn.embedding_lookup(embedding_layer, lm_input_data_ph)
            lm_emb_projection = tf.contrib.layers.fully_connected(inputs=lm_emb_lookup, num_outputs=rnn_size,
                                                                  weights_initializer=tf.contrib.layers.xavier_initializer(),
                                                                  biases_initializer=tf.zeros_initializer())

            lm_inputs = tf.split(lm_emb_projection, seq_length, 1)
            lm_inputs = [tf.squeeze(input_, [1]) for input_ in lm_inputs]

        lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(decoder_inputs=lm_inputs, initial_state=lm_initial_state,
                                                               cell=lm_cell, loop_function=None, scope='rnnlm')

        lm_output = tf.reshape(tf.concat(lm_outputs, 1), [-1, rnn_size])

        lm_logits = tf.matmul(lm_output, lm_W) + lm_b
        lm_probabilities = tf.nn.softmax(lm_logits)

        global lm_loss, lm_cost, lm_final_state
        lm_loss = legacy_seq2seq.sequence_loss_by_example(logits=[lm_logits], targets=[tf.reshape(lm_targets_ph, [-1])],
                                                          weights=[tf.ones([lm_batch_size * seq_length])])
        lm_cost = tf.reduce_sum(lm_loss) / lm_batch_size / seq_length
        lm_final_state = lm_last_state

    discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name)
    lm_vars = tfutil.get_variables_in_scope(lm_scope_name)

    predictions_int = tf.cast(predictions, tf.int32)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
    lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    global session
    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))

        saver.restore(session, restore_path)

        lm_ckpt = tf.train.get_checkpoint_state(lm_path)
        lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        order = np.arange(nb_instances)

        sentences1 = sentence1[order]
        sentences2 = sentence2[order]

        sizes1 = sentence1_length[order]
        sizes2 = sentence2_length[order]

        labels = label[order]

        logger.info('Number of examples: {}'.format(labels.shape[0]))

        predictions_int_value = []
        c_losses, e_losses, n_losses = [], [], []

        for batch_idx, (batch_start, batch_end) in enumerate(batches):
            batch_sentences1 = sentences1[batch_start:batch_end]
            batch_sentences2 = sentences2[batch_start:batch_end]

            batch_sizes1 = sizes1[batch_start:batch_end]
            batch_sizes2 = sizes2[batch_start:batch_end]

            batch_feed_dict = {
                sentence1_ph: batch_sentences1,
                sentence1_len_ph: batch_sizes1,

                sentence2_ph: batch_sentences2,
                sentence2_len_ph: batch_sizes2,

                dropout_keep_prob_ph: 1.0
            }

            batch_predictions_int = session.run(predictions_int, feed_dict=batch_feed_dict)
            predictions_int_value += batch_predictions_int.tolist()

            batch_c_loss = contradiction_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            c_losses += batch_c_loss.tolist()

            batch_e_loss = entailment_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            e_losses += batch_e_loss.tolist()

            batch_n_loss = neutral_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2)
            n_losses += batch_n_loss.tolist()

            if is_corrupt:
                search(sentences1=batch_sentences1, sizes1=batch_sizes1,
                       sentences2=batch_sentences2, sizes2=batch_sizes2,
                       batch_size=batch_size, epsilon=epsilon, lambda_w=lambda_w,
                       inconsistency_loss=iloss)

        train_accuracy_value = np.mean(labels == np.array(predictions_int_value))
        logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value))

        if is_most_violating:
            c_ranking = np.argsort(np.array(c_losses))[::-1]
            assert c_ranking.shape[0] == len(data_is)

            for i in range(min(1024, c_ranking.shape[0])):
                idx = c_ranking[i]
                print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], c_losses[idx]))
                print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], c_losses[idx]))

            e_ranking = np.argsort(np.array(e_losses))[::-1]
            assert e_ranking.shape[0] == len(data_is)

            for i in range(min(1024, e_ranking.shape[0])):
                idx = e_ranking[i]
                print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], e_losses[idx]))
                print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], e_losses[idx]))

            n_ranking = np.argsort(np.array(n_losses))[::-1]
            assert n_ranking.shape[0] == len(data_is)

            for i in range(min(1024, n_ranking.shape[0])):
                idx = n_ranking[i]
                print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], n_losses[idx]))
                print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], n_losses[idx]))
예제 #7
0
def test_lm_overfit():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data',
                        '-d',
                        type=str,
                        default='data/snli/snli_1.0_test.jsonl.gz')

    parser.add_argument('--vocabulary',
                        type=str,
                        default='models/snli/dam_1/dam_1_index_to_token.p')
    parser.add_argument('--checkpoint',
                        type=str,
                        default='models/snli/dam_1/dam_1')

    parser.add_argument('--save',
                        type=str,
                        default='./models/lm/',
                        help='directory to store checkpointed models')

    parser.add_argument('--embedding-size',
                        type=int,
                        default=300,
                        help='embedding size')
    parser.add_argument('--rnn-size',
                        type=int,
                        default=256,
                        help='size of RNN hidden state')
    parser.add_argument('--num-layers',
                        type=int,
                        default=1,
                        help='number of layers in the RNN')

    parser.add_argument('--model',
                        type=str,
                        default='lstm',
                        help='rnn, gru, or lstm')

    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='minibatch size')
    parser.add_argument('--seq-length',
                        type=int,
                        default=8,
                        help='RNN sequence length')
    parser.add_argument('--num-epochs',
                        type=int,
                        default=100,
                        help='number of epochs')

    parser.add_argument('--report-every',
                        '-r',
                        type=int,
                        default=10,
                        help='report loss frequency')
    parser.add_argument('--save-every',
                        '-s',
                        type=int,
                        default=100,
                        help='save frequency')

    parser.add_argument('--grad-clip',
                        type=float,
                        default=5.,
                        help='clip gradients at this value')
    parser.add_argument('--learning-rate',
                        '--lr',
                        type=float,
                        default=0.001,
                        help='learning rate')

    args = parser.parse_args('')

    vocabulary_path = args.vocabulary
    checkpoint_path = args.checkpoint

    with open(vocabulary_path, 'rb') as f:
        index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    logger.info('Loading the dataset ..')

    loader = SNLILoader(path=args.data,
                        token_to_index=token_to_index,
                        batch_size=args.batch_size,
                        seq_length=args.seq_length)
    vocab_size = len(token_to_index)

    config = {
        'model': args.model,
        'seq_length': args.seq_length,
        'batch_size': args.batch_size,
        'vocab_size': vocab_size,
        'embedding_size': args.embedding_size,
        'rnn_size': args.rnn_size,
        'num_layers': args.num_layers
    }

    logger.info('Generating the computational graph ..')

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable(
            'embeddings',
            shape=[vocab_size + 3, args.embedding_size],
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=False)

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        model = LanguageModel(model=config['model'],
                              seq_length=config['seq_length'],
                              batch_size=config['batch_size'],
                              rnn_size=config['rnn_size'],
                              num_layers=config['num_layers'],
                              vocab_size=config['vocab_size'],
                              embedding_layer=embedding_layer,
                              infer=False)

    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(model.cost, tvars),
                                      args.grad_clip)

    optimizer = tf.train.AdamOptimizer(args.learning_rate)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    init_op = tf.global_variables_initializer()

    emb_saver = tf.train.Saver([embedding_layer], max_to_keep=1)

    logger.info('Creating the session ..')

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))
        session.run(init_op)

        emb_saver.restore(session, checkpoint_path)

        loss_values = []

        for epoch_id in range(0, args.num_epochs):
            logger.debug('Epoch: {}'.format(epoch_id))

            loader.reset_batch_pointer()
            state = session.run(model.initial_state)

            for batch_id in range(loader.pointer, loader.num_batches):
                x, y = loader.next_batch()

                feed_dict = {
                    model.input_data: x,
                    model.targets: y,
                    model.initial_state: state
                }

                loss_value, state, _ = session.run(
                    [model.cost, model.final_state, train_op],
                    feed_dict=feed_dict)
                loss_values += [loss_value]

                if (epoch_id * loader.num_batches +
                        batch_id) % args.report_every == 0:
                    a = epoch_id * loader.num_batches + batch_id
                    b = args.num_epochs * loader.num_batches
                    logger.info("{}/{} (epoch {}), loss = {}".format(
                        a, b, epoch_id, stats(loss_values)))
                    loss_values = []
예제 #8
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE via Adversarial Sets Regularisation',
        formatter_class=fmt)

    argparser.add_argument('--data',
                           '-d',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='ff-dam',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           '-b',
                           action='store',
                           type=int,
                           default=32)
    argparser.add_argument('--seq-length', action='store', type=int, default=5)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--restore', action='store', type=str, default=None)
    argparser.add_argument('--lm',
                           action='store',
                           type=str,
                           default='models/lm/')

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore
    lm_path = args.lm

    np.random.seed(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    data_is, _, _ = util.SNLI.generate(train_path=data_path,
                                       valid_path=None,
                                       test_path=None,
                                       is_lower=is_lower)
    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    global index_to_token, token_to_index
    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    with open('{}/config.json'.format(lm_path), 'r') as f:
        config = json.load(f)

    seq_length = 1
    lm_batch_size = batch_size
    rnn_size = config['rnn_size']
    num_layers = config['num_layers']

    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx
    }
    max_len = None

    args = dict(has_bos=has_bos,
                has_eos=has_eos,
                has_unk=has_unk,
                bos_idx=bos_idx,
                eos_idx=eos_idx,
                unk_idx=unk_idx,
                max_len=max_len)

    dataset = util.instances_to_dataset(data_is, token_to_index,
                                        label_to_index, **args)

    sentence1 = dataset['sentence1']
    sentence1_length = dataset['sentence1_length']

    sentence2 = dataset['sentence2'],
    sentence2_length = dataset['sentence2_length']

    label = dataset['label']

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable('embeddings',
                                          shape=[vocab_size, embedding_size],
                                          trainable=False)
        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = 0.01

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }
        model_class = mode_name_to_class[model_name]
        assert model_class is not None

        model = model_class(**model_kwargs)
        logits = model()

        global probabilities
        probabilities = tf.nn.softmax(logits)

        predictions = tf.argmax(logits, axis=1, name='predictions')

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name):
        cell_fn = rnn.BasicLSTMCell
        cells = [cell_fn(rnn_size) for _ in range(num_layers)]

        global lm_cell
        lm_cell = rnn.MultiRNNCell(cells)

        global lm_input_data_ph, lm_targets_ph, lm_initial_state
        lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length],
                                          name='input_data')
        lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length],
                                       name='targets')

        lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32)

        with tf.variable_scope('rnnlm'):
            lm_W = tf.get_variable(
                name='W',
                shape=[rnn_size, vocab_size],
                initializer=tf.contrib.layers.xavier_initializer())

            lm_b = tf.get_variable(name='b',
                                   shape=[vocab_size],
                                   initializer=tf.zeros_initializer())

            lm_emb_lookup = tf.nn.embedding_lookup(embedding_layer,
                                                   lm_input_data_ph)

            lm_emb_projection = tf.contrib.layers.fully_connected(
                inputs=lm_emb_lookup,
                num_outputs=rnn_size,
                weights_initializer=tf.contrib.layers.xavier_initializer(),
                biases_initializer=tf.zeros_initializer())

            lm_inputs = tf.split(lm_emb_projection, seq_length, 1)
            lm_inputs = [tf.squeeze(input_, [1]) for input_ in lm_inputs]

        lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(
            decoder_inputs=lm_inputs,
            initial_state=lm_initial_state,
            cell=lm_cell,
            loop_function=None,
            scope='rnnlm')

        lm_output = tf.reshape(tf.concat(lm_outputs, 1), [-1, rnn_size])

        lm_logits = tf.matmul(lm_output, lm_W) + lm_b
        lm_probabilities = tf.nn.softmax(lm_logits)

        global lm_loss, lm_cost, lm_final_state
        lm_loss = legacy_seq2seq.sequence_loss_by_example(
            logits=[lm_logits],
            targets=[tf.reshape(lm_targets_ph, [-1])],
            weights=[tf.ones([lm_batch_size * seq_length])])
        lm_cost = tf.reduce_sum(lm_loss) / lm_batch_size / seq_length
        lm_final_state = lm_last_state

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)
    lm_vars = tfutil.get_variables_in_scope(lm_scope_name)

    predictions_int = tf.cast(predictions, tf.int32)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
    lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

    adversary_scope_name = discriminator_scope_name
    with tf.variable_scope(adversary_scope_name):
        adversary = AdversarialSets(model_class=model_class,
                                    model_kwargs=model_kwargs,
                                    embedding_size=embedding_size,
                                    scope_name='adversary',
                                    batch_size=1,
                                    sequence_length=10,
                                    entailment_idx=entailment_idx,
                                    contradiction_idx=contradiction_idx,
                                    neutral_idx=neutral_idx)

        a_loss, a_sequence_lst = adversary.rule6_loss()

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    text = ['The', 'girl', 'runs', 'on', 'the', 'plane', '.']
    sentence_ids = [token_to_index[token] for token in text]

    global session
    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))

        saver.restore(session, restore_path)

        lm_ckpt = tf.train.get_checkpoint_state(lm_path)
        lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

        embedding_layer_value = session.run(embedding_layer)
        assert embedding_layer_value.shape == (vocab_size, embedding_size)

        sentences, sizes = np.array([sentence_ids
                                     ]), np.array([len(sentence_ids)])
        assert log_perplexity(sentences, sizes) >= 0.0

        feed = {sentence1_ph: sentences, sentence1_len_ph: sizes}
        sentence_embedding = session.run(sentence1_embedding, feed_dict=feed)
        assert sentence_embedding.shape == (1, len(sentence_ids),
                                            embedding_size)

        print(a_sequence_lst)
예제 #9
0
def train(args):
    vocabulary_path = args.vocabulary
    checkpoint_path = args.checkpoint

    with open(vocabulary_path, 'rb') as f:
        index_to_token = pickle.load(f)

    index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'})

    token_to_index = {token: index for index, token in index_to_token.items()}

    logger.info('Loading the dataset ..')

    loader = SNLILoader(path=args.train,
                        token_to_index=token_to_index,
                        batch_size=args.batch_size,
                        seq_length=args.seq_length,
                        shuffle=True)

    valid_loader = SNLILoader(path=args.valid,
                              token_to_index=token_to_index,
                              batch_size=args.batch_size,
                              seq_length=args.seq_length,
                              shuffle=False)

    vocab_size = len(token_to_index)

    config = {
        'model': args.model,
        'seq_length': args.seq_length,
        'batch_size': args.batch_size,
        'vocab_size': vocab_size,
        'embedding_size': args.embedding_size,
        'rnn_size': args.rnn_size,
        'num_layers': args.num_layers
    }

    config_path = os.path.join(args.save, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(config, f)

    logger.info('Generating the computational graph ..')

    print(max(index_to_token.keys()), vocab_size)
    assert max(index_to_token.keys()) + 1 == vocab_size

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable(
            'embeddings',
            shape=[vocab_size, args.embedding_size],
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=False)

    lm_scope_name = 'language_model'
    with tf.variable_scope(lm_scope_name) as scope:
        model = LanguageModel(model=config['model'],
                              seq_length=config['seq_length'],
                              batch_size=config['batch_size'],
                              rnn_size=config['rnn_size'],
                              num_layers=config['num_layers'],
                              vocab_size=config['vocab_size'],
                              embedding_layer=embedding_layer,
                              infer=False)

        scope.reuse_variables()
        imodel = LanguageModel(model=config['model'],
                               seq_length=config['seq_length'],
                               batch_size=config['batch_size'],
                               rnn_size=config['rnn_size'],
                               num_layers=config['num_layers'],
                               vocab_size=config['vocab_size'],
                               embedding_layer=embedding_layer,
                               infer=True)

    optimizer = tf.train.AdagradOptimizer(args.learning_rate)

    tvars = tf.trainable_variables()
    train_op = optimizer.minimize(model.cost, var_list=[var for var in tvars])

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver(tf.global_variables())
    emb_saver = tf.train.Saver([embedding_layer], max_to_keep=1)

    logger.info('Creating the session ..')

    with tf.Session(config=session_config) as session:
        logger.info('Trainable Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=[var
                                                        for var in tvars])))
        session.run(init_op)

        emb_saver.restore(session, checkpoint_path)

        loss_values = []
        best_valid_log_perplexity = None

        for epoch_id in range(0, args.num_epochs):
            logger.debug('Epoch: {}'.format(epoch_id))

            loader.reset_batch_pointer()
            state = session.run(model.initial_state)

            for batch_id in range(loader.pointer, loader.num_batches):
                x, y = loader.next_batch()

                feed_dict = {
                    model.input_data: x,
                    model.targets: y,
                    model.initial_state: state
                }

                loss_value, state, _ = session.run(
                    [model.cost, model.final_state, train_op],
                    feed_dict=feed_dict)
                loss_values += [loss_value]

                if (epoch_id * loader.num_batches +
                        batch_id) % args.report_every == 0:
                    a = epoch_id * loader.num_batches + batch_id
                    b = args.num_epochs * loader.num_batches
                    logger.info("{}/{} (epoch {}), loss = {}".format(
                        a, b, epoch_id, stats(loss_values)))
                    loss_values = []

                    sample_value = imodel.sample(session, index_to_token,
                                                 token_to_index, 10, 'A', 0, 1,
                                                 4)
                    logger.info('Sample: {}'.format(sample_value))

                if (epoch_id * loader.num_batches +
                        batch_id) % args.save_every == 0:
                    valid_loader.reset_batch_pointer()
                    state = session.run(model.initial_state)

                    valid_log_perplexity = 0.0
                    valid_log_perplexities = []

                    for batch_id in range(valid_loader.pointer,
                                          valid_loader.num_batches):
                        x, y = valid_loader.next_batch()

                        feed_dict = {
                            model.input_data: x,
                            model.targets: y,
                            model.initial_state: state
                        }

                        batch_valid_log_perplexity, state = session.run(
                            [model.cost, model.final_state],
                            feed_dict=feed_dict)
                        valid_log_perplexity += batch_valid_log_perplexity
                        valid_log_perplexities += [batch_valid_log_perplexity]

                    if best_valid_log_perplexity is None or valid_log_perplexity < best_valid_log_perplexity:
                        checkpoint_path = os.path.join(args.save, 'lm.ckpt')
                        saver.save(session,
                                   checkpoint_path,
                                   global_step=epoch_id * loader.num_batches +
                                   batch_id)
                        logger.info("Language model saved to {}".format(
                            checkpoint_path))

                        logger.info(
                            'Validation Log-Perplexity: {0:.4f}'.format(
                                valid_log_perplexity))
                        logger.info('Validation Log-Perplexities: {0}'.format(
                            stats(valid_log_perplexities)))

                        best_valid_log_perplexity = valid_log_perplexity
                        config[
                            'valid_log_perplexity'] = best_valid_log_perplexity
                        with open(config_path, 'w') as f:
                            json.dump(config, f)
예제 #10
0
def main(argv):
    logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

    def fmt(prog):
        return argparse.HelpFormatter(prog, max_help_position=100, width=200)

    argparser = argparse.ArgumentParser(
        'Regularising RTE via Adversarial Sets Regularisation',
        formatter_class=fmt)

    argparser.add_argument('--data',
                           '-d',
                           action='store',
                           type=str,
                           default='data/snli/snli_1.0_train.jsonl.gz')
    argparser.add_argument(
        '--model',
        '-m',
        action='store',
        type=str,
        default='ff-dam',
        choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1'])

    argparser.add_argument('--embedding-size',
                           action='store',
                           type=int,
                           default=300)
    argparser.add_argument('--representation-size',
                           action='store',
                           type=int,
                           default=200)

    argparser.add_argument('--batch-size',
                           action='store',
                           type=int,
                           default=32)

    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--has-bos',
                           action='store_true',
                           default=False,
                           help='Has <Beginning Of Sentence> token')
    argparser.add_argument('--has-eos',
                           action='store_true',
                           default=False,
                           help='Has <End Of Sentence> token')
    argparser.add_argument('--has-unk',
                           action='store_true',
                           default=False,
                           help='Has <Unknown Word> token')
    argparser.add_argument('--lower',
                           '-l',
                           action='store_true',
                           default=False,
                           help='Lowercase the corpus')

    argparser.add_argument('--restore', action='store', type=str, default=None)

    argparser.add_argument('--check-transitivity',
                           '-x',
                           action='store_true',
                           default=False)

    args = argparser.parse_args(argv)

    # Command line arguments
    data_path = args.data

    model_name = args.model

    embedding_size = args.embedding_size
    representation_size = args.representation_size

    batch_size = args.batch_size

    seed = args.seed

    has_bos = args.has_bos
    has_eos = args.has_eos
    has_unk = args.has_unk
    is_lower = args.lower

    restore_path = args.restore

    is_check_transitivity = args.check_transitivity

    np.random.seed(seed)
    rs = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    logger.debug('Reading corpus ..')
    data_is, _, _ = util.SNLI.generate(train_path=data_path,
                                       valid_path=None,
                                       test_path=None,
                                       is_lower=is_lower)

    logger.info('Data size: {}'.format(len(data_is)))

    # Enumeration of tokens start at index=3:
    # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD
    bos_idx, eos_idx, unk_idx = 1, 2, 3

    with open('{}_index_to_token.p'.format(restore_path), 'rb') as f:
        index_to_token = pickle.load(f)

    token_to_index = {token: index for index, token in index_to_token.items()}

    entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2
    label_to_index = {
        'entailment': entailment_idx,
        'neutral': neutral_idx,
        'contradiction': contradiction_idx,
    }

    max_len = None

    args = dict(has_bos=has_bos,
                has_eos=has_eos,
                has_unk=has_unk,
                bos_idx=bos_idx,
                eos_idx=eos_idx,
                unk_idx=unk_idx,
                max_len=max_len)

    dataset = util.instances_to_dataset(data_is, token_to_index,
                                        label_to_index, **args)

    sentence1 = dataset['sentence1']
    sentence1_length = dataset['sentence1_length']
    sentence2 = dataset['sentence2']
    sentence2_length = dataset['sentence2_length']
    label = dataset['label']

    sentence1_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence1')
    sentence2_ph = tf.placeholder(dtype=tf.int32,
                                  shape=[None, None],
                                  name='sentence2')

    sentence1_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence1_length')
    sentence2_len_ph = tf.placeholder(dtype=tf.int32,
                                      shape=[None],
                                      name='sentence2_length')

    clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph)
    clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph)

    token_set = set(token_to_index.keys())
    vocab_size = max(token_to_index.values()) + 1

    discriminator_scope_name = 'discriminator'
    with tf.variable_scope(discriminator_scope_name):
        embedding_layer = tf.get_variable('embeddings',
                                          shape=[vocab_size, embedding_size],
                                          trainable=False)

        sentence1_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence1)
        sentence2_embedding = tf.nn.embedding_lookup(embedding_layer,
                                                     clipped_sentence2)

        dropout_keep_prob_ph = tf.placeholder(tf.float32,
                                              name='dropout_keep_prob')

        model_kwargs = dict(sequence1=sentence1_embedding,
                            sequence1_length=sentence1_len_ph,
                            sequence2=sentence2_embedding,
                            sequence2_length=sentence2_len_ph,
                            representation_size=representation_size,
                            dropout_keep_prob=dropout_keep_prob_ph)

        if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}:
            model_kwargs['init_std_dev'] = 0.01

        mode_name_to_class = {
            'cbilstm': ConditionalBiLSTM,
            'ff-dam': FeedForwardDAM,
            'ff-damp': FeedForwardDAMP,
            'ff-dams': FeedForwardDAMS,
            'esim1': ESIMv1
        }

        model_class = mode_name_to_class[model_name]

        assert model_class is not None
        model = model_class(**model_kwargs)

        logits = model()
        probabilities = tf.nn.softmax(logits)

        predictions = tf.argmax(logits, axis=1, name='predictions')

    discriminator_vars = tfutil.get_variables_in_scope(
        discriminator_scope_name)

    trainable_discriminator_vars = list(discriminator_vars)

    predictions_int = tf.cast(predictions, tf.int32)

    saver = tf.train.Saver(discriminator_vars, max_to_keep=1)

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True

    with tf.Session(config=session_config) as session:
        logger.info('Total Parameters: {}'.format(
            tfutil.count_trainable_parameters()))

        logger.info('Total Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(var_list=discriminator_vars)))

        logger.info('Total Trainable Discriminator Parameters: {}'.format(
            tfutil.count_trainable_parameters(
                var_list=trainable_discriminator_vars)))

        saver.restore(session, restore_path)

        nb_instances = sentence1.shape[0]
        batches = make_batches(size=nb_instances, batch_size=batch_size)

        order = np.arange(nb_instances)

        sentences1 = sentence1[order]
        sentences2 = sentence2[order]

        sizes1 = sentence1_length[order]
        sizes2 = sentence2_length[order]

        labels = label[order]

        a_predictions_int_value = []
        b_predictions_int_value = []

        a_probabilities_value = []
        b_probabilities_value = []

        for batch_idx, (batch_start,
                        batch_end) in tqdm(list(enumerate(batches))):
            batch_sentences1 = sentences1[batch_start:batch_end]
            batch_sentences2 = sentences2[batch_start:batch_end]
            batch_sizes1 = sizes1[batch_start:batch_end]
            batch_sizes2 = sizes2[batch_start:batch_end]

            batch_a_feed_dict = {
                sentence1_ph: batch_sentences1,
                sentence1_len_ph: batch_sizes1,
                sentence2_ph: batch_sentences2,
                sentence2_len_ph: batch_sizes2,
                dropout_keep_prob_ph: 1.0
            }

            batch_a_predictions_int_value, batch_a_probabilities_value = session.run(
                [predictions_int, probabilities], feed_dict=batch_a_feed_dict)

            a_predictions_int_value += batch_a_predictions_int_value.tolist()
            for i in range(batch_a_probabilities_value.shape[0]):
                a_probabilities_value += [{
                    'neutral':
                    batch_a_probabilities_value[i, neutral_idx],
                    'contradiction':
                    batch_a_probabilities_value[i, contradiction_idx],
                    'entailment':
                    batch_a_probabilities_value[i, entailment_idx]
                }]

            batch_b_feed_dict = {
                sentence1_ph: batch_sentences2,
                sentence1_len_ph: batch_sizes2,
                sentence2_ph: batch_sentences1,
                sentence2_len_ph: batch_sizes1,
                dropout_keep_prob_ph: 1.0
            }

            batch_b_predictions_int_value, batch_b_probabilities_value = session.run(
                [predictions_int, probabilities], feed_dict=batch_b_feed_dict)
            b_predictions_int_value += batch_b_predictions_int_value.tolist()
            for i in range(batch_b_probabilities_value.shape[0]):
                b_probabilities_value += [{
                    'neutral':
                    batch_b_probabilities_value[i, neutral_idx],
                    'contradiction':
                    batch_b_probabilities_value[i, contradiction_idx],
                    'entailment':
                    batch_b_probabilities_value[i, entailment_idx]
                }]

        for i, instance in enumerate(data_is):
            instance.update({
                'a': a_probabilities_value[i],
                'b': b_probabilities_value[i],
            })

        logger.info('Number of examples: {}'.format(labels.shape[0]))

        train_accuracy_value = np.mean(
            labels == np.array(a_predictions_int_value))
        logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value))

        s1s2_con = (np.array(a_predictions_int_value) == contradiction_idx)
        s2s1_con = (np.array(b_predictions_int_value) == contradiction_idx)

        assert s1s2_con.shape == s2s1_con.shape

        s1s2_ent = (np.array(a_predictions_int_value) == entailment_idx)
        s2s1_ent = (np.array(b_predictions_int_value) == entailment_idx)

        s1s2_neu = (np.array(a_predictions_int_value) == neutral_idx)
        s2s1_neu = (np.array(b_predictions_int_value) == neutral_idx)

        a = np.logical_xor(s1s2_con, s2s1_con)
        logger.info('(S1 contradicts S2) XOR (S2 contradicts S1): {0}'.format(
            a.sum()))

        b = s1s2_con
        logger.info('(S1 contradicts S2): {0}'.format(b.sum()))
        c = np.logical_and(s1s2_con, np.logical_not(s2s1_con))
        logger.info(
            '(S1 contradicts S2) AND NOT(S2 contradicts S1): {0} ({1:.4f})'.
            format(c.sum(),
                   c.sum() / b.sum()))

        with open('c.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(c)[0].tolist()]
            pickle.dump(tmp, f)

        d = s1s2_ent
        logger.info('(S1 entailment S2): {0}'.format(d.sum()))
        e = np.logical_and(s1s2_ent, s2s1_con)
        logger.info(
            '(S1 entailment S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format(
                e.sum(),
                e.sum() / d.sum()))

        with open('e.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(e)[0].tolist()]
            pickle.dump(tmp, f)

        f = s1s2_con
        logger.info('(S1 neutral S2): {0}'.format(f.sum()))
        g = np.logical_and(s1s2_neu, s2s1_con)
        logger.info(
            '(S1 neutral S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format(
                g.sum(),
                g.sum() / f.sum()))

        with open('g.p', 'wb') as f:
            tmp = [data_is[i] for i in np.where(g)[0].tolist()]
            pickle.dump(tmp, f)

        if is_check_transitivity:
            # Find S1, S2 such that entails(S1, S2)
            print(type(s1s2_ent))

            c_predictions_int_value = []
            c_probabilities_value = []

            d_predictions_int_value = []
            d_probabilities_value = []

            # Find candidate S3 sentences
            order = np.arange(nb_instances)

            sentences3 = sentence2[order]
            sizes3 = sentence2_length[order]

            for batch_idx, (batch_start,
                            batch_end) in tqdm(list(enumerate(batches))):
                batch_sentences2 = sentences2[batch_start:batch_end]
                batch_sentences3 = sentences3[batch_start:batch_end]

                batch_sizes2 = sizes2[batch_start:batch_end]
                batch_sizes3 = sizes3[batch_start:batch_end]

                batch_c_feed_dict = {
                    sentence1_ph: batch_sentences2,
                    sentence1_len_ph: batch_sizes2,
                    sentence2_ph: batch_sentences3,
                    sentence2_len_ph: batch_sizes3,
                    dropout_keep_prob_ph: 1.0
                }

                batch_c_predictions_int_value, batch_c_probabilities_value = session.run(
                    [predictions_int, probabilities],
                    feed_dict=batch_c_feed_dict)

                c_predictions_int_value += batch_c_predictions_int_value.tolist(
                )
                for i in range(batch_c_probabilities_value.shape[0]):
                    c_probabilities_value += [{
                        'neutral':
                        batch_c_probabilities_value[i, neutral_idx],
                        'contradiction':
                        batch_c_probabilities_value[i, contradiction_idx],
                        'entailment':
                        batch_c_probabilities_value[i, entailment_idx]
                    }]

                batch_sentences1 = sentences1[batch_start:batch_end]
                batch_sentences3 = sentences3[batch_start:batch_end]

                batch_sizes1 = sizes1[batch_start:batch_end]
                batch_sizes3 = sizes3[batch_start:batch_end]

                batch_d_feed_dict = {
                    sentence1_ph: batch_sentences1,
                    sentence1_len_ph: batch_sizes1,
                    sentence2_ph: batch_sentences3,
                    sentence2_len_ph: batch_sizes3,
                    dropout_keep_prob_ph: 1.0
                }

                batch_d_predictions_int_value, batch_d_probabilities_value = session.run(
                    [predictions_int, probabilities],
                    feed_dict=batch_d_feed_dict)

                d_predictions_int_value += batch_d_predictions_int_value.tolist(
                )
                for i in range(batch_d_probabilities_value.shape[0]):
                    d_probabilities_value += [{
                        'neutral':
                        batch_d_probabilities_value[i, neutral_idx],
                        'contradiction':
                        batch_d_probabilities_value[i, contradiction_idx],
                        'entailment':
                        batch_d_probabilities_value[i, entailment_idx]
                    }]

            s2s3_ent = (np.array(c_predictions_int_value) == entailment_idx)
            s1s3_ent = (np.array(c_predictions_int_value) == entailment_idx)

            body = np.logical_and(s1s2_ent, s2s3_ent)
            body_not_head = np.logical_and(body, np.logical_not(s1s3_ent))

            logger.info('(S1 entails S2) and (S2 entails S3): {0}'.format(
                body.sum()))
            logger.info('body AND NOT(head): {0} ({1:.4f})'.format(
                body_not_head.sum(),
                body_not_head.sum() / body.sum()))

            with open('h.p', 'wb') as f:
                tmp = []
                for idx in np.where(body_not_head)[0].tolist():
                    s1 = data_is[idx]['sentence1']
                    s2 = data_is[idx]['sentence2']
                    s3 = data_is[order[idx]]['sentence2']
                    tmp += [{'s1': s1, 's2': s2, 's3': s3}]
                pickle.dump(tmp, f)