def build_model(self): with tf.device('/cpu:*'): self.context_raw = tf.placeholder(tf.string, [None, None], name='placeholder_context_raw') # source vectors of unknown size self.question_raw = tf.placeholder(tf.string, [None, None]) # target vectors of unknown size self.answer_raw = tf.placeholder(tf.string, [None, None], name='placeholder_ans_raw') # target vectors of unknown size self.context_ids = tf.placeholder(tf.int32, [None, None], name='placeholder_context_ids') # source vectors of unknown size self.context_copy_ids = tf.placeholder(tf.int32, [None, None], name='placeholder_context_cp_ids') # source vectors of unknown size self.context_length = tf.placeholder(tf.int32, [None], name='placeholder_context_len') # size(source) self.context_vocab_size = tf.placeholder(tf.int32, [None], name='placeholder_context_vocsize') # size(source_vocab) self.question_ids = tf.placeholder(tf.int32, [None, None]) # target vectors of unknown size self.question_onehot = tf.placeholder(tf.float32, [None, None, None]) # target vectors of unknown size self.question_length = tf.placeholder(tf.int32, [None]) # size(source) self.answer_ids = tf.placeholder(tf.int32, [None, None], name='placeholder_ans_ids') # target vectors of unknown size self.answer_length = tf.placeholder(tf.int32, [None], name='placeholder_ans_len') self.answer_locs = tf.placeholder(tf.int32, [None,None], name='placeholder_ans_len') self.original_ix = tf.placeholder(tf.int32, [None]) # unused - gives the index of the input in the unshuffled dataset self.hide_answer_in_copy = tf.placeholder_with_default(False, (),"hide_answer_in_copy") self.context_in = (self.context_raw, self.context_ids, self.context_copy_ids, self.context_length, self.context_vocab_size) self.question_in = (self.question_raw, self.question_ids, self.question_onehot, self.question_length) self.answer_in = (self.answer_raw, self.answer_ids, self.answer_length, self.answer_locs) self.input_batch = (self.context_in, self.question_in, self.answer_in, self.original_ix) curr_batch_size = tf.shape(self.answer_ids)[0] with tf.variable_scope('input_pipeline'): # build teacher output - coerce to vocab and pad with SOS/EOS # also build output for loss - one hot over vocab+context self.question_teach = tf.concat([tf.tile(tf.constant(self.vocab[SOS], shape=[1, 1]), [curr_batch_size,1]), self.question_ids[:,:-1]], axis=1) # self.question_teach_oh = tf.concat([tf.one_hot(tf.tile(tf.constant(self.vocab[SOS], shape=[1, 1]), [curr_batch_size,1]), depth=len(self.vocab)+FLAGS.max_copy_size), self.question_onehot[:,:-1,:]], axis=1) # init embeddings with tf.device('/cpu:*'): glove_embeddings = loader.load_glove(FLAGS.data_path, d=FLAGS.embedding_size) embeddings_init = tf.constant(loader.get_embeddings(self.vocab, glove_embeddings, D=FLAGS.embedding_size)) self.embeddings = tf.get_variable('word_embeddings', initializer=embeddings_init, dtype=tf.float32) if FLAGS.loc_embeddings: self.copy_embeddings = tf.get_variable('copy_embeddings', shape=(FLAGS.max_copy_size, FLAGS.embedding_size), dtype=tf.float32) else: self.copy_embeddings = tf.nn.embedding_lookup(self.embeddings, tf.tile([self.vocab[OOV]], [FLAGS.max_copy_size])) self.full_embeddings = tf.concat([self.embeddings, self.copy_embeddings], axis=0) assert self.embeddings.shape == [len(self.vocab), self.embedding_size] # this uses a load of memory, dont create unless it's actually needed if self.use_embedding_loss: self.glove_vocab = loader.get_glove_vocab(FLAGS.data_path, size=-1, d=FLAGS.embedding_size, filter_to_squad=True) extended_embeddings_init = tf.constant(loader.get_embeddings(self.glove_vocab, glove_embeddings, D=FLAGS.embedding_size)) self.extended_embeddings = tf.get_variable('full_word_embeddings', initializer=extended_embeddings_init, dtype=tf.float32, trainable=False) self.question_teach_ids = tf.concat([tf.tile(tf.constant(self.vocab[SOS], shape=[1, 1]), [curr_batch_size, 1]), self.question_ids[:, :-1]], axis=1) self.question_teach_embedded = tf.nn.embedding_lookup(self.full_embeddings, self.question_teach_ids) del glove_embeddings # First, coerce them to the shortlist vocab. Then embed self.context_coerced = tf.where(tf.greater_equal(self.context_ids, len(self.vocab)), tf.tile(tf.constant([[self.vocab[OOV]]]), tf.shape(self.context_ids)), self.context_ids) self.context_embedded = tf.nn.embedding_lookup(self.embeddings, self.context_coerced) self.answer_coerced = tf.where(tf.greater_equal(self.answer_ids, len(self.vocab)), tf.tile(tf.constant([[self.vocab[OOV]]]), tf.shape(self.answer_ids)), self.answer_ids) self.answer_embedded = tf.nn.embedding_lookup(self.embeddings, self.answer_coerced) # batch x seq x embed # Is context token in answer? max_context_len = tf.reduce_max(self.context_length) context_ix = tf.tile(tf.expand_dims(tf.range(max_context_len),axis=0), [curr_batch_size,1]) gt_start = tf.greater_equal(context_ix, tf.tile(tf.expand_dims(self.answer_locs[:,0],axis=1), [1, max_context_len])) lt_end = tf.less(context_ix, tf.tile(tf.expand_dims(self.answer_locs[:,0]+self.answer_length,axis=1), [1, max_context_len])) self.in_answer_feature = tf.expand_dims(tf.cast(tf.logical_and(gt_start, lt_end), tf.float32),axis=2) embed_feats =[self.context_embedded, self.in_answer_feature] if FLAGS.begin_ans_feat: self.begin_ans_feat = tf.expand_dims(tf.one_hot(self.answer_locs[:,0], depth=max_context_len), axis=2) embed_feats.append(self.begin_ans_feat) # augment embedding self.context_embedded = tf.concat(embed_feats, axis=2) # Build encoder for context # Build RNN cell for encoder with tf.variable_scope('context_encoder'): context_encoder_cell_fwd = tf.contrib.rnn.DropoutWrapper( cell=tf.nn.rnn_cell.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(num_units=self.context_encoder_units) for n in range(FLAGS.ctxt_encoder_depth)]), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.embedding_size+1+(1 if FLAGS.begin_ans_feat else 0), variational_recurrent=True, dtype=tf.float32) context_encoder_cell_bwd = tf.contrib.rnn.DropoutWrapper( cell=tf.nn.rnn_cell.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(num_units=self.context_encoder_units) for n in range(FLAGS.ctxt_encoder_depth)]), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.embedding_size+1+(1 if FLAGS.begin_ans_feat else 0), variational_recurrent=True, dtype=tf.float32) # Unroll encoder RNN context_encoder_output_parts, context_encoder_state = tf.nn.bidirectional_dynamic_rnn( context_encoder_cell_fwd, context_encoder_cell_bwd, self.context_embedded, sequence_length=self.context_length, dtype=tf.float32) self.context_encoder_output = tf.concat([context_encoder_output_parts[0], context_encoder_output_parts[1]], axis=2) # batch x seq x 2*units # Build encoder for mean(encoder(context)) + answer # Build RNN cell for encoder with tf.variable_scope('a_encoder'): # To build the "extractive condition encoding" input, take embeddings of answer words concated with encoded context at that position # This is super involved! Even though we have the right indices we have to do a LOT of massaging to get them in the right shape seq_length = tf.reduce_max(self.answer_length) self.indices = self.answer_locs # cap the indices to be valid self.indices = tf.minimum(self.indices, tf.tile(tf.expand_dims(self.context_length-1,axis=1),[1,tf.reduce_max(self.answer_length)])) batch_ix = tf.expand_dims(tf.transpose(tf.tile(tf.expand_dims(tf.range(curr_batch_size),axis=0),[seq_length,1]),[1,0]),axis=2) full_ix = tf.concat([batch_ix,tf.expand_dims(self.indices,axis=-1)], axis=2) self.context_condition_encoding = tf.gather_nd(self.context_encoder_output, full_ix) self.full_condition_encoding = tf.concat([self.context_condition_encoding, self.answer_embedded], axis=2) a_encoder_cell_fwd = tf.contrib.rnn.DropoutWrapper(cell=tf.nn.rnn_cell.MultiRNNCell([ tf.contrib.rnn.BasicLSTMCell(num_units=self.answer_encoder_units) for n in range(FLAGS.ans_encoder_depth)]), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.context_encoder_units*2+self.embedding_size, variational_recurrent=True, dtype=tf.float32) a_encoder_cell_bwd = tf.contrib.rnn.DropoutWrapper(cell=tf.nn.rnn_cell.MultiRNNCell([ tf.contrib.rnn.BasicLSTMCell(num_units=self.answer_encoder_units) for n in range(FLAGS.ans_encoder_depth)]), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.context_encoder_units*2+self.embedding_size, variational_recurrent=True, dtype=tf.float32) # Unroll encoder RNN a_encoder_output_parts, a_encoder_state_parts = tf.nn.bidirectional_dynamic_rnn( a_encoder_cell_fwd, a_encoder_cell_bwd, self.full_condition_encoding, sequence_length=self.answer_length, dtype=tf.float32) # This is actually wrong! It should take last element of the fwd RNN, and first element of the bwd RNN. It doesn't seem to matter in experiments, and fixing it would be a breaking change. # self.a_encoder_final_state = tf.concat([ops.get_last_from_seq(a_encoder_output_parts[0], self.answer_length-1), ops.get_last_from_seq(a_encoder_output_parts[1], self.answer_length-1)], axis=1) # Fixed! self.a_encoder_final_state = tf.concat([ops.get_last_from_seq(a_encoder_output_parts[0], self.answer_length-1), a_encoder_output_parts[1][:,0,:]], axis=1) # build init state with tf.variable_scope('decoder_initial_state'): L = tf.get_variable('decoder_L', [self.context_encoder_units*2, self.context_encoder_units*2], initializer=tf.glorot_uniform_initializer(), dtype=tf.float32) W0 = tf.get_variable('decoder_W0', [self.context_encoder_units*2, self.decoder_units], initializer=tf.glorot_uniform_initializer(), dtype=tf.float32) b0 = tf.get_variable('decoder_b0', [self.decoder_units], initializer=tf.zeros_initializer(), dtype=tf.float32) # This is a bit cheeky - this should be injected by the more advanced model. Consider refactoring into separate methods then overloading the one that handles this if self.advanced_condition_encoding: self.context_encoding = self.a_encoder_final_state # this would be the maluuba model else: self.context_encoding = tf.reduce_mean(self.context_condition_encoding, axis=1) # this is the baseline model r = tf.reduce_sum(self.context_encoder_output, axis=1)/tf.expand_dims(tf.cast(self.context_length,tf.float32),axis=1) + tf.matmul(self.context_encoding,L) self.s0 = tf.nn.tanh(tf.matmul(r,W0) + b0) if self.advanced_condition_encoding and FLAGS.full_context_encoding: # for Maluuba model, decoder inputs are concat of context and answer encoding # Strictly speaking this is still wrong - the attn mech uses only the context encoding self.context_encoder_output = tf.concat([self.context_encoder_output, tf.tile(tf.expand_dims(self.a_encoder_final_state,axis=1),[1,max_context_len,1])], axis=2) # decode with tf.variable_scope('decoder_init'): beam_memory = tf.contrib.seq2seq.tile_batch( self.context_encoder_output, multiplier=FLAGS.beam_width ) beam_memory_sequence_length = tf.contrib.seq2seq.tile_batch( self.context_length, multiplier=FLAGS.beam_width) s0_tiled = tf.contrib.seq2seq.tile_batch( self.s0, multiplier=FLAGS.beam_width) beam_init_state = tf.contrib.rnn.LSTMStateTuple(s0_tiled, tf.contrib.seq2seq.tile_batch(tf.zeros([curr_batch_size, self.decoder_units]), multiplier=FLAGS.beam_width)) train_memory = self.context_encoder_output train_memory_sequence_length = self.context_length train_init_state = tf.contrib.rnn.LSTMStateTuple(self.s0, tf.zeros([curr_batch_size, self.decoder_units])) with tf.variable_scope('attn_mech') as scope: train_attention_mechanism = copy_attention_wrapper.BahdanauAttention( num_units=self.decoder_units, memory=train_memory, memory_sequence_length=train_memory_sequence_length, name='bahdanau_attn') if FLAGS.separate_copy_mech: train_copy_mechanism = copy_attention_wrapper.BahdanauAttention( num_units=self.decoder_units, memory=train_memory, memory_sequence_length=train_memory_sequence_length, name='bahdanau_attn_copy') else: train_copy_mechanism = train_attention_mechanism with tf.variable_scope('decoder_cell'): train_decoder_cell = tf.contrib.rnn.DropoutWrapper( cell=tf.contrib.rnn.BasicLSTMCell(num_units=self.decoder_units), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.embedding_size+self.decoder_units//2, variational_recurrent=True, dtype=tf.float32) train_decoder_cell = copy_attention_wrapper.CopyAttentionWrapper(train_decoder_cell, train_attention_mechanism, attention_layer_size=self.decoder_units / 2, alignment_history=False, copy_mechanism=train_copy_mechanism, output_attention=True, initial_cell_state=train_init_state, name='copy_attention_wrapper') train_init_state = train_decoder_cell.zero_state(curr_batch_size*(1), tf.float32).clone(cell_state=train_init_state) # copy_mechanism = copy_attention_wrapper.BahdanauAttention( # num_units=self.decoder_units, memory=memory, # memory_sequence_length=memory_sequence_length) with tf.variable_scope('attn_mech', reuse=True) as scope: scope.reuse_variables() beam_attention_mechanism = copy_attention_wrapper.BahdanauAttention( num_units=self.decoder_units, memory=beam_memory, memory_sequence_length=beam_memory_sequence_length, name='bahdanau_attn') if FLAGS.separate_copy_mech: beam_copy_mechanism = copy_attention_wrapper.BahdanauAttention( num_units=self.decoder_units, memory=beam_memory, memory_sequence_length=beam_memory_sequence_length, name='bahdanau_attn_copy') else: beam_copy_mechanism = beam_attention_mechanism with tf.variable_scope('decoder_cell', reuse=True): beam_decoder_cell = tf.contrib.rnn.DropoutWrapper( cell=tf.contrib.rnn.BasicLSTMCell(num_units=self.decoder_units), input_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), state_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), output_keep_prob=(tf.cond(self.is_training,lambda: 1.0 - self.dropout_prob,lambda: 1.)), input_size=self.embedding_size+self.decoder_units//2, variational_recurrent=True, dtype=tf.float32) beam_decoder_cell = copy_attention_wrapper.CopyAttentionWrapper(beam_decoder_cell, beam_attention_mechanism, attention_layer_size=self.decoder_units / 2, alignment_history=False, copy_mechanism=beam_copy_mechanism, output_attention=True, initial_cell_state=beam_init_state, name='copy_attention_wrapper') beam_init_state = beam_decoder_cell.zero_state(curr_batch_size*(FLAGS.beam_width), tf.float32).clone(cell_state=beam_init_state) # We have to make two copies of the layer as beam search uses different shapes - but force them to share variables with tf.variable_scope('copy_layer') as scope: ans_mask = 1-tf.reshape(self.in_answer_feature,[curr_batch_size,-1]) self.answer_mask = tf.cond(self.hide_answer_in_copy, lambda: ans_mask, lambda: tf.ones(tf.shape(ans_mask))) train_projection_layer = copy_layer.CopyLayer(FLAGS.decoder_units//2, FLAGS.max_context_len, switch_units=FLAGS.switch_units, source_provider=lambda: self.context_copy_ids if FLAGS.context_as_set else self.context_ids, source_provider_sl=lambda: self.context_ids, condition_encoding=lambda: self.context_encoding, vocab_size=len(self.vocab), training_mode=self.is_training, output_mask=lambda: self.answer_mask, context_as_set=FLAGS.context_as_set, max_copy_size=FLAGS.max_copy_size, mask_oovs=tf.logical_not(self.is_training), name="copy_layer") scope.reuse_variables() answer_mask_beam = tf.contrib.seq2seq.tile_batch(self.answer_mask, multiplier=FLAGS.beam_width) beam_projection_layer = copy_layer.CopyLayer(FLAGS.decoder_units//2, FLAGS.max_context_len, switch_units=FLAGS.switch_units, source_provider=lambda: self.context_copy_ids if FLAGS.context_as_set else self.context_ids, source_provider_sl=lambda: self.context_ids, condition_encoding=lambda: self.context_encoding, vocab_size=len(self.vocab), training_mode=self.is_training, output_mask=lambda: answer_mask_beam, context_as_set=FLAGS.context_as_set, max_copy_size=FLAGS.max_copy_size, mask_oovs=tf.logical_not(self.is_training), name="copy_layer") with tf.variable_scope('decoder_unroll') as scope: # Helper - training training_helper = tf.contrib.seq2seq.TrainingHelper( self.question_teach_embedded, self.question_length) # self.question_teach_oh, self.question_length) # decoder_emb_inp, length(decoder_emb_inp)+1) # Decoder - training training_decoder = tf.contrib.seq2seq.BasicDecoder( train_decoder_cell, training_helper, initial_state=train_init_state, # initial_state=encoder_state # TODO: hardcoded FLAGS.max_copy_size is longest context in SQuAD - this will need changing for a new dataset!!! output_layer=train_projection_layer ) # Unroll the decoder training_outputs, training_decoder_states,training_out_lens = tf.contrib.seq2seq.dynamic_decode(training_decoder,impute_finished=True, maximum_iterations=tf.reduce_max(self.question_length)) training_probs=training_outputs.rnn_output with tf.variable_scope(scope, reuse=True): start_tokens = tf.tile(tf.constant([self.vocab[SOS]], dtype=tf.int32), [ curr_batch_size ] ) end_token = self.vocab[EOS] # DBS degrades to normal BS with groups=1, but my implementation is 1) probably slower and 2) wont receive updates from upstream if FLAGS.diverse_bs: beam_decoder = DiverseBeamSearchDecoder( cell = beam_decoder_cell, embedding = self.full_embeddings, start_tokens = start_tokens, end_token = end_token, initial_state = beam_init_state, beam_width = FLAGS.beam_width, output_layer = beam_projection_layer , length_penalty_weight=FLAGS.length_penalty, num_groups=FLAGS.beam_groups, diversity_param=FLAGS.beam_diversity) else: beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell = beam_decoder_cell, embedding = self.full_embeddings, start_tokens = start_tokens, end_token = end_token, initial_state = beam_init_state, beam_width = FLAGS.beam_width, output_layer = beam_projection_layer , length_penalty_weight=FLAGS.length_penalty) beam_outputs, beam_decoder_states,beam_out_lens = tf.contrib.seq2seq.dynamic_decode( beam_decoder, impute_finished=False, maximum_iterations=40 ) beam_pred_ids = beam_outputs.predicted_ids[:,:,0] # tf1.4 (and maybe others) return -1 for parts of the sequence outside the valid length, replace this with PAD (0) beam_mask = tf.sequence_mask(beam_out_lens[:,0], tf.shape(beam_pred_ids)[1], dtype=tf.int32) beam_pred_ids = beam_pred_ids*beam_mask beam_pred_scores = beam_outputs.beam_search_decoder_output.scores # pred_ids = debug_shape(pred_ids, "pred ids") beam_probs = tf.one_hot(beam_pred_ids, depth=len(self.vocab)+FLAGS.max_copy_size) self.q_hat = training_probs#tf.nn.softmax(logits, dim=2) # because we've done a few logs of softmaxes, there can be some precision problems that lead to non zero probability outside of the valid vocab, fix it here: self.max_vocab_size = tf.tile(tf.expand_dims(self.context_vocab_size+len(self.vocab),axis=1),[1,tf.shape(self.question_onehot)[1]]) output_mask = tf.sequence_mask(self.max_vocab_size, FLAGS.max_copy_size+len(self.vocab), dtype=tf.float32) # self.q_hat = self.q_hat*output_mask with tf.variable_scope('output'), tf.device('/cpu:*'): self.q_hat_ids = tf.argmax(self.q_hat,axis=2,output_type=tf.int32) self.a_string = ops.id_tensor_to_string(self.answer_coerced, self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) self.q_hat_string = ops.id_tensor_to_string(self.q_hat_ids, self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) self.q_hat_beam_ids = beam_pred_ids self.q_hat_beam_string = ops.id_tensor_to_string(self.q_hat_beam_ids, self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) self.q_hat_full_beam_str = [ops.id_tensor_to_string(ids*tf.sequence_mask(beam_out_lens[:,i], tf.shape(beam_pred_ids)[1], dtype=tf.int32), self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) for i,ids in enumerate(tf.unstack(beam_outputs.predicted_ids,axis=2))] self.q_hat_full_beam_lens = [len for len in tf.unstack(beam_out_lens,axis=1)] self.q_hat_beam_lens = beam_out_lens[:,0] self.q_gold = ops.id_tensor_to_string(self.question_ids, self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) self._output_summaries.extend( [tf.summary.text("q_hat", self.q_hat_string), tf.summary.text("q_gold", self.q_gold), tf.summary.text("answer", self.answer_raw)]) with tf.variable_scope('train_loss'): self.target_weights = tf.sequence_mask( self.question_length, tf.shape(self.q_hat)[1], dtype=tf.float32) logits = ops.safe_log(self.q_hat) # if the switch variable is fully latent, this gets a bit fiddly - we have to sum probabilities over all correct tokens, *then* take CE loss # otherwise the built in fn is fine (and almost certainly faster) if FLAGS.latent_switch: self.crossent =-1*ops.safe_log(tf.reduce_sum(self.q_hat*self.question_onehot, axis=2)) else: self.crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.question_ids, logits=logits) qlen_float = tf.cast(self.question_length, tf.float32) self.xe_loss = tf.reduce_mean(tf.reduce_sum(self.crossent * self.target_weights,axis=1)/qlen_float,axis=0) self.nll = tf.reduce_sum(self.crossent * self.target_weights,axis=1) # TODO: Check these should be included in baseline? # get sum of all probabilities for words that are also in answer answer_oh = tf.one_hot(self.answer_ids, depth=len(self.vocab) +FLAGS.max_copy_size) answer_mask = tf.tile(tf.reduce_sum(answer_oh, axis=1,keep_dims=True), [1,tf.reduce_max(self.question_length),1]) self.suppression_loss = tf.reduce_mean(tf.reduce_sum(tf.reduce_sum(answer_mask * self.q_hat,axis=2)*self.target_weights,axis=1)/qlen_float,axis=0) # entropy maximiser self.entropy_loss = tf.reduce_mean(tf.reduce_sum(tf.reduce_sum(self.q_hat *ops.safe_log(self.q_hat),axis=2)*self.target_weights,axis=1)/qlen_float,axis=0) if self.use_embedding_loss: vocab_cap = tf.tile(tf.expand_dims(self.context_vocab_size+len(self.vocab)-1,axis=1),[1,FLAGS.max_copy_size+len(self.vocab)]) with tf.device('/cpu:*'): self.local_vocab_string = ops.id_tensor_to_string(tf.minimum(tf.tile(tf.expand_dims(tf.range(FLAGS.max_copy_size+len(self.vocab)),axis=0), [curr_batch_size,1]), vocab_cap), self.rev_vocab, self.context_raw, context_as_set=FLAGS.context_as_set) self.local_vocab_to_extended = ops.string_tensor_to_id(self.local_vocab_string, self.glove_vocab) self.local_embeddings = tf.reshape(tf.nn.embedding_lookup(self.extended_embeddings, self.local_vocab_to_extended), [curr_batch_size, FLAGS.max_copy_size+len(self.vocab),self.embedding_size]) self.q_gold_ids_extended = ops.string_tensor_to_id(self.question_raw, self.glove_vocab) # self.q_hat_extended = tf.matmul(self.q_hat, tf.stop_gradient(self.local_vocab_to_extended)) # batch x seq x ext_vocab self.q_gold_embedded_extended = tf.nn.embedding_lookup(self.extended_embeddings, self.q_gold_ids_extended) self.q_hat_embedded_extended = tf.matmul(self.q_hat,self.local_embeddings) # self.q_hat_embedded_extended = tf.matmul(self.extended_embeddings, tf.cast(self.q_hat_ids_extended, tf.int32), b_is_sparse=True) self.similarity = tf.reduce_sum(self.q_hat_embedded_extended * tf.stop_gradient(self.q_gold_embedded_extended), axis=-1)/(1e-5+tf.norm(self.q_gold_embedded_extended, axis=-1)*tf.norm(self.q_hat_embedded_extended, axis=-1)) # batch x seq self.dist = tf.reduce_sum(tf.square(self.q_hat_embedded_extended - tf.stop_gradient(self.q_gold_embedded_extended)), axis=-1) self._train_summaries.append(tf.summary.scalar("debug/similarities", tf.reduce_mean(tf.reduce_sum(self.similarity* self.target_weights,axis=1)/qlen_float))) self._train_summaries.append(tf.summary.scalar("debug/dist", tf.reduce_mean(tf.reduce_sum(self.dist* self.target_weights,axis=1)/qlen_float))) self.loss=tf.reduce_mean(tf.reduce_sum(tf.abs(tf.acos(self.similarity)) * self.target_weights,axis=1)/qlen_float,axis=0) # self.loss = tf.abs(tf.acos(self.similarity) else: self.loss = self.xe_loss + FLAGS.suppression_weight*self.suppression_loss + FLAGS.entropy_weight*self.entropy_loss self.shortlist_prob = tf.reduce_sum(self.q_hat[:,:,:len(self.vocab)],axis=2)*self.target_weights self.copy_prob = tf.reduce_sum(self.q_hat[:,:,len(self.vocab):],axis=2)*self.target_weights self.mean_copy_prob = tf.reduce_sum(self.copy_prob,axis=1)/qlen_float self._train_summaries.append(tf.summary.scalar("debug/shortlist_prob", tf.reduce_mean(tf.reduce_sum(self.shortlist_prob,axis=1)/qlen_float))) self._train_summaries.append(tf.summary.scalar("debug/copy_prob", tf.reduce_mean(tf.reduce_sum(self.copy_prob,axis=1)/qlen_float))) self._train_summaries.append(tf.summary.scalar('train_loss/xe_loss', self.xe_loss)) self._train_summaries.append(tf.summary.scalar('train_loss/entropy_loss', self.entropy_loss)) self._train_summaries.append(tf.summary.scalar('train_loss/suppr_loss', self.suppression_loss)) #dont bother calculating gradients if not training if self.training_mode: # Calculate and clip gradients params = tf.trainable_variables() gradients = tf.gradients(self.loss, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, 5) # Optimization if FLAGS.opt_type == "sgd": self.global_step = tf.train.create_global_step(self.graph) self.sgd_lr = 1 * tf.pow(0.5, tf.cast(tf.maximum(0, tf.cast(self.global_step, tf.int32)-8000)/1000, tf.float32)) self._train_summaries.append(tf.summary.scalar('debug/sgd_lr', self.sgd_lr)) self.optimizer = tf.train.GradientDescentOptimizer(self.sgd_lr).apply_gradients( zip(clipped_gradients, params)) if self.training_mode else tf.no_op() elif FLAGS.opt_type == "sgd_mom": self.global_step = tf.train.create_global_step(self.graph) self.sgd_lr = 0.1 * tf.pow(0.5, tf.cast(tf.maximum(0, tf.cast(self.global_step, tf.int32)-16000)/2000, tf.float32)) self._train_summaries.append(tf.summary.scalar('debug/sgd_lr', self.sgd_lr)) momentum = tf.Variable(0.1, trainable=False) self.optimizer = tf.train.MomentumOptimizer(self.sgd_lr, momentum).apply_gradients( zip(clipped_gradients, params)) if self.training_mode else tf.no_op() else: self.optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate).apply_gradients( zip(clipped_gradients, params)) if self.training_mode else tf.no_op() self.accuracy = tf.reduce_mean(tf.cast(tf.reduce_sum(self.question_onehot * tf.contrib.seq2seq.hardmax(self.q_hat), axis=-1),tf.float32)*self.target_weights)
def main(_): if FLAGS.testing: print('TEST MODE - reducing model size') FLAGS.context_encoder_units = 100 FLAGS.answer_encoder_units = 100 FLAGS.decoder_units = 100 FLAGS.batch_size = 8 FLAGS.eval_batch_size = 8 # FLAGS.embedding_size=50 run_id = str(int(time.time())) chkpt_path = FLAGS.model_dir + 'qgen/' + FLAGS.model_type + '/' + run_id restore_path = FLAGS.model_dir + 'qgen/' + FLAGS.restore_path if FLAGS.restore_path is not None else None #'MALUUBA-CROP-LATENT'+'/'+'1534123959' # restore_path=FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart' disc_path = FLAGS.model_dir + 'saved/discriminator-trained-latent' print("Run ID is ", run_id) print("Model type is ", FLAGS.model_type) if not os.path.exists(chkpt_path): os.makedirs(chkpt_path) # load dataset train_data = loader.load_squad_triples(FLAGS.data_path, False) dev_data = loader.load_squad_triples(FLAGS.data_path, True) train_contexts_unfilt, _, ans_text_unfilt, ans_pos_unfilt = zip( *train_data) dev_contexts_unfilt, _, dev_ans_text_unfilt, dev_ans_pos_unfilt = zip( *dev_data) if FLAGS.testing: train_data = train_data[:1000] num_dev_samples = 100 else: num_dev_samples = FLAGS.num_dev_samples if FLAGS.filter_window_size_before > -1: train_data = preprocessing.filter_squad( train_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens) dev_data = preprocessing.filter_squad( dev_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens) print('Loaded SQuAD with ', len(train_data), ' triples') train_contexts, train_qs, train_as, train_a_pos = zip(*train_data) if FLAGS.restore: if restore_path is None: exit('You need to specify a restore path!') with open(restore_path + '/vocab.json', encoding="utf-8") as f: vocab = json.load(f) elif FLAGS.glove_vocab: vocab = loader.get_glove_vocab(FLAGS.data_path, size=FLAGS.vocab_size, d=FLAGS.embedding_size) with open(chkpt_path + '/vocab.json', 'w', encoding="utf-8") as outfile: json.dump(vocab, outfile) else: vocab = loader.get_vocab(train_contexts + train_qs, FLAGS.vocab_size) with open(chkpt_path + '/vocab.json', 'w', encoding="utf-8") as outfile: json.dump(vocab, outfile) # Create model if FLAGS.model_type[:7] == "SEQ2SEQ": model = Seq2SeqModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss) elif FLAGS.model_type[:7] == "MALUUBA": # TEMP if not FLAGS.policy_gradient: FLAGS.qa_weight = 0 FLAGS.lm_weight = 0 model = MaluubaModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss) # if FLAGS.model_type[:10] == "MALUUBA_RL": # qa_vocab=model.qa.vocab # lm_vocab=model.lm.vocab if FLAGS.policy_gradient: discriminator = DiscriminatorInstance(trainable=FLAGS.disc_train, path=disc_path) else: exit("Unrecognised model type: " + FLAGS.model_type) # create data streamer with SquadStreamer(vocab, FLAGS.batch_size, FLAGS.num_epochs, shuffle=True) as train_data_source, SquadStreamer( vocab, FLAGS.eval_batch_size, 1, shuffle=True) as dev_data_source: with model.graph.as_default(): saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) # change visible devices if using RL models gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit, visible_device_list='0', allow_growth=True) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=False), graph=model.graph) as sess: summary_writer = tf.summary.FileWriter( FLAGS.log_dir + 'qgen/' + FLAGS.model_type + '/' + run_id, sess.graph) train_data_source.initialise(train_data) num_steps_train = len(train_data) // FLAGS.batch_size num_steps_dev = num_dev_samples // FLAGS.eval_batch_size if FLAGS.restore: saver.restore(sess, tf.train.latest_checkpoint(restore_path)) start_e = 15 #FLAGS.num_epochs print('Loaded model') else: start_e = 0 sess.run(tf.global_variables_initializer()) # sess.run(model.glove_init_ops) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/f1", simple_value=0.0) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/bleu", simple_value=0.0) ]) summary_writer.add_summary(f1summary, global_step=0) summary_writer.add_summary(bleusummary, global_step=0) # Initialise the dataset # sess.run(model.iterator.initializer, feed_dict={model.context_ph: train_contexts, # model.qs_ph: train_qs, model.as_ph: train_as, model.a_pos_ph: train_a_pos}) best_oos_nll = 1e6 lm_score_moments = online_moments.OnlineMoment() qa_score_moments = online_moments.OnlineMoment() disc_score_moments = online_moments.OnlineMoment() # for e in range(start_e,start_e+FLAGS.num_epochs): # Train for one epoch for i in tqdm(range(num_steps_train * FLAGS.num_epochs), desc='Training'): # Get a batch train_batch, curr_batch_size = train_data_source.get_batch() # Are we doing policy gradient? Do a forward pass first, then build the PG batch and do an update step if FLAGS.model_type[: 10] == "MALUUBA_RL" and FLAGS.policy_gradient: # do a fwd pass first, get the score, then do another pass and optimize qhat_str, qhat_ids, qhat_lens = sess.run( [ model.q_hat_beam_string, model.q_hat_beam_ids, model.q_hat_beam_lens ], feed_dict={ model.input_batch: train_batch, model.is_training: FLAGS.pg_dropout, model.hide_answer_in_copy: True }) # The output is as long as the max allowed len - remove the pointless extra padding qhat_ids = qhat_ids[:, :np.max(qhat_lens)] qhat_str = qhat_str[:, :np.max(qhat_lens)] pred_str = byte_token_array_to_str(qhat_str, qhat_lens - 1) gold_q_str = byte_token_array_to_str( train_batch[1][0], train_batch[1][3]) # Get reward values lm_score = (-1 * model.lm.get_seq_perplexity(pred_str) ).tolist() # lower perplexity is better # retrieve the uncropped context for QA evaluation unfilt_ctxt_batch = [ train_contexts_unfilt[ix] for ix in train_batch[3] ] ans_text_batch = [ ans_text_unfilt[ix] for ix in train_batch[3] ] ans_pos_batch = [ ans_pos_unfilt[ix] for ix in train_batch[3] ] qa_pred = model.qa.get_ans(unfilt_ctxt_batch, pred_str) qa_pred_gold = model.qa.get_ans(unfilt_ctxt_batch, gold_q_str) # gold_str=[] # pred_str=[] qa_f1s = [] gold_ans_str = byte_token_array_to_str(train_batch[2][0], train_batch[2][2], is_array=False) qa_f1s.extend([ metrics.f1(metrics.normalize_answer(gold_ans_str[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size) ]) disc_scores = discriminator.get_pred( unfilt_ctxt_batch, pred_str, ans_text_batch, ans_pos_batch) if i > FLAGS.pg_burnin // 2: lm_score_moments.push(lm_score) qa_score_moments.push(qa_f1s) disc_score_moments.push(disc_scores) # print(disc_scores) # print((e-start_e)*num_steps_train+i, flags.pg_burnin) if i > FLAGS.pg_burnin: # A variant of popart qa_score_whitened = ( qa_f1s - qa_score_moments.mean ) / np.sqrt(qa_score_moments.variance + 1e-6) lm_score_whitened = ( lm_score - lm_score_moments.mean ) / np.sqrt(lm_score_moments.variance + 1e-6) disc_score_whitened = ( disc_scores - disc_score_moments.mean ) / np.sqrt(disc_score_moments.variance + 1e-6) lm_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/lm", simple_value=np.mean(lm_score)) ]) summary_writer.add_summary(lm_summary, global_step=(i)) qa_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/qa", simple_value=np.mean(qa_f1s)) ]) summary_writer.add_summary(qa_summary, global_step=(i)) disc_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/disc", simple_value=np.mean(disc_scores)) ]) summary_writer.add_summary(disc_summary, global_step=(i)) lm_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/lm_white", simple_value=np.mean( lm_score_whitened)) ]) summary_writer.add_summary(lm_white_summary, global_step=(i)) qa_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/qa_white", simple_value=np.mean( qa_score_whitened)) ]) summary_writer.add_summary(qa_white_summary, global_step=(i)) disc_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/disc_white", simple_value=np.mean( disc_score_whitened)) ]) summary_writer.add_summary(disc_white_summary, global_step=(i)) # Build a combined batch - half ground truth for MLE, half generated for PG train_batch_ext = duplicate_batch_and_inject( train_batch, qhat_ids, qhat_str, qhat_lens) # print(qhat_ids) # print(qhat_lens) # print(train_batch_ext[2][2]) rl_dict = { model.lm_score: np.asarray((lm_score_whitened * FLAGS.lm_weight).tolist() + [ FLAGS.pg_ml_weight for b in range(curr_batch_size) ]), model.qa_score: np.asarray((qa_score_whitened * FLAGS.qa_weight).tolist() + [0 for b in range(curr_batch_size)]), model.disc_score: np.asarray((disc_score_whitened * FLAGS.disc_weight).tolist() + [0 for b in range(curr_batch_size)]), model.rl_lm_enabled: True, model.rl_qa_enabled: True, model.rl_disc_enabled: FLAGS.disc_weight > 0, model.step: i - FLAGS.pg_burnin, model.hide_answer_in_copy: True } # perform a policy gradient step, but combine with a XE step by using appropriate rewards ops = [ model.pg_optimizer, model.train_summary, model.q_hat_string ] if i % FLAGS.eval_freq == 0: ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length ]) res_offset = 5 else: res_offset = 0 ops.extend([model.lm_loss, model.qa_loss]) res = sess.run(ops, feed_dict={ model.input_batch: train_batch_ext, model.is_training: False, **rl_dict }) summary_writer.add_summary(res[1], global_step=(i)) # Log only the first half of the PG related losses lm_loss_summary = tf.Summary(value=[ tf.Summary.Value( tag="train_loss/lm", simple_value=np.mean(res[3 + res_offset] [:curr_batch_size])) ]) summary_writer.add_summary(lm_loss_summary, global_step=(i)) qa_loss_summary = tf.Summary(value=[ tf.Summary.Value( tag="train_loss/qa", simple_value=np.mean(res[4 + res_offset] [:curr_batch_size])) ]) summary_writer.add_summary(qa_loss_summary, global_step=(i)) # TODO: more principled scheduling here than alternating steps if FLAGS.disc_train: ixs = np.round( np.random.binomial(1, 0.5, curr_batch_size)) qbatch = [ pred_str[ix].replace(" </Sent>", "").replace( " <PAD>", "") if ixs[ix] < 0.5 else gold_q_str[ix].replace( " </Sent>", "").replace(" <PAD>", "") for ix in range(curr_batch_size) ] loss = discriminator.train_step(unfilt_ctxt_batch, qbatch, ans_text_batch, ans_pos_batch, ixs, step=(i)) else: # Normal single pass update step. If model has PG capability, fill in the placeholders with empty values if FLAGS.model_type[: 7] == "MALUUBA" and not FLAGS.policy_gradient: rl_dict = { model.lm_score: [0 for b in range(curr_batch_size)], model.qa_score: [0 for b in range(curr_batch_size)], model.disc_score: [0 for b in range(curr_batch_size)], model.rl_lm_enabled: False, model.rl_qa_enabled: False, model.rl_disc_enabled: False, model.hide_answer_in_copy: False } else: rl_dict = {} # Perform a normal optimizer step ops = [ model.optimizer, model.train_summary, model.q_hat_string ] if i % FLAGS.eval_freq == 0: ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length ]) res = sess.run(ops, feed_dict={ model.input_batch: train_batch, model.is_training: True, **rl_dict }) summary_writer.add_summary(res[1], global_step=(i)) # Dump some output periodically if i > 0 and i % FLAGS.eval_freq == 0 and ( i > FLAGS.pg_burnin or not FLAGS.policy_gradient): with open(FLAGS.log_dir + 'out.htm', 'w', encoding='utf-8') as fp: fp.write( output_pretty(res[2].tolist(), res[3], res[4], res[5], 0, i)) gold_batch = res[6] gold_lens = res[7] f1s = [] bleus = [] for b, pred in enumerate(res[2]): pred_str = tokens_to_string(pred[:gold_lens[b] - 1]) gold_str = tokens_to_string( gold_batch[b][:gold_lens[b] - 1]) f1s.append(metrics.f1(gold_str, pred_str)) bleus.append(metrics.bleu(gold_str, pred_str)) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="train_perf/f1", simple_value=sum(f1s) / len(f1s)) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="train_perf/bleu", simple_value=sum(bleus) / len(bleus)) ]) summary_writer.add_summary(f1summary, global_step=(i)) summary_writer.add_summary(bleusummary, global_step=(i)) # Evaluate against dev set f1s = [] bleus = [] nlls = [] np.random.shuffle(dev_data) dev_subset = dev_data[:num_dev_samples] dev_data_source.initialise(dev_subset) for j in tqdm(range(num_steps_dev), desc='Eval ' + str(i)): dev_batch, curr_batch_size = dev_data_source.get_batch( ) pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len, nll = sess.run( [ model.q_hat_beam_string, model.q_hat_beam_ids, model.q_hat_beam_lens, model.question_raw, model.question_length, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll ], feed_dict={ model.input_batch: dev_batch, model.is_training: False }) nlls.extend(nll.tolist()) # out_str="<h1>"+str(e)+' - '+str(datetime.datetime.now())+'</h1>' for b, pred in enumerate(pred_batch): pred_str = tokens_to_string( pred[:pred_lens[b] - 1]).replace( ' </Sent>', "").replace(" <PAD>", "") gold_str = tokens_to_string( gold_batch[b][:gold_lens[b] - 1]) f1s.append(metrics.f1(gold_str, pred_str)) bleus.append(metrics.bleu(gold_str, pred_str)) # out_str+=pred_str.replace('>','>').replace('<','<')+"<br/>"+gold_str.replace('>','>').replace('<','<')+"<hr/>" if j == 0: title = chkpt_path out_str = output_eval(title, pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len) with open(FLAGS.log_dir + 'out_eval_' + FLAGS.model_type + '.htm', 'w', encoding='utf-8') as fp: fp.write(out_str) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/f1", simple_value=sum(f1s) / len(f1s)) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/bleu", simple_value=sum(bleus) / len(bleus)) ]) nllsummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/nll", simple_value=sum(nlls) / len(nlls)) ]) summary_writer.add_summary(f1summary, global_step=i) summary_writer.add_summary(bleusummary, global_step=i) summary_writer.add_summary(nllsummary, global_step=i) mean_nll = sum(nlls) / len(nlls) if mean_nll < best_oos_nll: print("New best NLL! ", mean_nll, " Saving...") best_oos_nll = mean_nll saver.save(sess, chkpt_path + '/model.checkpoint', global_step=i) else: print("NLL not improved ", mean_nll) if FLAGS.policy_gradient: print("Saving anyway") saver.save(sess, chkpt_path + '/model.checkpoint', global_step=i) if FLAGS.disc_train: print("Saving disc") discriminator.save_to_chkpt(FLAGS.model_dir, i)