def accuracy(session, dataset, name, sentence1_ph, sentence1_length_ph, sentence2_ph, sentence2_length_ph, label_ph, dropout_keep_prob_ph, predictions_int, labels_int, contradiction_idx, entailment_idx, neutral_idx, batch_size): nb_eval_instances = len(dataset['sentence1']) eval_batches = make_batches(size=nb_eval_instances, batch_size=batch_size) p_vals, l_vals = [], [] for e_batch_start, e_batch_end in eval_batches: feed_dict = { sentence1_ph: dataset['sentence1'][e_batch_start:e_batch_end], sentence1_length_ph: dataset['sentence1_length'][e_batch_start:e_batch_end], sentence2_ph: dataset['sentence2'][e_batch_start:e_batch_end], sentence2_length_ph: dataset['sentence2_length'][e_batch_start:e_batch_end], label_ph: dataset['label'][e_batch_start:e_batch_end], dropout_keep_prob_ph: 1.0 } p_val, l_val = session.run([predictions_int, labels_int], feed_dict=feed_dict) p_vals += p_val.tolist() l_vals += l_val.tolist() matches = np.equal(p_vals, l_vals) acc = np.mean(matches) acc_c = np.mean(matches[np.where(np.array(l_vals) == contradiction_idx)]) acc_e = np.mean(matches[np.where(np.array(l_vals) == entailment_idx)]) acc_n = np.mean(matches[np.where(np.array(l_vals) == neutral_idx)]) if name: logger.debug('{0} Accuracy: {1:.4f} - C: {2:.4f}, E: {3:.4f}, N: {4:.4f}'.format( name, acc * 100, acc_c * 100, acc_e * 100, acc_n * 100)) return acc, acc_c, acc_e, acc_n
def create_batches(self): order = self.random_state.permutation(self.nb_samples) tensor_shuf = self.tensor[order, :] _batch_lst = make_batches(self.nb_samples, self.batch_size) self.batches = [] for batch_start, batch_end in _batch_lst: batch_size = batch_end - batch_start batch = tensor_shuf[batch_start:batch_end, :] assert batch.shape[0] == batch_size x = np.zeros(shape=(batch_size, self.seq_length)) y = np.zeros(shape=(batch_size, self.seq_length)) for i in range(batch_size): start_idx = self.random_state.randint(low=0, high=self.max_len - 1) end_idx = min(start_idx + self.seq_length, self.max_len) x[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx] start_idx += 1 end_idx = min(start_idx + self.seq_length, self.max_len) y[i, 0:(end_idx - start_idx)] = batch[i, start_idx:end_idx] d = {'x': x, 'y': y} self.batches += [d] self.num_batches = len(self.batches) return
def main(argv): logger.info('Command line: {}'.format(' '.join(arg for arg in argv))) def fmt(prog): return argparse.HelpFormatter(prog, max_help_position=100, width=200) argparser = argparse.ArgumentParser( 'Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt) argparser.add_argument('--train', '-t', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz') argparser.add_argument('--valid', '-v', action='store', type=str, default='data/snli/snli_1.0_dev.jsonl.gz') argparser.add_argument('--test', '-T', action='store', type=str, default='data/snli/snli_1.0_test.jsonl.gz') argparser.add_argument( '--model', '-m', action='store', type=str, default='cbilstm', choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1']) argparser.add_argument('--optimizer', '-o', action='store', type=str, default='adagrad', choices=['adagrad', 'adam']) argparser.add_argument('--embedding-size', action='store', type=int, default=300) argparser.add_argument('--representation-size', action='store', type=int, default=200) argparser.add_argument('--batch-size', action='store', type=int, default=1024) argparser.add_argument('--nb-epochs', '-e', action='store', type=int, default=1000) argparser.add_argument('--nb-discriminator-epochs', '-D', action='store', type=int, default=1) argparser.add_argument('--nb-adversary-epochs', '-A', action='store', type=int, default=1000) argparser.add_argument('--dropout-keep-prob', action='store', type=float, default=1.0) argparser.add_argument('--learning-rate', action='store', type=float, default=0.1) argparser.add_argument('--clip', '-c', action='store', type=float, default=None) argparser.add_argument('--nb-words', action='store', type=int, default=None) argparser.add_argument('--seed', action='store', type=int, default=0) argparser.add_argument('--std-dev', action='store', type=float, default=0.01) argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token') argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token') argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token') argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus') argparser.add_argument('--initialize-embeddings', '-i', action='store', type=str, default=None, choices=['normal', 'uniform']) argparser.add_argument('--fixed-embeddings', '-f', action='store_true') argparser.add_argument('--normalize-embeddings', '-n', action='store_true') argparser.add_argument('--only-use-pretrained-embeddings', '-p', action='store_true', help='Only use pre-trained word embeddings') argparser.add_argument('--train-special-token-embeddings', '-s', action='store_true') argparser.add_argument('--semi-sort', '-S', action='store_true') argparser.add_argument('--save', action='store', type=str, default=None) argparser.add_argument('--hard-save', action='store', type=str, default=None) argparser.add_argument('--restore', action='store', type=str, default=None) argparser.add_argument('--glove', action='store', type=str, default=None) argparser.add_argument('--word2vec', action='store', type=str, default=None) argparser.add_argument('--rule0-weight', '-0', action='store', type=float, default=None) argparser.add_argument('--rule1-weight', '-1', action='store', type=float, default=None) argparser.add_argument('--rule2-weight', '-2', action='store', type=float, default=None) argparser.add_argument('--rule3-weight', '-3', action='store', type=float, default=None) argparser.add_argument('--rule4-weight', '-4', action='store', type=float, default=None) argparser.add_argument('--rule5-weight', '-5', action='store', type=float, default=None) argparser.add_argument('--rule6-weight', '-6', action='store', type=float, default=None) argparser.add_argument('--rule7-weight', '-7', action='store', type=float, default=None) argparser.add_argument('--rule8-weight', '-8', action='store', type=float, default=None) argparser.add_argument('--adversarial-batch-size', '-B', action='store', type=int, default=32) argparser.add_argument('--adversarial-sentence-length', '-L', action='store', type=int, default=10) argparser.add_argument('--adversarial-pooling', '-P', default='max', choices=['sum', 'max', 'mean', 'logsumexp']) argparser.add_argument( '--adversarial-smart-init', '-I', action='store_true', default=False, help='Initialize sentence embeddings with actual word embeddings') argparser.add_argument( '--report', '-r', default=100, type=int, help='Number of batches between performance reports') argparser.add_argument('--report-loss', default=100, type=int, help='Number of batches between loss reports') argparser.add_argument( '--memory-limit', default=None, type=int, help= 'The maximum area (in bytes) of address space which may be taken by the process.' ) argparser.add_argument('--universum', '-U', action='store_true') args = argparser.parse_args(argv) # Command line arguments train_path, valid_path, test_path = args.train, args.valid, args.test model_name = args.model optimizer_name = args.optimizer embedding_size = args.embedding_size representation_size = args.representation_size batch_size = args.batch_size nb_epochs = args.nb_epochs nb_discriminator_epochs = args.nb_discriminator_epochs nb_adversary_epochs = args.nb_adversary_epochs dropout_keep_prob = args.dropout_keep_prob learning_rate = args.learning_rate clip_value = args.clip seed = args.seed std_dev = args.std_dev has_bos = args.has_bos has_eos = args.has_eos has_unk = args.has_unk is_lower = args.lower initialize_embeddings = args.initialize_embeddings is_fixed_embeddings = args.fixed_embeddings is_normalize_embeddings = args.normalize_embeddings is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings is_train_special_token_embeddings = args.train_special_token_embeddings is_semi_sort = args.semi_sort logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format( has_bos, has_eos, has_unk)) logger.info( 'is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}'. format(is_lower, is_fixed_embeddings, is_normalize_embeddings)) logger.info( 'is_only_use_pretrained_embeddings: {}, is_train_special_token_embeddings: {}, is_semi_sort: {}' .format(is_only_use_pretrained_embeddings, is_train_special_token_embeddings, is_semi_sort)) save_path = args.save hard_save_path = args.hard_save restore_path = args.restore glove_path = args.glove word2vec_path = args.word2vec # Experimental RTE regularizers rule0_weight = args.rule0_weight rule1_weight = args.rule1_weight rule2_weight = args.rule2_weight rule3_weight = args.rule3_weight rule4_weight = args.rule4_weight rule5_weight = args.rule5_weight rule6_weight = args.rule6_weight rule7_weight = args.rule7_weight rule8_weight = args.rule8_weight adversarial_batch_size = args.adversarial_batch_size adversarial_sentence_length = args.adversarial_sentence_length adversarial_pooling_name = args.adversarial_pooling adversarial_smart_init = args.adversarial_smart_init name_to_adversarial_pooling = { 'sum': tf.reduce_sum, 'max': tf.reduce_max, 'mean': tf.reduce_mean, 'logsumexp': tf.reduce_logsumexp } report_interval = args.report report_loss_interval = args.report_loss memory_limit = args.memory_limit is_universum = args.universum if memory_limit: import resource soft, hard = resource.getrlimit(resource.RLIMIT_AS) logging.info('Current memory limit: {}, {}'.format(soft, hard)) resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit)) soft, hard = resource.getrlimit(resource.RLIMIT_AS) logging.info('New memory limit: {}, {}'.format(soft, hard)) np.random.seed(seed) random_state = np.random.RandomState(seed) tf.set_random_seed(seed) logger.debug('Reading corpus ..') train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path, valid_path=valid_path, test_path=test_path, is_lower=is_lower) logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format( len(train_is), len(dev_is), len(test_is))) all_is = train_is + dev_is + test_is # Enumeration of tokens start at index=3: # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD bos_idx, eos_idx, unk_idx = 1, 2, 3 start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else 0) + (1 if has_unk else 0) if not restore_path: # Create a sequence of tokens containing all sentences in the dataset token_seq = [] for instance in all_is: token_seq += instance['sentence1_parse_tokens'] + instance[ 'sentence2_parse_tokens'] token_set = set(token_seq) allowed_words = None if is_only_use_pretrained_embeddings: assert (glove_path is not None) or (word2vec_path is not None) if glove_path: logger.info('Loading GloVe words from {}'.format(glove_path)) assert os.path.isfile(glove_path) allowed_words = load_glove_words(path=glove_path, words=token_set) elif word2vec_path: logger.info( 'Loading word2vec words from {}'.format(word2vec_path)) assert os.path.isfile(word2vec_path) allowed_words = load_word2vec_words(path=word2vec_path, words=token_set) logger.info('Number of allowed words: {}'.format( len(allowed_words))) # Count the number of occurrences of each token token_counts = dict() for token in token_seq: if (allowed_words is None) or (token in allowed_words): if token not in token_counts: token_counts[token] = 0 token_counts[token] += 1 # Sort the tokens according to their frequency and lexicographic ordering sorted_vocabulary = sorted(token_counts.keys(), key=lambda t: (-token_counts[t], t)) index_to_token = { index: token for index, token in enumerate(sorted_vocabulary, start=start_idx) } else: with open('{}_index_to_token.p'.format(restore_path), 'rb') as f: index_to_token = pickle.load(f) token_to_index = {token: index for index, token in index_to_token.items()} entailment_idx, neutral_idx, contradiction_idx, none_idx = 0, 1, 2, 3 label_to_index = { 'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx, } if is_universum: label_to_index['none'] = none_idx max_len = None optimizer_name_to_class = { 'adagrad': tf.train.AdagradOptimizer, 'adam': tf.train.AdamOptimizer } optimizer_class = optimizer_name_to_class[optimizer_name] assert optimizer_class optimizer = optimizer_class(learning_rate=learning_rate) args = dict(has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx, max_len=max_len) train_dataset = util.instances_to_dataset(train_is, token_to_index, label_to_index, **args) dev_dataset = util.instances_to_dataset(dev_is, token_to_index, label_to_index, **args) test_dataset = util.instances_to_dataset(test_is, token_to_index, label_to_index, **args) sentence1 = train_dataset['sentence1'] sentence1_length = train_dataset['sentence1_length'] sentence2 = train_dataset['sentence2'] sentence2_length = train_dataset['sentence2_length'] label = train_dataset['label'] sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1') sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2') sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length') sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length') clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph) clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph) label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label') token_set = set(token_to_index.keys()) vocab_size = max(token_to_index.values()) + 1 nb_words = len(token_to_index) nb_special_tokens = vocab_size - nb_words token_to_embedding = dict() if not restore_path: if glove_path: logger.info( 'Loading GloVe word embeddings from {}'.format(glove_path)) assert os.path.isfile(glove_path) token_to_embedding = load_glove(glove_path, token_set) elif word2vec_path: logger.info('Loading word2vec word embeddings from {}'.format( word2vec_path)) assert os.path.isfile(word2vec_path) token_to_embedding = load_word2vec(word2vec_path, token_set) discriminator_scope_name = 'discriminator' with tf.variable_scope(discriminator_scope_name): if initialize_embeddings == 'normal': logger.info('Initializing the embeddings with 𝓝(0, 1)') embedding_initializer = tf.random_normal_initializer(0.0, 1.0) elif initialize_embeddings == 'uniform': logger.info('Initializing the embeddings with 𝒰(-1, 1)') embedding_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0) else: logger.info( 'Initializing the embeddings with Xavier initialization') embedding_initializer = tf.contrib.layers.xavier_initializer() if is_train_special_token_embeddings: embedding_layer_special = tf.get_variable( 'special_embeddings', shape=[nb_special_tokens, embedding_size], initializer=embedding_initializer, trainable=True) embedding_layer_words = tf.get_variable( 'word_embeddings', shape=[nb_words, embedding_size], initializer=embedding_initializer, trainable=not is_fixed_embeddings) embedding_layer = tf.concat( values=[embedding_layer_special, embedding_layer_words], axis=0) else: embedding_layer = tf.get_variable( 'embeddings', shape=[vocab_size, embedding_size], initializer=embedding_initializer, trainable=not is_fixed_embeddings) sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1) sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2) dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob') model_kwargs = dict(sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph, sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph, representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph) if is_universum: model_kwargs['nb_classes'] = 4 if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}: model_kwargs['init_std_dev'] = std_dev mode_name_to_class = { 'cbilstm': ConditionalBiLSTM, 'ff-dam': FeedForwardDAM, 'ff-damp': FeedForwardDAMP, 'ff-dams': FeedForwardDAMS, 'esim1': ESIMv1 } model_class = mode_name_to_class[model_name] assert model_class is not None model = model_class(**model_kwargs) logits = model() predictions = tf.argmax(logits, axis=1, name='predictions') losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ph) loss = tf.reduce_mean(losses) if rule0_weight: loss += rule0_weight * contradiction_symmetry_l2( model_class, model_kwargs, contradiction_idx=contradiction_idx) discriminator_vars = tfutil.get_variables_in_scope( discriminator_scope_name) discriminator_init_op = tf.variables_initializer(discriminator_vars) trainable_discriminator_vars = list(discriminator_vars) if is_fixed_embeddings: if is_train_special_token_embeddings: trainable_discriminator_vars.remove(embedding_layer_words) else: trainable_discriminator_vars.remove(embedding_layer) discriminator_optimizer_scope_name = 'discriminator_optimizer' with tf.variable_scope(discriminator_optimizer_scope_name): if clip_value: gradients, v = zip(*optimizer.compute_gradients( loss, var_list=trainable_discriminator_vars)) gradients, _ = tf.clip_by_global_norm(gradients, clip_value) training_step = optimizer.apply_gradients(zip(gradients, v)) else: training_step = optimizer.minimize( loss, var_list=trainable_discriminator_vars) discriminator_optimizer_vars = tfutil.get_variables_in_scope( discriminator_optimizer_scope_name) discriminator_optimizer_init_op = tf.variables_initializer( discriminator_optimizer_vars) token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx') token_embedding_ph = tf.placeholder(dtype=tf.float32, shape=[None], name='word_embedding') if is_train_special_token_embeddings: assign_token_embedding = embedding_layer_words[ token_idx_ph - nb_special_tokens, :].assign(token_embedding_ph) else: assign_token_embedding = embedding_layer[token_idx_ph, :].assign( token_embedding_ph) init_projection_steps = [] learning_projection_steps = [] if is_normalize_embeddings: if is_train_special_token_embeddings: special_embeddings_projection = constraints.unit_sphere( embedding_layer_special, norm=1.0) word_embeddings_projection = constraints.unit_sphere( embedding_layer_words, norm=1.0) init_projection_steps += [special_embeddings_projection] init_projection_steps += [word_embeddings_projection] learning_projection_steps += [special_embeddings_projection] if not is_fixed_embeddings: learning_projection_steps += [word_embeddings_projection] else: embeddings_projection = constraints.unit_sphere(embedding_layer, norm=1.0) init_projection_steps += [embeddings_projection] if not is_fixed_embeddings: learning_projection_steps += [embeddings_projection] predictions_int = tf.cast(predictions, tf.int32) labels_int = tf.cast(label_ph, tf.int32) use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight if use_adversarial_training: adversary_scope_name = discriminator_scope_name with tf.variable_scope(adversary_scope_name): adversarial = AdversarialSets( model_class=model_class, model_kwargs=model_kwargs, embedding_size=embedding_size, scope_name='adversary', batch_size=adversarial_batch_size, sequence_length=adversarial_sentence_length, entailment_idx=entailment_idx, contradiction_idx=contradiction_idx, neutral_idx=neutral_idx) adversary_loss = tf.constant(0.0, dtype=tf.float32) adversary_vars = [] adversarial_pooling = name_to_adversarial_pooling[ adversarial_pooling_name] if rule1_weight: rule1_loss, rule1_vars = adversarial.rule1_loss() adversary_loss += rule1_weight * adversarial_pooling( rule1_loss) adversary_vars += rule1_vars if rule2_weight: rule2_loss, rule2_vars = adversarial.rule2_loss() adversary_loss += rule2_weight * adversarial_pooling( rule2_loss) adversary_vars += rule2_vars if rule3_weight: rule3_loss, rule3_vars = adversarial.rule3_loss() adversary_loss += rule3_weight * adversarial_pooling( rule3_loss) adversary_vars += rule3_vars if rule4_weight: rule4_loss, rule4_vars = adversarial.rule4_loss() adversary_loss += rule4_weight * adversarial_pooling( rule4_loss) adversary_vars += rule4_vars if rule5_weight: rule5_loss, rule5_vars = adversarial.rule5_loss() adversary_loss += rule5_weight * adversarial_pooling( rule5_loss) adversary_vars += rule5_vars if rule6_weight: rule6_loss, rule6_vars = adversarial.rule6_loss() adversary_loss += rule6_weight * adversarial_pooling( rule6_loss) adversary_vars += rule6_vars if rule7_weight: rule7_loss, rule7_vars = adversarial.rule7_loss() adversary_loss += rule7_weight * adversarial_pooling( rule7_loss) adversary_vars += rule7_vars if rule8_weight: rule8_loss, rule8_vars = adversarial.rule8_loss() adversary_loss += rule8_weight * adversarial_pooling( rule8_loss) adversary_vars += rule8_vars loss += adversary_loss assert len(adversary_vars) > 0 for adversary_var in adversary_vars: assert adversary_var.name.startswith( 'discriminator/adversary/rule') adversary_var_to_assign_op = dict() adversary_var_value_ph = tf.placeholder(dtype=tf.float32, shape=[None, None, None], name='adversary_var_value') for a_var in adversary_vars: adversary_var_to_assign_op[a_var] = a_var.assign( adversary_var_value_ph) adversary_init_op = tf.variables_initializer(adversary_vars) adv_opt_scope_name = 'adversary_optimizer' with tf.variable_scope(adv_opt_scope_name): adversary_optimizer = optimizer_class(learning_rate=learning_rate) adversary_training_step = adversary_optimizer.minimize( -adversary_loss, var_list=adversary_vars) adversary_optimizer_vars = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=adv_opt_scope_name) adversary_optimizer_init_op = tf.variables_initializer( adversary_optimizer_vars) logger.info( 'Adversarial Batch Size: {}'.format(adversarial_batch_size)) adversary_projection_steps = [] for var in adversary_vars: if is_normalize_embeddings: unit_sphere_adversarial_embeddings = constraints.unit_sphere( var, norm=1.0, axis=-1) adversary_projection_steps += [ unit_sphere_adversarial_embeddings ] assert adversarial_batch_size == var.get_shape()[0].value def token_init_op(_var, _token_idx, target_idx): token_emb = tf.nn.embedding_lookup(embedding_layer, _token_idx) tiled_token_emb = tf.tile(tf.expand_dims(token_emb, 0), (adversarial_batch_size, 1)) return _var[:, target_idx, :].assign(tiled_token_emb) if has_bos: adversary_projection_steps += [token_init_op(var, bos_idx, 0)] saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars, max_to_keep=1) session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True # session_config.log_device_placement = True with tf.Session(config=session_config) as session: logger.info('Total Parameters: {}'.format( tfutil.count_trainable_parameters())) logger.info('Total Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters(var_list=discriminator_vars))) logger.info('Total Trainable Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters( var_list=trainable_discriminator_vars))) if use_adversarial_training: session.run([adversary_init_op, adversary_optimizer_init_op]) if restore_path: saver.restore(session, restore_path) else: session.run( [discriminator_init_op, discriminator_optimizer_init_op]) # Initialising pre-trained embeddings logger.info('Initialising the embeddings pre-trained vectors ..') for token in token_to_embedding: token_idx, token_embedding = token_to_index[ token], token_to_embedding[token] assert embedding_size == len(token_embedding) session.run(assign_token_embedding, feed_dict={ token_idx_ph: token_idx, token_embedding_ph: token_embedding }) logger.info('Done!') for adversary_projection_step in init_projection_steps: session.run([adversary_projection_step]) nb_instances = sentence1.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) best_dev_acc, best_test_acc = None, None discriminator_batch_counter = 0 for epoch in range(1, nb_epochs + 1): for d_epoch in range(1, nb_discriminator_epochs + 1): order = random_state.permutation(nb_instances) sentences1, sentences2 = sentence1[order], sentence2[order] sizes1, sizes2 = sentence1_length[order], sentence2_length[ order] labels = label[order] if is_semi_sort: order = util.semi_sort(sizes1, sizes2) sentences1, sentences2 = sentence1[order], sentence2[order] sizes1, sizes2 = sentence1_length[order], sentence2_length[ order] labels = label[order] loss_values, epoch_loss_values = [], [] for batch_idx, (batch_start, batch_end) in enumerate(batches): discriminator_batch_counter += 1 batch_sentences1, batch_sentences2 = sentences1[ batch_start:batch_end], sentences2[ batch_start:batch_end] batch_sizes1, batch_sizes2 = sizes1[ batch_start:batch_end], sizes2[batch_start:batch_end] batch_labels = labels[batch_start:batch_end] batch_max_size1 = np.max(batch_sizes1) batch_max_size2 = np.max(batch_sizes2) batch_sentences1 = batch_sentences1[:, :batch_max_size1] batch_sentences2 = batch_sentences2[:, :batch_max_size2] batch_feed_dict = { sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1, sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2, label_ph: batch_labels, dropout_keep_prob_ph: dropout_keep_prob } _, loss_value = session.run([training_step, loss], feed_dict=batch_feed_dict) logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format( epoch, d_epoch, batch_idx, loss_value)) cur_batch_size = batch_sentences1.shape[0] loss_values += [loss_value / cur_batch_size] epoch_loss_values += [loss_value / cur_batch_size] for adversary_projection_step in learning_projection_steps: session.run([adversary_projection_step]) if discriminator_batch_counter % report_loss_interval == 0: logger.info( 'Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format( epoch, d_epoch, batch_idx, stats(loss_values))) loss_values = [] if discriminator_batch_counter % report_interval == 0: accuracy_args = [ sentence1_ph, sentence1_len_ph, sentence2_ph, sentence2_len_ph, label_ph, dropout_keep_prob_ph, predictions_int, labels_int, contradiction_idx, entailment_idx, neutral_idx, batch_size ] dev_acc, _, _, _ = accuracy(session, dev_dataset, 'Dev', *accuracy_args) test_acc, _, _, _ = accuracy(session, test_dataset, 'Test', *accuracy_args) logger.info( 'Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}' .format(epoch, d_epoch, batch_idx, dev_acc * 100, test_acc * 100)) if best_dev_acc is None or dev_acc > best_dev_acc: best_dev_acc, best_test_acc = dev_acc, test_acc if save_path: with open( '{}_index_to_token.p'.format( save_path), 'wb') as f: pickle.dump(index_to_token, f) saved_path = saver.save(session, save_path) logger.info('Model saved in file: {}'.format( saved_path)) logger.info( 'Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}' .format(epoch, d_epoch, batch_idx, best_dev_acc * 100, best_test_acc * 100)) logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format( epoch, d_epoch, stats(epoch_loss_values))) if hard_save_path: with open('{}_index_to_token.p'.format(hard_save_path), 'wb') as f: pickle.dump(index_to_token, f) hard_saved_path = saver.save(session, hard_save_path) logger.info( 'Model saved in file: {}'.format(hard_saved_path)) if use_adversarial_training: session.run([adversary_init_op, adversary_optimizer_init_op]) if adversarial_smart_init: _token_indices = np.array(sorted(index_to_token.keys())) for a_var in adversary_vars: # Create a [batch size, sentence length, embedding size] NumPy tensor of sentence embeddings a_word_idx = _token_indices[random_state.randint( low=0, high=len(_token_indices), size=[ adversarial_batch_size, adversarial_sentence_length ])] np_embedding_layer = session.run(embedding_layer) np_adversarial_embeddings = np_embedding_layer[ a_word_idx] assert np_adversarial_embeddings.shape == ( adversarial_batch_size, adversarial_sentence_length, embedding_size) assert a_var in adversary_var_to_assign_op assign_op = adversary_var_to_assign_op[a_var] logger.info( 'Clever initialization of the adversarial embeddings ..' ) session.run(assign_op, feed_dict={ adversary_var_value_ph: np_adversarial_embeddings }) for a_epoch in range(1, nb_adversary_epochs + 1): adversary_feed_dict = {dropout_keep_prob_ph: 1.0} _, adversary_loss_value = session.run( [adversary_training_step, adversary_loss], feed_dict=adversary_feed_dict) logger.info('Adversary Epoch {0}/{1}\tLoss: {2}'.format( epoch, a_epoch, adversary_loss_value)) for adversary_projection_step in adversary_projection_steps: session.run(adversary_projection_step) logger.info('Training finished.')
def main(argv): logger.info('Command line: {}'.format(' '.join(arg for arg in argv))) def fmt(prog): return argparse.HelpFormatter(prog, max_help_position=100, width=200) argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt) argparser.add_argument('--train', '-t', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz') argparser.add_argument('--valid', '-v', action='store', type=str, default='data/snli/snli_1.0_dev.jsonl.gz') argparser.add_argument('--test', '-T', action='store', type=str, default='data/snli/snli_1.0_test.jsonl.gz') argparser.add_argument('--model', '-m', action='store', type=str, default='cbilstm', choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1']) argparser.add_argument('--optimizer', '-o', action='store', type=str, default='adagrad', choices=['adagrad', 'adam']) argparser.add_argument('--embedding-size', action='store', type=int, default=300) argparser.add_argument('--representation-size', action='store', type=int, default=200) argparser.add_argument('--batch-size', action='store', type=int, default=1024) argparser.add_argument('--nb-epochs', '-e', action='store', type=int, default=1000) argparser.add_argument('--nb-discriminator-epochs', '-D', action='store', type=int, default=1) argparser.add_argument('--nb-adversary-epochs', '-A', action='store', type=int, default=1000) argparser.add_argument('--dropout-keep-prob', action='store', type=float, default=1.0) argparser.add_argument('--learning-rate', action='store', type=float, default=0.1) argparser.add_argument('--clip', '-c', action='store', type=float, default=None) argparser.add_argument('--nb-words', action='store', type=int, default=None) argparser.add_argument('--seed', action='store', type=int, default=0) argparser.add_argument('--std-dev', action='store', type=float, default=0.01) argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token') argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token') argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token') argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus') argparser.add_argument('--initialize-embeddings', '-i', action='store', type=str, default=None, choices=['normal', 'uniform']) argparser.add_argument('--fixed-embeddings', '-f', action='store_true') argparser.add_argument('--normalize-embeddings', '-n', action='store_true') argparser.add_argument('--only-use-pretrained-embeddings', '-p', action='store_true', help='Only use pre-trained word embeddings') argparser.add_argument('--semi-sort', '-S', action='store_true') argparser.add_argument('--save', action='store', type=str, default=None) argparser.add_argument('--hard-save', action='store', type=str, default=None) argparser.add_argument('--restore', action='store', type=str, default=None) argparser.add_argument('--glove', action='store', type=str, default=None) argparser.add_argument('--rule00-weight', '--00', action='store', type=float, default=None) argparser.add_argument('--rule01-weight', '--01', action='store', type=float, default=None) argparser.add_argument('--rule02-weight', '--02', action='store', type=float, default=None) argparser.add_argument('--rule03-weight', '--03', action='store', type=float, default=None) for i in range(1, 9): argparser.add_argument('--rule{}-weight'.format(i), '-{}'.format(i), action='store', type=float, default=None) argparser.add_argument('--adversarial-batch-size', '-B', action='store', type=int, default=32) argparser.add_argument('--adversarial-pooling', '-P', default='max', choices=['sum', 'max', 'mean', 'logsumexp']) argparser.add_argument('--report', '-r', default=100, type=int, help='Number of batches between performance reports') argparser.add_argument('--report-loss', default=100, type=int, help='Number of batches between loss reports') argparser.add_argument('--eval', '-E', nargs='+', type=str, help='Evaluate on these additional sets') args = argparser.parse_args(argv) # Command line arguments train_path, valid_path, test_path = args.train, args.valid, args.test model_name = args.model optimizer_name = args.optimizer embedding_size = args.embedding_size representation_size = args.representation_size batch_size = args.batch_size nb_epochs = args.nb_epochs nb_discriminator_epochs = args.nb_discriminator_epochs dropout_keep_prob = args.dropout_keep_prob learning_rate = args.learning_rate clip_value = args.clip seed = args.seed std_dev = args.std_dev has_bos = args.has_bos has_eos = args.has_eos has_unk = args.has_unk is_lower = args.lower initialize_embeddings = args.initialize_embeddings is_fixed_embeddings = args.fixed_embeddings is_normalize_embeddings = args.normalize_embeddings is_only_use_pretrained_embeddings = args.only_use_pretrained_embeddings is_semi_sort = args.semi_sort logger.info('has_bos: {}, has_eos: {}, has_unk: {}'.format(has_bos, has_eos, has_unk)) logger.info('is_lower: {}, is_fixed_embeddings: {}, is_normalize_embeddings: {}' .format(is_lower, is_fixed_embeddings, is_normalize_embeddings)) logger.info('is_only_use_pretrained_embeddings: {}, is_semi_sort: {}' .format(is_only_use_pretrained_embeddings, is_semi_sort)) save_path = args.save hard_save_path = args.hard_save restore_path = args.restore glove_path = args.glove # Experimental RTE regularizers rule00_weight = args.rule00_weight rule01_weight = args.rule01_weight rule02_weight = args.rule02_weight rule03_weight = args.rule03_weight rule1_weight = args.rule1_weight rule2_weight = args.rule2_weight rule3_weight = args.rule3_weight rule4_weight = args.rule4_weight rule5_weight = args.rule5_weight rule6_weight = args.rule6_weight rule7_weight = args.rule7_weight rule8_weight = args.rule8_weight a_batch_size = args.adversarial_batch_size adversarial_pooling_name = args.adversarial_pooling name_to_adversarial_pooling = { 'sum': tf.reduce_sum, 'max': tf.reduce_max, 'mean': tf.reduce_mean, 'logsumexp': tf.reduce_logsumexp } report_interval = args.report report_loss_interval = args.report_loss eval_paths = args.eval np.random.seed(seed) rs = np.random.RandomState(seed) tf.set_random_seed(seed) logger.debug('Reading corpus ..') train_is, dev_is, test_is = util.SNLI.generate(train_path=train_path, valid_path=valid_path, test_path=test_path, is_lower=is_lower) logger.info('Train size: {}\tDev size: {}\tTest size: {}'.format(len(train_is), len(dev_is), len(test_is))) all_is = train_is + dev_is + test_is # Enumeration of tokens start at index=3: # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD bos_idx, eos_idx, unk_idx = 1, 2, 3 start_idx = 1 + (1 if has_bos else 0) + (1 if has_eos else 0) + (1 if has_unk else 0) if not restore_path: # Create a sequence of tokens containing all sentences in the dataset token_seq = [] for instance in all_is: token_seq += instance['sentence1_parse_tokens'] + instance['sentence2_parse_tokens'] token_set = set(token_seq) allowed_words = None if is_only_use_pretrained_embeddings: assert glove_path is not None logger.info('Loading GloVe words from {}'.format(glove_path)) assert os.path.isfile(glove_path) allowed_words = load_glove_words(path=glove_path, words=token_set) logger.info('Number of allowed words: {}'.format(len(allowed_words))) # Count the number of occurrences of each token token_counts = dict() for token in token_seq: if (allowed_words is None) or (token in allowed_words): if token not in token_counts: token_counts[token] = 0 token_counts[token] += 1 # Sort the tokens according to their frequency and lexicographic ordering sorted_vocabulary = sorted(token_counts.keys(), key=lambda t: (- token_counts[t], t)) index_to_token = {index: token for index, token in enumerate(sorted_vocabulary, start=start_idx)} else: with open('{}_index_to_token.p'.format(restore_path), 'rb') as f: index_to_token = pickle.load(f) token_to_index = {token: index for index, token in index_to_token.items()} entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2 label_to_index = { 'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx, } max_len = None optimizer_name_to_class = { 'adagrad': tf.train.AdagradOptimizer, 'adam': tf.train.AdamOptimizer } optimizer_class = optimizer_name_to_class[optimizer_name] assert optimizer_class optimizer = optimizer_class(learning_rate=learning_rate) args = dict(has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx, max_len=max_len) train_dataset = util.instances_to_dataset(train_is, token_to_index, label_to_index, **args) dev_dataset = util.instances_to_dataset(dev_is, token_to_index, label_to_index, **args) test_dataset = util.instances_to_dataset(test_is, token_to_index, label_to_index, **args) sentence1 = train_dataset['sentence1'] sentence1_length = train_dataset['sentence1_length'] sentence2 = train_dataset['sentence2'] sentence2_length = train_dataset['sentence2_length'] label = train_dataset['label'] sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1') sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2') sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length') sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length') clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph) clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph) label_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='label') token_set = set(token_to_index.keys()) vocab_size = max(token_to_index.values()) + 1 token_to_embedding = dict() if not restore_path: if glove_path: logger.info('Loading GloVe word embeddings from {}'.format(glove_path)) assert os.path.isfile(glove_path) token_to_embedding = load_glove(glove_path, token_set) discriminator_scope_name = 'discriminator' with tf.variable_scope(discriminator_scope_name): if initialize_embeddings == 'normal': logger.info('Initializing the embeddings with 𝓝(0, 1)') embedding_initializer = tf.random_normal_initializer(0.0, 1.0) elif initialize_embeddings == 'uniform': logger.info('Initializing the embeddings with 𝒰(-1, 1)') embedding_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0) else: logger.info('Initializing the embeddings with Xavier initialization') embedding_initializer = tf.contrib.layers.xavier_initializer() embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size], initializer=embedding_initializer, trainable=not is_fixed_embeddings) sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1) sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2) dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob') model_kwargs = dict( sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph, sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph, representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph) if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}: model_kwargs['init_std_dev'] = std_dev mode_name_to_class = { 'cbilstm': ConditionalBiLSTM, 'ff-dam': FeedForwardDAM, 'ff-damp': FeedForwardDAMP, 'ff-dams': FeedForwardDAMS, 'esim1': ESIMv1 } model_class = mode_name_to_class[model_name] assert model_class is not None model = model_class(**model_kwargs) logits = model() predictions = tf.argmax(logits, axis=1, name='predictions') losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label_ph) loss = tf.reduce_mean(losses) a_pooling_function = name_to_adversarial_pooling[adversarial_pooling_name] a_losses = None if rule00_weight: a_loss, a_losses = contradiction_symmetry_l2(model_class, model_kwargs, contradiction_idx=contradiction_idx, pooling_function=a_pooling_function, debug=True) loss += rule00_weight * a_loss if rule01_weight: a_loss, a_losses = contradiction_symmetry_l1(model_class, model_kwargs, contradiction_idx=contradiction_idx, pooling_function=a_pooling_function, debug=True) loss += rule01_weight * a_loss if rule02_weight: a_loss, a_losses = contradiction_kullback_leibler(model_class, model_kwargs, contradiction_idx=contradiction_idx, pooling_function=a_pooling_function, debug=True) loss += rule02_weight * a_loss if rule03_weight: a_loss, a_losses = contradiction_jensen_shannon(model_class, model_kwargs, contradiction_idx=contradiction_idx, pooling_function=a_pooling_function, debug=True) loss += rule03_weight * a_loss discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name) discriminator_init_op = tf.variables_initializer(discriminator_vars) trainable_discriminator_vars = list(discriminator_vars) if is_fixed_embeddings: trainable_discriminator_vars.remove(embedding_layer) discriminator_optimizer_scope_name = 'discriminator_optimizer' with tf.variable_scope(discriminator_optimizer_scope_name): if clip_value: gradients, v = zip(*optimizer.compute_gradients(loss, var_list=trainable_discriminator_vars)) gradients, _ = tf.clip_by_global_norm(gradients, clip_value) training_step = optimizer.apply_gradients(zip(gradients, v)) else: training_step = optimizer.minimize(loss, var_list=trainable_discriminator_vars) discriminator_optimizer_vars = tfutil.get_variables_in_scope(discriminator_optimizer_scope_name) discriminator_optimizer_init_op = tf.variables_initializer(discriminator_optimizer_vars) token_idx_ph = tf.placeholder(dtype=tf.int32, name='word_idx') token_embedding_ph = tf.placeholder(dtype=tf.float32, shape=[None], name='word_embedding') assign_token_embedding = embedding_layer[token_idx_ph, :].assign(token_embedding_ph) init_projection_steps = [] learning_projection_steps = [] if is_normalize_embeddings: embeddings_projection = constraints.unit_sphere(embedding_layer, norm=1.0) init_projection_steps += [embeddings_projection] if not is_fixed_embeddings: learning_projection_steps += [embeddings_projection] predictions_int = tf.cast(predictions, tf.int32) labels_int = tf.cast(label_ph, tf.int32) use_adversarial_training = rule1_weight or rule2_weight or rule3_weight or rule4_weight or rule5_weight or rule6_weight or rule7_weight or rule8_weight rule_id_to_placeholders = dict() if use_adversarial_training: adversary_scope_name = discriminator_scope_name with tf.variable_scope(adversary_scope_name): adversarial = AdversarialSets3(model_class=model_class, model_kwargs=model_kwargs, scope_name='adversary', entailment_idx=entailment_idx, contradiction_idx=contradiction_idx, neutral_idx=neutral_idx) adversary_loss = tf.constant(0.0, dtype=tf.float32) adversarial_pooling = name_to_adversarial_pooling[adversarial_pooling_name] def f(rule_idx): nb_sequences = adversarial.rule_nb_sequences(rule_idx) a_args, rule_placeholders = [], [] for seq_id in range(nb_sequences): a_sentence_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='a_rule{}_sentence{}'.format(rule_idx, seq_id)) a_sentence_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='a_rule{}_sentence{}_length'.format(rule_idx, seq_id)) a_clipped_sentence = tfutil.clip_sentence(a_sentence_ph, a_sentence_len_ph) a_sentence_embedding = tf.nn.embedding_lookup(embedding_layer, a_clipped_sentence) a_args += [a_sentence_embedding, a_sentence_len_ph] rule_placeholders += [(a_sentence_ph, a_sentence_len_ph)] rule_id_to_placeholders[rule_idx] = rule_placeholders rule_loss = adversarial.rule1_loss(*a_args) return rule_loss if rule1_weight: r_loss = f(1) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule2_weight: r_loss = f(2) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule3_weight: r_loss = f(3) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule4_weight: r_loss = f(4) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule5_weight: r_loss = f(5) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule6_weight: r_loss = f(6) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule7_weight: r_loss = f(7) adversary_loss += rule1_weight * adversarial_pooling(r_loss) if rule8_weight: r_loss = f(8) adversary_loss += rule1_weight * adversarial_pooling(r_loss) loss += adversary_loss logger.info('Adversarial Batch Size: {}'.format(a_batch_size)) a_feed_dict = dict() a_rs = np.random.RandomState(seed) d_sentence1, d_sentence2 = train_dataset['sentence1'], train_dataset['sentence2'] d_sentence1_len, d_sentence2_len = train_dataset['sentence1_length'], train_dataset['sentence2_length'] d_label = train_dataset['label'] nb_train_instances = d_label.shape[0] max_sentence_len = max(d_sentence1.shape[1], d_sentence2.shape[1]) d_sentence = np.zeros(shape=(nb_train_instances * 2, max_sentence_len), dtype=np.int) d_sentence[0:d_sentence1.shape[0], 0:d_sentence1.shape[1]] = d_sentence1 d_sentence[d_sentence1.shape[0]:, 0:d_sentence2.shape[1]] = d_sentence2 d_sentence_len = np.concatenate((d_sentence1_len, d_sentence2_len), axis=0) nb_train_sentences = d_sentence_len.shape[0] saver = tf.train.Saver(discriminator_vars + discriminator_optimizer_vars, max_to_keep=1) session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True with tf.Session(config=session_config) as session: logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters())) logger.info('Total Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters(var_list=discriminator_vars))) logger.info('Total Trainable Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters(var_list=trainable_discriminator_vars))) if restore_path: saver.restore(session, restore_path) else: session.run([discriminator_init_op, discriminator_optimizer_init_op]) # Initialising pre-trained embeddings logger.info('Initialising the embeddings pre-trained vectors ..') for token in token_to_embedding: token_idx, token_embedding = token_to_index[token], token_to_embedding[token] assert embedding_size == len(token_embedding) session.run(assign_token_embedding, feed_dict={ token_idx_ph: token_idx, token_embedding_ph: token_embedding }) logger.info('Done!') for adversary_projection_step in init_projection_steps: session.run([adversary_projection_step]) nb_instances = sentence1.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) best_dev_acc, best_test_acc = None, None discriminator_batch_counter = 0 for epoch in range(1, nb_epochs + 1): if use_adversarial_training: for rule_idx, rule_placeholders in rule_id_to_placeholders.items(): a_idxs = a_rs.choice(a_batch_size, nb_train_sentences) for a_sentence_ph, a_sentence_len_ph in rule_placeholders: # Select a random sentence from the training set a_sentence_batch = d_sentence[a_idxs] a_sentence_len_batch = d_sentence_len[a_idxs] a_feed_dict[a_sentence_ph] = a_sentence_batch a_feed_dict[a_sentence_len_ph] = a_sentence_len_batch for d_epoch in range(1, nb_discriminator_epochs + 1): order = rs.permutation(nb_instances) sentences1, sentences2 = sentence1[order], sentence2[order] sizes1, sizes2 = sentence1_length[order], sentence2_length[order] labels = label[order] if is_semi_sort: order = util.semi_sort(sizes1, sizes2) sentences1, sentences2 = sentence1[order], sentence2[order] sizes1, sizes2 = sentence1_length[order], sentence2_length[order] labels = label[order] loss_values, epoch_loss_values = [], [] for batch_idx, (batch_start, batch_end) in enumerate(batches): discriminator_batch_counter += 1 batch_sentences1, batch_sentences2 = sentences1[batch_start:batch_end], sentences2[batch_start:batch_end] batch_sizes1, batch_sizes2 = sizes1[batch_start:batch_end], sizes2[batch_start:batch_end] batch_labels = labels[batch_start:batch_end] batch_max_size1 = np.max(batch_sizes1) batch_max_size2 = np.max(batch_sizes2) batch_sentences1 = batch_sentences1[:, :batch_max_size1] batch_sentences2 = batch_sentences2[:, :batch_max_size2] batch_feed_dict = { sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1, sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2, label_ph: batch_labels, dropout_keep_prob_ph: dropout_keep_prob } # Adding the adversaries batch_feed_dict.update(a_feed_dict) _, loss_value = session.run([training_step, loss], feed_dict=batch_feed_dict) logger.debug('Epoch {0}/{1}/{2}\tLoss: {3}'.format(epoch, d_epoch, batch_idx, loss_value)) cur_batch_size = batch_sentences1.shape[0] loss_values += [loss_value / cur_batch_size] epoch_loss_values += [loss_value / cur_batch_size] for adversary_projection_step in learning_projection_steps: session.run([adversary_projection_step]) if discriminator_batch_counter % report_loss_interval == 0: logger.info('Epoch {0}/{1}/{2}\tLoss Stats: {3}'.format(epoch, d_epoch, batch_idx, stats(loss_values))) loss_values = [] if discriminator_batch_counter % report_interval == 0: accuracy_args = [sentence1_ph, sentence1_len_ph, sentence2_ph, sentence2_len_ph, label_ph, dropout_keep_prob_ph, predictions_int, labels_int, contradiction_idx, entailment_idx, neutral_idx, batch_size] dev_acc, _, _, _ = accuracy(session, dev_dataset, 'Dev', *accuracy_args) test_acc, _, _, _ = accuracy(session, test_dataset, 'Test', *accuracy_args) logger.info('Epoch {0}/{1}/{2}\tDev Acc: {3:.2f}\tTest Acc: {4:.2f}' .format(epoch, d_epoch, batch_idx, dev_acc * 100, test_acc * 100)) if best_dev_acc is None or dev_acc > best_dev_acc: best_dev_acc, best_test_acc = dev_acc, test_acc if save_path: with open('{}_index_to_token.p'.format(save_path), 'wb') as f: pickle.dump(index_to_token, f) saved_path = saver.save(session, save_path) logger.info('Model saved in file: {}'.format(saved_path)) logger.info('Epoch {0}/{1}/{2}\tBest Dev Accuracy: {3:.2f}\tBest Test Accuracy: {4:.2f}' .format(epoch, d_epoch, batch_idx, best_dev_acc * 100, best_test_acc * 100)) for eval_path in eval_paths: eval_path_acc = eutil.evaluate(session, eval_path, label_to_index, token_to_index, predictions, batch_size, sentence1_ph, sentence2_ph, sentence1_len_ph, sentence2_len_ph, dropout_keep_prob_ph, has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, is_lower=is_lower, bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx) logger.info('Epoch {0}/{1}/{2}\tAccuracy on {3} is {4}'.format(epoch, d_epoch, batch_idx, eval_path, eval_path_acc)) if a_losses is not None: t_feed_dict = a_feed_dict if len(t_feed_dict) == 0: t_feed_dict = { sentence1_ph: sentences1[:1024], sentence1_len_ph: sizes1[:1024], sentence2_ph: sentences2[:1024], sentence2_len_ph: sizes2[:1024], dropout_keep_prob_ph: 1.0 } a_losses_value = session.run(a_losses, feed_dict=t_feed_dict) a_input_idxs = np.argsort(- a_losses_value) for i in a_input_idxs[:10]: t_sentence1 = t_feed_dict[sentence1_ph][i] t_sentence2 = t_feed_dict[sentence2_ph][i] logger.info('[ {} / {} ] Sentence1: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence1 if x not in [0, 1, 2]]))) logger.info('[ {} / {} ] Sentence2: {}'.format(i, a_losses_value[i], ' '.join([index_to_token[x] for x in t_sentence2 if x not in [0, 1, 2]]))) logger.info('Epoch {0}/{1}\tEpoch Loss Stats: {2}'.format(epoch, d_epoch, stats(epoch_loss_values))) if hard_save_path: with open('{}_index_to_token.p'.format(hard_save_path), 'wb') as f: pickle.dump(index_to_token, f) hard_saved_path = saver.save(session, hard_save_path) logger.info('Model saved in file: {}'.format(hard_saved_path)) logger.info('Training finished.')
def main(argv): logger.info('Command line: {}'.format(' '.join(arg for arg in argv))) def fmt(prog): return argparse.HelpFormatter(prog, max_help_position=100, width=200) argparser = argparse.ArgumentParser('Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt) argparser.add_argument('--data', '-d', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz') argparser.add_argument('--model', '-m', action='store', type=str, default='ff-dam', choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1']) argparser.add_argument('--embedding-size', action='store', type=int, default=300) argparser.add_argument('--representation-size', action='store', type=int, default=200) argparser.add_argument('--batch-size', action='store', type=int, default=32) argparser.add_argument('--seed', action='store', type=int, default=0) argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token') argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token') argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token') argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus') argparser.add_argument('--restore', action='store', type=str, default=None) argparser.add_argument('--lm', action='store', type=str, default='models/lm/') argparser.add_argument('--corrupt', '-c', action='store_true', default=False, help='Corrupt examples so to maximise their inconsistency') argparser.add_argument('--most-violating', '-M', action='store_true', default=False, help='Show most violating examples') argparser.add_argument('--epsilon', '-e', action='store', type=float, default=1e-4) argparser.add_argument('--lambda-weight', '-L', action='store', type=float, default=1.0) argparser.add_argument('--inconsistency', '-i', action='store', type=str, default='contradiction') args = argparser.parse_args(argv) # Command line arguments data_path = args.data model_name = args.model embedding_size = args.embedding_size representation_size = args.representation_size batch_size = args.batch_size seed = args.seed has_bos = args.has_bos has_eos = args.has_eos has_unk = args.has_unk is_lower = args.lower restore_path = args.restore lm_path = args.lm is_corrupt = args.corrupt is_most_violating = args.most_violating epsilon = args.epsilon lambda_w = args.lambda_weight inconsistency_name = args.inconsistency iloss = None if inconsistency_name == 'contradiction': iloss = contradiction_loss elif inconsistency_name == 'neutral': iloss = neutral_loss elif inconsistency_name == 'entailment': iloss = entailment_loss assert iloss is not None np.random.seed(seed) tf.set_random_seed(seed) logger.debug('Reading corpus ..') data_is, _, _ = util.SNLI.generate(train_path=data_path, valid_path=None, test_path=None, is_lower=is_lower) logger.info('Data size: {}'.format(len(data_is))) # Enumeration of tokens start at index=3: # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD bos_idx, eos_idx, unk_idx = 1, 2, 3 global index_to_token, token_to_index with open('{}_index_to_token.p'.format(restore_path), 'rb') as f: index_to_token = pickle.load(f) index_to_token.update({0: '<PAD>', 1: '<BOS>', 2: '<UNK>'}) token_to_index = {token: index for index, token in index_to_token.items()} with open('{}/config.json'.format(lm_path), 'r') as f: config = json.load(f) seq_length = 1 lm_batch_size = batch_size rnn_size = config['rnn_size'] num_layers = config['num_layers'] label_to_index = { 'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx, } max_len = None args = dict( has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx, max_len=max_len) dataset = util.instances_to_dataset(data_is, token_to_index, label_to_index, **args) sentence1, sentence1_length = dataset['sentence1'], dataset['sentence1_length'] sentence2, sentence2_length = dataset['sentence2'], dataset['sentence2_length'] label = dataset['label'] clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph) clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph) vocab_size = max(token_to_index.values()) + 1 discriminator_scope_name = 'discriminator' with tf.variable_scope(discriminator_scope_name): embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size], trainable=False) sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1) sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2) model_kwargs = dict( sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph, sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph, representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph) if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}: model_kwargs['init_std_dev'] = 0.01 mode_name_to_class = { 'cbilstm': ConditionalBiLSTM, 'ff-dam': FeedForwardDAM, 'ff-damp': FeedForwardDAMP, 'ff-dams': FeedForwardDAMS, 'esim1': ESIMv1 } model_class = mode_name_to_class[model_name] assert model_class is not None model = model_class(**model_kwargs) logits = model() global probabilities probabilities = tf.nn.softmax(logits) predictions = tf.argmax(logits, axis=1, name='predictions') lm_scope_name = 'language_model' with tf.variable_scope(lm_scope_name): cell_fn = rnn.BasicLSTMCell cells = [cell_fn(rnn_size) for _ in range(num_layers)] global lm_cell lm_cell = rnn.MultiRNNCell(cells) global lm_input_data_ph, lm_targets_ph, lm_initial_state lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length], name='input_data') lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length], name='targets') lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32) with tf.variable_scope('rnnlm'): lm_W = tf.get_variable(name='W', shape=[rnn_size, vocab_size], initializer=tf.contrib.layers.xavier_initializer()) lm_b = tf.get_variable(name='b', shape=[vocab_size], initializer=tf.zeros_initializer()) lm_emb_lookup = tf.nn.embedding_lookup(embedding_layer, lm_input_data_ph) lm_emb_projection = tf.contrib.layers.fully_connected(inputs=lm_emb_lookup, num_outputs=rnn_size, weights_initializer=tf.contrib.layers.xavier_initializer(), biases_initializer=tf.zeros_initializer()) lm_inputs = tf.split(lm_emb_projection, seq_length, 1) lm_inputs = [tf.squeeze(input_, [1]) for input_ in lm_inputs] lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(decoder_inputs=lm_inputs, initial_state=lm_initial_state, cell=lm_cell, loop_function=None, scope='rnnlm') lm_output = tf.reshape(tf.concat(lm_outputs, 1), [-1, rnn_size]) lm_logits = tf.matmul(lm_output, lm_W) + lm_b lm_probabilities = tf.nn.softmax(lm_logits) global lm_loss, lm_cost, lm_final_state lm_loss = legacy_seq2seq.sequence_loss_by_example(logits=[lm_logits], targets=[tf.reshape(lm_targets_ph, [-1])], weights=[tf.ones([lm_batch_size * seq_length])]) lm_cost = tf.reduce_sum(lm_loss) / lm_batch_size / seq_length lm_final_state = lm_last_state discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name) lm_vars = tfutil.get_variables_in_scope(lm_scope_name) predictions_int = tf.cast(predictions, tf.int32) saver = tf.train.Saver(discriminator_vars, max_to_keep=1) lm_saver = tf.train.Saver(lm_vars, max_to_keep=1) session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True global session with tf.Session(config=session_config) as session: logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters())) saver.restore(session, restore_path) lm_ckpt = tf.train.get_checkpoint_state(lm_path) lm_saver.restore(session, lm_ckpt.model_checkpoint_path) nb_instances = sentence1.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) order = np.arange(nb_instances) sentences1 = sentence1[order] sentences2 = sentence2[order] sizes1 = sentence1_length[order] sizes2 = sentence2_length[order] labels = label[order] logger.info('Number of examples: {}'.format(labels.shape[0])) predictions_int_value = [] c_losses, e_losses, n_losses = [], [], [] for batch_idx, (batch_start, batch_end) in enumerate(batches): batch_sentences1 = sentences1[batch_start:batch_end] batch_sentences2 = sentences2[batch_start:batch_end] batch_sizes1 = sizes1[batch_start:batch_end] batch_sizes2 = sizes2[batch_start:batch_end] batch_feed_dict = { sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1, sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2, dropout_keep_prob_ph: 1.0 } batch_predictions_int = session.run(predictions_int, feed_dict=batch_feed_dict) predictions_int_value += batch_predictions_int.tolist() batch_c_loss = contradiction_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2) c_losses += batch_c_loss.tolist() batch_e_loss = entailment_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2) e_losses += batch_e_loss.tolist() batch_n_loss = neutral_loss(batch_sentences1, batch_sizes1, batch_sentences2, batch_sizes2) n_losses += batch_n_loss.tolist() if is_corrupt: search(sentences1=batch_sentences1, sizes1=batch_sizes1, sentences2=batch_sentences2, sizes2=batch_sizes2, batch_size=batch_size, epsilon=epsilon, lambda_w=lambda_w, inconsistency_loss=iloss) train_accuracy_value = np.mean(labels == np.array(predictions_int_value)) logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value)) if is_most_violating: c_ranking = np.argsort(np.array(c_losses))[::-1] assert c_ranking.shape[0] == len(data_is) for i in range(min(1024, c_ranking.shape[0])): idx = c_ranking[i] print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], c_losses[idx])) print('[C/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], c_losses[idx])) e_ranking = np.argsort(np.array(e_losses))[::-1] assert e_ranking.shape[0] == len(data_is) for i in range(min(1024, e_ranking.shape[0])): idx = e_ranking[i] print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], e_losses[idx])) print('[E/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], e_losses[idx])) n_ranking = np.argsort(np.array(n_losses))[::-1] assert n_ranking.shape[0] == len(data_is) for i in range(min(1024, n_ranking.shape[0])): idx = n_ranking[i] print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence1'], n_losses[idx])) print('[N/{}/{}] {} ({})'.format(i, idx, data_is[idx]['sentence2'], n_losses[idx]))
def search(sentences1, sizes1, sentences2, sizes2, lambda_w=0.1, inconsistency_loss=contradiction_loss, epsilon=1e-4, batch_size=32, nb_corruptions=1024, nb_words=256): loss_value, iloss_value, logperp_value = loss(sentences1=sentences1, sizes1=sizes1, sentences2=sentences2, sizes2=sizes2, lambda_w=lambda_w, inconsistency_loss=inconsistency_loss) # Find examples that have a nearly-zero inconsistency loss, and only work on making those more "adversarial" low_iloss_idxs = np.where(iloss_value < 1e-6)[0] for low_iloss_idx in low_iloss_idxs.tolist(): sentence1, size1 = sentences1[low_iloss_idx, :], sizes1[low_iloss_idx] sentence2, size2 = sentences2[low_iloss_idx, :], sizes2[low_iloss_idx] sample_loss_value, sample_iloss_value, sample_logperp_value = \ loss_value[low_iloss_idx], iloss_value[low_iloss_idx], logperp_value[low_iloss_idx] sentence1_str = ' '.join([index_to_token[tidx] for tidx in sentence1 if tidx != 0]) sentence2_str = ' '.join([index_to_token[tidx] for tidx in sentence2 if tidx != 0]) print('SENTENCE 1 (inconsistency loss: {} / log-perplexity: {}): {}' .format(sample_iloss_value, sample_logperp_value, sentence1_str)) print('SENTENCE 2 (inconsistency loss: {} / log-perplexity: {}): {}' .format(sample_iloss_value, sample_logperp_value, sentence2_str)) # Generate mutations that do not increase the perplexity too much, and maximise the inconsistency loss corruptions1, corruption_sizes1, corruptions2, corruption_sizes2 = \ corrupt(sentence1=sentence1, size1=size1, sentence2=sentence2, size2=size2, nb_corruptions=nb_corruptions, nb_words=nb_words) # Compute all relevant metrics for the corruptions nb_corruptions = corruptions1.shape[0] batches = make_batches(size=nb_corruptions, batch_size=batch_size) corruption_loss_values, corruption_iloss_values, corruption_logperp_values = [], [], [] for batch_start, batch_end in batches: batch_corruptions1 = corruptions1[batch_start:batch_end, :] batch_corruption_sizes1 = corruption_sizes1[batch_start:batch_end] batch_corruptions2 = corruptions2[batch_start:batch_end, :] batch_corruption_sizes2 = corruption_sizes2[batch_start:batch_end] batch_loss_values, batch_iloss_values, batch_logperp_values = \ loss(sentences1=batch_corruptions1, sizes1=batch_corruption_sizes1, sentences2=batch_corruptions2, sizes2=batch_corruption_sizes2, lambda_w=lambda_w, inconsistency_loss=inconsistency_loss) corruption_loss_values += batch_loss_values.tolist() corruption_iloss_values += batch_iloss_values.tolist() corruption_logperp_values += batch_logperp_values.tolist() corruption_loss_values = np.array(corruption_loss_values) corruption_iloss_values = np.array(corruption_iloss_values) corruption_logperp_values = np.array(corruption_logperp_values) # Sort the corruptions by their inconsistency loss: corruptions_order = np.argsort(corruption_iloss_values)[::-1] # Select corruptions that did not increase the log-perplexity too much low_perplexity_mask = corruption_logperp_values <= logperp_value[low_iloss_idx] + epsilon counter = 0 for idx in corruptions_order.tolist(): if idx in np.where(low_perplexity_mask)[0].tolist(): if counter < 10: corruption_str = ' '.join([index_to_token[tidx] for tidx in corruptions2[idx] if tidx != 0]) msg = '[{}] CORRUPTION 2 (inconsistency loss: {} / log-perplexity: {}): {}'\ .format(counter, corruption_iloss_values[idx], corruption_logperp_values[idx], corruption_str) print(msg) _sentence1 = np.array([sentence1]) _size1 = np.array([size1]) _sentence2 = np.array([corruptions2[idx]]) _size2 = np.array([size2]) probabilities_1 = inference(_sentence1, _size1, _sentence2, _size2) probabilities_2 = inference(_sentence2, _size2, _sentence1, _size1) msg = 'A -> B: {}\tB -> A: {}'.format(str(probabilities_1), str(probabilities_2)) print(msg) counter += 1 return
def evaluate(session, eval_path, label_to_index, token_to_index, predictions_op, batch_size, sentence1_ph, sentence2_ph, sentence1_len_ph, sentence2_len_ph, dropout_keep_prob_ph, has_bos=False, has_eos=False, has_unk=False, is_lower=False, bos_idx=1, eos_idx=2, unk_idx=3): sentence1_all = [] sentence2_all = [] gold_label_all = [] with gzip.open(eval_path, 'rb') as f: for line in f: decoded_line = line.decode('utf-8') if is_lower: decoded_line = decoded_line.lower() obj = json.loads(decoded_line) gold_label = obj['gold_label'] if gold_label in ['contradiction', 'entailment', 'neutral']: gold_label_all += [label_to_index[gold_label]] sentence1_parse = obj['sentence1_parse'] sentence2_parse = obj['sentence2_parse'] sentence1_tree = nltk.Tree.fromstring(sentence1_parse) sentence2_tree = nltk.Tree.fromstring(sentence2_parse) sentence1_tokens = sentence1_tree.leaves() sentence2_tokens = sentence2_tree.leaves() sentence1_ids = [] sentence2_ids = [] if has_bos: sentence1_ids += [bos_idx] sentence2_ids += [bos_idx] for token in sentence1_tokens: if token in token_to_index: sentence1_ids += [token_to_index[token]] elif has_unk: sentence1_ids += [unk_idx] for token in sentence2_tokens: if token in token_to_index: sentence2_ids += [token_to_index[token]] elif has_unk: sentence2_ids += [unk_idx] if has_eos: sentence1_ids += [eos_idx] sentence2_ids += [eos_idx] sentence1_all += [sentence1_ids] sentence2_all += [sentence2_ids] sentence1_all_len = [len(s) for s in sentence1_all] sentence2_all_len = [len(s) for s in sentence2_all] np_sentence1 = util.pad_sequences(sequences=sentence1_all) np_sentence2 = util.pad_sequences(sequences=sentence2_all) np_sentence1_len = np.array(sentence1_all_len) np_sentence2_len = np.array(sentence2_all_len) gold_label = np.array(gold_label_all) nb_instances = gold_label.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) predictions = [] for batch_idx, (batch_start, batch_end) in enumerate(batches): feed_dict = { sentence1_ph: np_sentence1[batch_start:batch_end], sentence2_ph: np_sentence2[batch_start:batch_end], sentence1_len_ph: np_sentence1_len[batch_start:batch_end], sentence2_len_ph: np_sentence2_len[batch_start:batch_end], dropout_keep_prob_ph: 1.0 } _predictions = session.run(predictions_op, feed_dict=feed_dict) predictions += _predictions.tolist() matches = np.array(predictions) == gold_label return np.mean(matches)
def main(argv): def formatter(prog): return argparse.HelpFormatter(prog, max_help_position=100, width=200) argparser = argparse.ArgumentParser('NLI Service', formatter_class=formatter) argparser.add_argument( '--model', '-m', action='store', type=str, default='cbilstm', choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1']) argparser.add_argument('--embedding-size', '-e', action='store', type=int, default=300) argparser.add_argument('--representation-size', '-r', action='store', type=int, default=200) argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token') argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token') argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token') argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus') argparser.add_argument('--restore', '-R', action='store', type=str, default=None, required=True) argparser.add_argument('--eval', action='store', default=None, type=str) argparser.add_argument('--batch-size', '-b', action='store', default=32, type=int) args = argparser.parse_args(argv) model_name = args.model embedding_size = args.embedding_size representation_size = args.representation_size has_bos = args.has_bos has_eos = args.has_eos has_unk = args.has_unk is_lower = args.lower restore_path = args.restore eval_path = args.eval batch_size = args.batch_size with open('{}_index_to_token.p'.format(restore_path), 'rb') as f: index_to_token = pickle.load(f) token_to_index = {token: index for index, token in index_to_token.items()} # Enumeration of tokens start at index=3: # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD bos_idx, eos_idx, unk_idx = 1, 2, 3 entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2 label_to_index = { 'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx, } vocab_size = max(token_to_index.values()) + 1 sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1') sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2') sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length') sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length') dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob') clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph) clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph) discriminator_scope_name = 'discriminator' with tf.variable_scope(discriminator_scope_name): embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size]) sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1) sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2) model_kwargs = dict(sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph, sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph, representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph) mode_name_to_class = { 'cbilstm': ConditionalBiLSTM, 'ff-dam': FeedForwardDAM, 'ff-damp': FeedForwardDAMP, 'ff-dams': FeedForwardDAMS, 'esim1': ESIMv1 } model_class = mode_name_to_class[model_name] assert model_class is not None model = model_class(**model_kwargs) logits = model() predictions_op = tf.argmax(logits, axis=1, name='predictions') discriminator_vars = tfutil.get_variables_in_scope( discriminator_scope_name) sentence1_all = [] sentence2_all = [] gold_label_all = [] with gzip.open(eval_path, 'rb') as f: for line in f: decoded_line = line.decode('utf-8') if is_lower: decoded_line = decoded_line.lower() obj = json.loads(decoded_line) gold_label = obj['gold_label'] if gold_label in ['contradiction', 'entailment', 'neutral']: gold_label_all += [label_to_index[gold_label]] sentence1_parse = obj['sentence1_parse'] sentence2_parse = obj['sentence2_parse'] sentence1_tree = nltk.Tree.fromstring(sentence1_parse) sentence2_tree = nltk.Tree.fromstring(sentence2_parse) sentence1_tokens = sentence1_tree.leaves() sentence2_tokens = sentence2_tree.leaves() sentence1_ids = [] sentence2_ids = [] if has_bos: sentence1_ids += [bos_idx] sentence2_ids += [bos_idx] for token in sentence1_tokens: if token in token_to_index: sentence1_ids += [token_to_index[token]] elif has_unk: sentence1_ids += [unk_idx] for token in sentence2_tokens: if token in token_to_index: sentence2_ids += [token_to_index[token]] elif has_unk: sentence2_ids += [unk_idx] if has_eos: sentence1_ids += [eos_idx] sentence2_ids += [eos_idx] sentence1_all += [sentence1_ids] sentence2_all += [sentence2_ids] sentence1_all_len = [len(s) for s in sentence1_all] sentence2_all_len = [len(s) for s in sentence2_all] np_sentence1 = util.pad_sequences(sequences=sentence1_all) np_sentence2 = util.pad_sequences(sequences=sentence2_all) np_sentence1_len = np.array(sentence1_all_len) np_sentence2_len = np.array(sentence2_all_len) gold_label = np.array(gold_label_all) with tf.Session() as session: saver = tf.train.Saver(discriminator_vars, max_to_keep=1) saver.restore(session, restore_path) from inferbeddings.models.training.util import make_batches nb_instances = gold_label.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) predictions = [] for batch_idx, (batch_start, batch_end) in enumerate(batches): feed_dict = { sentence1_ph: np_sentence1[batch_start:batch_end], sentence2_ph: np_sentence2[batch_start:batch_end], sentence1_len_ph: np_sentence1_len[batch_start:batch_end], sentence2_len_ph: np_sentence2_len[batch_start:batch_end], dropout_keep_prob_ph: 1.0 } _predictions = session.run(predictions_op, feed_dict=feed_dict) predictions += _predictions.tolist() matches = np.array(predictions) == gold_label print(np.mean(matches))
def main(argv): logger.info('Command line: {}'.format(' '.join(arg for arg in argv))) def fmt(prog): return argparse.HelpFormatter(prog, max_help_position=100, width=200) argparser = argparse.ArgumentParser( 'Regularising RTE via Adversarial Sets Regularisation', formatter_class=fmt) argparser.add_argument('--data', '-d', action='store', type=str, default='data/snli/snli_1.0_train.jsonl.gz') argparser.add_argument( '--model', '-m', action='store', type=str, default='ff-dam', choices=['cbilstm', 'ff-dam', 'ff-damp', 'ff-dams', 'esim1']) argparser.add_argument('--embedding-size', action='store', type=int, default=300) argparser.add_argument('--representation-size', action='store', type=int, default=200) argparser.add_argument('--batch-size', action='store', type=int, default=32) argparser.add_argument('--seed', action='store', type=int, default=0) argparser.add_argument('--has-bos', action='store_true', default=False, help='Has <Beginning Of Sentence> token') argparser.add_argument('--has-eos', action='store_true', default=False, help='Has <End Of Sentence> token') argparser.add_argument('--has-unk', action='store_true', default=False, help='Has <Unknown Word> token') argparser.add_argument('--lower', '-l', action='store_true', default=False, help='Lowercase the corpus') argparser.add_argument('--restore', action='store', type=str, default=None) argparser.add_argument('--check-transitivity', '-x', action='store_true', default=False) args = argparser.parse_args(argv) # Command line arguments data_path = args.data model_name = args.model embedding_size = args.embedding_size representation_size = args.representation_size batch_size = args.batch_size seed = args.seed has_bos = args.has_bos has_eos = args.has_eos has_unk = args.has_unk is_lower = args.lower restore_path = args.restore is_check_transitivity = args.check_transitivity np.random.seed(seed) rs = np.random.RandomState(seed) tf.set_random_seed(seed) logger.debug('Reading corpus ..') data_is, _, _ = util.SNLI.generate(train_path=data_path, valid_path=None, test_path=None, is_lower=is_lower) logger.info('Data size: {}'.format(len(data_is))) # Enumeration of tokens start at index=3: # index=0 PADDING, index=1 START_OF_SENTENCE, index=2 END_OF_SENTENCE, index=3 UNKNOWN_WORD bos_idx, eos_idx, unk_idx = 1, 2, 3 with open('{}_index_to_token.p'.format(restore_path), 'rb') as f: index_to_token = pickle.load(f) token_to_index = {token: index for index, token in index_to_token.items()} entailment_idx, neutral_idx, contradiction_idx = 0, 1, 2 label_to_index = { 'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx, } max_len = None args = dict(has_bos=has_bos, has_eos=has_eos, has_unk=has_unk, bos_idx=bos_idx, eos_idx=eos_idx, unk_idx=unk_idx, max_len=max_len) dataset = util.instances_to_dataset(data_is, token_to_index, label_to_index, **args) sentence1 = dataset['sentence1'] sentence1_length = dataset['sentence1_length'] sentence2 = dataset['sentence2'] sentence2_length = dataset['sentence2_length'] label = dataset['label'] sentence1_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence1') sentence2_ph = tf.placeholder(dtype=tf.int32, shape=[None, None], name='sentence2') sentence1_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence1_length') sentence2_len_ph = tf.placeholder(dtype=tf.int32, shape=[None], name='sentence2_length') clipped_sentence1 = tfutil.clip_sentence(sentence1_ph, sentence1_len_ph) clipped_sentence2 = tfutil.clip_sentence(sentence2_ph, sentence2_len_ph) token_set = set(token_to_index.keys()) vocab_size = max(token_to_index.values()) + 1 discriminator_scope_name = 'discriminator' with tf.variable_scope(discriminator_scope_name): embedding_layer = tf.get_variable('embeddings', shape=[vocab_size, embedding_size], trainable=False) sentence1_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence1) sentence2_embedding = tf.nn.embedding_lookup(embedding_layer, clipped_sentence2) dropout_keep_prob_ph = tf.placeholder(tf.float32, name='dropout_keep_prob') model_kwargs = dict(sequence1=sentence1_embedding, sequence1_length=sentence1_len_ph, sequence2=sentence2_embedding, sequence2_length=sentence2_len_ph, representation_size=representation_size, dropout_keep_prob=dropout_keep_prob_ph) if model_name in {'ff-dam', 'ff-damp', 'ff-dams'}: model_kwargs['init_std_dev'] = 0.01 mode_name_to_class = { 'cbilstm': ConditionalBiLSTM, 'ff-dam': FeedForwardDAM, 'ff-damp': FeedForwardDAMP, 'ff-dams': FeedForwardDAMS, 'esim1': ESIMv1 } model_class = mode_name_to_class[model_name] assert model_class is not None model = model_class(**model_kwargs) logits = model() probabilities = tf.nn.softmax(logits) predictions = tf.argmax(logits, axis=1, name='predictions') discriminator_vars = tfutil.get_variables_in_scope( discriminator_scope_name) trainable_discriminator_vars = list(discriminator_vars) predictions_int = tf.cast(predictions, tf.int32) saver = tf.train.Saver(discriminator_vars, max_to_keep=1) session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True with tf.Session(config=session_config) as session: logger.info('Total Parameters: {}'.format( tfutil.count_trainable_parameters())) logger.info('Total Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters(var_list=discriminator_vars))) logger.info('Total Trainable Discriminator Parameters: {}'.format( tfutil.count_trainable_parameters( var_list=trainable_discriminator_vars))) saver.restore(session, restore_path) nb_instances = sentence1.shape[0] batches = make_batches(size=nb_instances, batch_size=batch_size) order = np.arange(nb_instances) sentences1 = sentence1[order] sentences2 = sentence2[order] sizes1 = sentence1_length[order] sizes2 = sentence2_length[order] labels = label[order] a_predictions_int_value = [] b_predictions_int_value = [] a_probabilities_value = [] b_probabilities_value = [] for batch_idx, (batch_start, batch_end) in tqdm(list(enumerate(batches))): batch_sentences1 = sentences1[batch_start:batch_end] batch_sentences2 = sentences2[batch_start:batch_end] batch_sizes1 = sizes1[batch_start:batch_end] batch_sizes2 = sizes2[batch_start:batch_end] batch_a_feed_dict = { sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1, sentence2_ph: batch_sentences2, sentence2_len_ph: batch_sizes2, dropout_keep_prob_ph: 1.0 } batch_a_predictions_int_value, batch_a_probabilities_value = session.run( [predictions_int, probabilities], feed_dict=batch_a_feed_dict) a_predictions_int_value += batch_a_predictions_int_value.tolist() for i in range(batch_a_probabilities_value.shape[0]): a_probabilities_value += [{ 'neutral': batch_a_probabilities_value[i, neutral_idx], 'contradiction': batch_a_probabilities_value[i, contradiction_idx], 'entailment': batch_a_probabilities_value[i, entailment_idx] }] batch_b_feed_dict = { sentence1_ph: batch_sentences2, sentence1_len_ph: batch_sizes2, sentence2_ph: batch_sentences1, sentence2_len_ph: batch_sizes1, dropout_keep_prob_ph: 1.0 } batch_b_predictions_int_value, batch_b_probabilities_value = session.run( [predictions_int, probabilities], feed_dict=batch_b_feed_dict) b_predictions_int_value += batch_b_predictions_int_value.tolist() for i in range(batch_b_probabilities_value.shape[0]): b_probabilities_value += [{ 'neutral': batch_b_probabilities_value[i, neutral_idx], 'contradiction': batch_b_probabilities_value[i, contradiction_idx], 'entailment': batch_b_probabilities_value[i, entailment_idx] }] for i, instance in enumerate(data_is): instance.update({ 'a': a_probabilities_value[i], 'b': b_probabilities_value[i], }) logger.info('Number of examples: {}'.format(labels.shape[0])) train_accuracy_value = np.mean( labels == np.array(a_predictions_int_value)) logger.info('Accuracy: {0:.4f}'.format(train_accuracy_value)) s1s2_con = (np.array(a_predictions_int_value) == contradiction_idx) s2s1_con = (np.array(b_predictions_int_value) == contradiction_idx) assert s1s2_con.shape == s2s1_con.shape s1s2_ent = (np.array(a_predictions_int_value) == entailment_idx) s2s1_ent = (np.array(b_predictions_int_value) == entailment_idx) s1s2_neu = (np.array(a_predictions_int_value) == neutral_idx) s2s1_neu = (np.array(b_predictions_int_value) == neutral_idx) a = np.logical_xor(s1s2_con, s2s1_con) logger.info('(S1 contradicts S2) XOR (S2 contradicts S1): {0}'.format( a.sum())) b = s1s2_con logger.info('(S1 contradicts S2): {0}'.format(b.sum())) c = np.logical_and(s1s2_con, np.logical_not(s2s1_con)) logger.info( '(S1 contradicts S2) AND NOT(S2 contradicts S1): {0} ({1:.4f})'. format(c.sum(), c.sum() / b.sum())) with open('c.p', 'wb') as f: tmp = [data_is[i] for i in np.where(c)[0].tolist()] pickle.dump(tmp, f) d = s1s2_ent logger.info('(S1 entailment S2): {0}'.format(d.sum())) e = np.logical_and(s1s2_ent, s2s1_con) logger.info( '(S1 entailment S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format( e.sum(), e.sum() / d.sum())) with open('e.p', 'wb') as f: tmp = [data_is[i] for i in np.where(e)[0].tolist()] pickle.dump(tmp, f) f = s1s2_con logger.info('(S1 neutral S2): {0}'.format(f.sum())) g = np.logical_and(s1s2_neu, s2s1_con) logger.info( '(S1 neutral S2) AND (S2 contradicts S1): {0} ({1:.4f})'.format( g.sum(), g.sum() / f.sum())) with open('g.p', 'wb') as f: tmp = [data_is[i] for i in np.where(g)[0].tolist()] pickle.dump(tmp, f) if is_check_transitivity: # Find S1, S2 such that entails(S1, S2) print(type(s1s2_ent)) c_predictions_int_value = [] c_probabilities_value = [] d_predictions_int_value = [] d_probabilities_value = [] # Find candidate S3 sentences order = np.arange(nb_instances) sentences3 = sentence2[order] sizes3 = sentence2_length[order] for batch_idx, (batch_start, batch_end) in tqdm(list(enumerate(batches))): batch_sentences2 = sentences2[batch_start:batch_end] batch_sentences3 = sentences3[batch_start:batch_end] batch_sizes2 = sizes2[batch_start:batch_end] batch_sizes3 = sizes3[batch_start:batch_end] batch_c_feed_dict = { sentence1_ph: batch_sentences2, sentence1_len_ph: batch_sizes2, sentence2_ph: batch_sentences3, sentence2_len_ph: batch_sizes3, dropout_keep_prob_ph: 1.0 } batch_c_predictions_int_value, batch_c_probabilities_value = session.run( [predictions_int, probabilities], feed_dict=batch_c_feed_dict) c_predictions_int_value += batch_c_predictions_int_value.tolist( ) for i in range(batch_c_probabilities_value.shape[0]): c_probabilities_value += [{ 'neutral': batch_c_probabilities_value[i, neutral_idx], 'contradiction': batch_c_probabilities_value[i, contradiction_idx], 'entailment': batch_c_probabilities_value[i, entailment_idx] }] batch_sentences1 = sentences1[batch_start:batch_end] batch_sentences3 = sentences3[batch_start:batch_end] batch_sizes1 = sizes1[batch_start:batch_end] batch_sizes3 = sizes3[batch_start:batch_end] batch_d_feed_dict = { sentence1_ph: batch_sentences1, sentence1_len_ph: batch_sizes1, sentence2_ph: batch_sentences3, sentence2_len_ph: batch_sizes3, dropout_keep_prob_ph: 1.0 } batch_d_predictions_int_value, batch_d_probabilities_value = session.run( [predictions_int, probabilities], feed_dict=batch_d_feed_dict) d_predictions_int_value += batch_d_predictions_int_value.tolist( ) for i in range(batch_d_probabilities_value.shape[0]): d_probabilities_value += [{ 'neutral': batch_d_probabilities_value[i, neutral_idx], 'contradiction': batch_d_probabilities_value[i, contradiction_idx], 'entailment': batch_d_probabilities_value[i, entailment_idx] }] s2s3_ent = (np.array(c_predictions_int_value) == entailment_idx) s1s3_ent = (np.array(c_predictions_int_value) == entailment_idx) body = np.logical_and(s1s2_ent, s2s3_ent) body_not_head = np.logical_and(body, np.logical_not(s1s3_ent)) logger.info('(S1 entails S2) and (S2 entails S3): {0}'.format( body.sum())) logger.info('body AND NOT(head): {0} ({1:.4f})'.format( body_not_head.sum(), body_not_head.sum() / body.sum())) with open('h.p', 'wb') as f: tmp = [] for idx in np.where(body_not_head)[0].tolist(): s1 = data_is[idx]['sentence1'] s2 = data_is[idx]['sentence2'] s3 = data_is[order[idx]]['sentence2'] tmp += [{'s1': s1, 's2': s2, 's3': s3}] pickle.dump(tmp, f)