def __call__(self, is_train, scope=None): activation = get_keras_activation(self.activation) recurrent_initializer = get_keras_initialization( self.recurrent_initializer) kernel_initializer = get_keras_initialization(self.kernel_initializer) candidate_initializer = get_keras_initialization( self.candidate_initializer) return GRUCell(self.num_units, tf.constant_initializer(self.bais_init), kernel_initializer, recurrent_initializer, candidate_initializer, activation)
def __call__(self, is_train, scope=None): activation = get_keras_activation(self.activation) recurrent_activation = get_keras_activation(self.recurrent_activation) kernel_initializer = get_keras_initialization(self.kernel_initializer) recurrent_initializer = get_keras_initialization( self.recurrent_initializer) if activation is None or kernel_initializer is None \ or recurrent_initializer is None or recurrent_activation is None: raise ValueError() cell = InitializedLSTMCell(self.num_units, kernel_initializer, recurrent_initializer, activation, recurrent_activation, self.forget_bias, self.keep_recurrent_probs, is_train, scope) return cell
def apply(self, is_train, x, mask=None): if self.key_mapper is not None: with tf.variable_scope("map_keys"): keys = self.key_mapper.apply(is_train, x, mask) else: keys = x weights = tf.get_variable( "weights", (keys.shape.as_list()[-1], self.n_encodings), dtype=tf.float32, initializer=get_keras_initialization(self.init)) dist = tf.tensordot(keys, weights, axes=[[2], [0]]) # (batch, x_words, n_encoding) if self.bias: dist += tf.get_variable("bias", (1, 1, self.n_encodings), dtype=tf.float32, initializer=tf.zeros_initializer()) if mask is not None: bool_mask = tf.expand_dims( tf.cast(tf.sequence_mask(mask, tf.shape(x)[1]), tf.float32), 2) dist = bool_mask * bool_mask + (1 - bool_mask) * VERY_NEGATIVE_NUMBER dist = tf.nn.softmax(dist, dim=1) out = tf.einsum("ajk,ajn->ank", x, dist) # (batch, n_encoding, feature) if self.post_process is not None: with tf.variable_scope("post_process"): out = self.post_process.apply(is_train, out) return out
def apply(self, is_train, sentences_rep, iteration1_logits, mask): init = get_keras_initialization(self.init) projected = fully_connected(sentences_rep, units=self.n_project, kernel_initializer=init, use_bias=True, activation=get_keras_activation(self.activation)) masked_logits1 = exp_mask(tf.squeeze(iteration1_logits, axis=[2]), mask) weights = tf.nn.softmax(masked_logits1) return tf.reduce_sum(tf.expand_dims(weights, axis=2) * projected, axis=1)
def apply(self, is_train, context_embed, answer, context_mask=None, **kwargs): init_fn = get_keras_initialization(self.init) with tf.variable_scope("bounds_encoding"): m1, m2 = self.predictor.apply(is_train, context_embed, context_mask) with tf.variable_scope("start_pred"): logits1 = fully_connected(m1, 1, activation_fn=None, weights_initializer=init_fn) logits1 = tf.squeeze(logits1, axis=[2]) with tf.variable_scope("end_pred"): logits2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn) logits2 = tf.squeeze(logits2, axis=[2]) with tf.variable_scope("predict_span"): return self.span_predictor.predict(answer, logits1, logits2, mask=context_mask, **kwargs)
def _distance_logits(self, x, keys): init = get_keras_initialization(self.init) key_w = tf.get_variable("key_w", shape=keys.shape.as_list()[-1], initializer=init, dtype=tf.float32) key_logits = tf.tensordot(keys, key_w, axes=[[2], [0]]) # (batch, key_len) x_w = tf.get_variable("input_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len) dot_w = tf.get_variable("dot_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) # Compute x * dot_weights first, the batch mult with x x_dots = x * tf.expand_dims(tf.expand_dims(dot_w, 0), 0) dot_logits = tf.matmul(x_dots, keys, transpose_b=True) return dot_logits + tf.expand_dims(key_logits, 1) + tf.expand_dims( x_logits, 2)
def _distance_logits(self, x1, x2): init = get_keras_initialization(self.init) project1 = tf.get_variable("project1", (x1.shape.as_list()[-1], self.project_size), initializer=init) x1 = tf.tensordot(x1, project1, [[2], [0]]) if self.share_project: if x2.shape.as_list()[-1] != x1.shape.as_list()[-1]: raise ValueError() project2 = project1 else: project2 = tf.get_variable( "project2", (x2.shape.as_list()[-1], self.project_size), initializer=init) x2 = tf.tensordot(x2, project2, [[2], [0]]) if self.project_bias: x1 += tf.get_variable("bias1", (1, 1, self.project_size), initializer=tf.zeros_initializer()) x2 += tf.get_variable("bias2", (1, 1, self.project_size), initializer=tf.zeros_initializer()) dots = tf.matmul(x1, x2, transpose_b=True) if self.scale: dots /= tf.sqrt(tf.cast(self.project_size, tf.float32)) return dots
def apply(self, is_train, x, memories, answer: List[Tensor], x_mask=None, memory_mask=None): with tf.variable_scope("map_context"): memories = self.context_mapper.apply(is_train, memories, memory_mask) with tf.variable_scope("encode_context"): encoded = self.context_encoder.apply(is_train, memories, memory_mask) with tf.variable_scope("merge"): x = self.merge.apply(is_train, x, encoded, x_mask) with tf.variable_scope("predict"): m1, m2 = self.bounds_predictor.apply(is_train, x, x_mask) init = get_keras_initialization(self.init) with tf.variable_scope("logits1"): l1 = fully_connected(m1, 1, activation_fn=None, weights_initializer=init) l1 = tf.squeeze(l1, axis=[2]) with tf.variable_scope("logits2"): l2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init) l2 = tf.squeeze(l2, axis=[2]) with tf.variable_scope("predict_span"): return self.span_predictor.predict(answer, l1, l2, x_mask)
def apply(self, is_train, x, weights, answer: List): init = get_keras_initialization(self.fc_init) x_shape = x.shape.as_list() with tf.variable_scope("compute_encoding_logits"): encoding_pred_weights = tf.get_variable( 'encoding_weights', shape=[x_shape[1], x_shape[2], 2], initializer=init) encoding_pred_biases = tf.get_variable( 'encoding_biases', shape=[x_shape[1], 2], initializer=tf.zeros_initializer()) encoding_logits = tf.einsum( 'btd,tdx->btx', x, encoding_pred_weights) + encoding_pred_biases encoding_softmaxes = tf.nn.softmax(encoding_logits, axis=-1) weighted_encoding_softmaxes = encoding_softmaxes * tf.expand_dims( weights, -1) weighted_softmax = tf.reduce_sum(weighted_encoding_softmaxes, axis=1) with tf.variable_scope( "compute_logits" ): # this is only for compatibility with the logits name of old models logits = tf.log(weighted_softmax, name='fully_connected/BiasAdd') cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=answer[0], logits=logits) if self.pos_weight is not None: cross_ent = cross_ent * tf.to_float((answer[0] * (self.pos_weight - 1)) + 1) loss = tf.reduce_mean(cross_ent) tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return BinaryPrediction(logits)
def _get_predictions_for(self, is_train, question_embed, question_mask, context_embed, context_mask, answer, question_lm, context_lm, sentence_segments, sentence_mask): question_rep, context_rep = question_embed, context_embed context1_rep, = tf.unstack(context_rep, axis=1, num=1) context1_mask, = tf.unstack(context_mask, axis=1, num=1) sentence_segments, = tf.unstack(sentence_segments, axis=1, num=1) sentence_mask, = tf.unstack(sentence_mask, axis=1, num=1) q_lm_in, c1_lm_in = [], [] if self.use_elmo: context1_lm, = tf.unstack(context_lm, axis=1, num=1) q_lm_in = [question_lm] c1_lm_in = [context1_lm] if self.embed_mapper is not None: with tf.variable_scope("map_embed"): context1_rep = self.embed_mapper.apply(is_train, context1_rep, context1_mask, *c1_lm_in) with tf.variable_scope("map_embed", reuse=True): question_rep = self.embed_mapper.apply(is_train, question_rep, question_mask, *q_lm_in) with tf.variable_scope("seq_enc"): question_rep = self.sequence_encoder.apply(is_train, question_rep, question_mask) with tf.variable_scope("sentences_enc"): context1_rep = self.sentences_encoder.apply(context1_rep, sentence_segments, sentence_mask) context1_rep = tf.identity(context1_rep, name='encode_context') tf.add_to_collection(INTERMEDIATE_LAYER_COLLECTION, context1_rep) with tf.variable_scope("merger"): merged_rep = self.merger.apply(is_train, tensor=context1_rep, fixed_tensor=question_rep, mask=sentence_mask) if self.post_merger is not None: with tf.variable_scope("post_merger"): merged_rep = self.post_merger.apply(is_train, merged_rep, mask=sentence_mask) with tf.variable_scope("sentence_level_predictions"): sentences_logits = fully_connected(merged_rep, 1, use_bias=True, activation=None, kernel_initializer=get_keras_initialization('glorot_uniform')) max_logits = self.max_pool.apply(is_train, sentences_logits, sentence_mask) with tf.variable_scope("predictor"): return self.predictor.apply(is_train, max_logits, answer)
def _distance_logits(self, x, keys): init = get_keras_initialization(self.init) key_w = tf.get_variable("key_w", shape=keys.shape.as_list()[-1], initializer=init, dtype=tf.float32) key_logits = tf.tensordot(keys, key_w, axes=[[2], [0]]) # (batch, key_len) x_w = tf.get_variable("x_w", shape=x.shape.as_list()[-1], initializer=init, dtype=tf.float32) x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len) # Broadcasting will expand the arrays to (batch, x_len, key_len) return tf.expand_dims(x_logits, axis=2) + tf.expand_dims(key_logits, axis=1)
def merge_weight_predict(is_train, context_rep, question_rep, context_mask, merger, post_merger, max_pool, predictor, answer, multiply_probs=None): with tf.variable_scope("merger"): c_q_merged_rep = merger.apply(is_train, tensor=context_rep, fixed_tensor=question_rep, mask=context_mask) if post_merger is not None: with tf.variable_scope("post_merger"): c_q_merged_rep = post_merger.apply(is_train, c_q_merged_rep, mask=context_mask) with tf.variable_scope("sentence_level_predictions"): sentences_logits = fully_connected( c_q_merged_rep, 1, use_bias=True, activation=None, kernel_initializer=get_keras_initialization('glorot_uniform')) max_logits = max_pool.apply(is_train, sentences_logits, context_mask) if multiply_probs is not None: max_logits = tf.log(multiply_probs + EPSILON) - tf.log(1. + tf.exp(-max_logits) - multiply_probs + EPSILON) with tf.variable_scope("predictor"): pred = predictor.apply(is_train, max_logits, answer) return c_q_merged_rep, sentences_logits, pred
def apply(self, is_train, x, mask=None): if self.key_mapper is not None: with tf.variable_scope("map_keys"): keys = self.key_mapper.apply(is_train, x, mask) else: keys = x weights = tf.get_variable("weights", keys.shape.as_list()[-1], dtype=tf.float32, initializer=get_keras_initialization( self.init)) dist = tf.tensordot(keys, weights, axes=[[2], [0]]) # (batch, x_words) dist = exp_mask(dist, mask) dist = tf.nn.softmax(dist) out = tf.einsum("ajk,aj->ak", x, dist) # (batch, x_dim) if self.post_process is not None: with tf.variable_scope("post_process"): out = self.post_process.apply(is_train, out) return out
def _distance_logits(self, x, keys): init = get_keras_initialization(self.init) key_w = tf.get_variable("key_w", shape=(keys.shape.as_list()[-1], self.projected_size), initializer=init, dtype=tf.float32) key_logits = tf.tensordot(keys, key_w, axes=[[2], [0] ]) # (batch, key_len, projected_size) if self.shared_project: x_w = key_w else: x_w = tf.get_variable("x_w", shape=(x.shape.as_list()[-1], self.projected_size), initializer=init, dtype=tf.float32) x_logits = tf.tensordot(x, x_w, axes=[[2], [0]]) # (batch, x_len, projected_size) summed = tf.expand_dims(x_logits, axis=2) + tf.expand_dims( key_logits, axis=1) # (batch, key_len, x_len, poject_size) summed = get_keras_activation(self.activation)(summed) combine_w = tf.get_variable("combine_w", shape=self.projected_size, initializer=init, dtype=tf.float32) return tf.tensordot(summed, combine_w, axes=[[3], [0]]) # (batch, key_len, x_len)
def apply(self, is_train, context_embed, answer, context_mask=None): init_fn = get_keras_initialization(self.init) bool_mask = tf.sequence_mask(context_mask, tf.shape(context_embed)[1]) with tf.variable_scope("predict"): m1, m2 = self.mapper.apply(is_train, context_embed, context_mask) if self.pre_process is not None: with tf.variable_scope("pre-process1"): m1 = self.pre_process.apply(is_train, m1, context_mask) with tf.variable_scope("pre-process2"): m2 = self.pre_process.apply(is_train, m2, context_mask) span_vector_lst = [] mask_lst = [] with tf.variable_scope("merge"): span_vector_lst.append(self.merge.apply(is_train, m1, m2)) mask_lst.append(bool_mask) for i in range(1, self.bound): with tf.variable_scope("merge", reuse=True): span_vector_lst.append( self.merge.apply(is_train, m1[:, :-i], m2[:, i:])) mask_lst.append(bool_mask[:, i:]) mask = tf.concat(mask_lst, axis=1) span_vectors = tf.concat( span_vector_lst, axis=1) # all logits -> flattened per-span predictions if self.post_process is not None: with tf.variable_scope("post-process"): span_vectors = self.post_process.apply(is_train, span_vectors) with tf.variable_scope("compute_logits"): logits = fully_connected(span_vectors, 1, activation_fn=None, weights_initializer=init_fn) logits = tf.squeeze(logits, axis=[2]) logits = logits + VERY_NEGATIVE_NUMBER * ( 1 - tf.cast(tf.concat(mask, axis=1), tf.float32)) l = tf.shape(context_embed)[1] if len(answer) == 1: answer = answer[0] if answer.dtype == tf.int32: if self.f1_weight == 0: answer_ix = to_packed_coordinates(answer, l, self.bound) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=answer_ix)) else: f1_mask = packed_span_f1_mask(answer, l, self.bound) if self.f1_weight < 1: f1_mask *= self.f1_weight f1_mask += (1 - self.f1_weight) * tf.one_hot( to_packed_coordinates(answer, l, self.bound), l) # TODO can we stay in log space? (actually its tricky since f1_mask can have zeros...) probs = tf.nn.softmax(logits) loss = -tf.reduce_mean( tf.log(tf.reduce_sum(probs * f1_mask, axis=1))) else: log_norm = tf.reduce_logsumexp(logits, axis=1) if self.aggregate == "sum": log_score = tf.reduce_logsumexp( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) elif self.aggregate == "max": log_score = tf.reduce_max( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) else: raise NotImplementedError() loss = tf.reduce_mean(-(log_score - log_norm)) else: raise NotImplementedError() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return PackedSpanPrediction(logits, l, self.bound)
def _get_predictions_for(self, is_train, question_embed, question_mask, context_embed, context_mask, answer, question_lm, context_lm, sentence_segments, sentence_mask): question_rep, context_rep = question_embed, context_embed context1_rep, context2_rep = tf.unstack(context_rep, axis=1, num=2) context1_mask, context2_mask = tf.unstack(context_mask, axis=1, num=2) context1_sentence_segments, context2_sentence_segments = tf.unstack( sentence_segments, axis=1, num=2) context1_sentence_mask, context2_sentence_mask = tf.unstack( sentence_mask, axis=1, num=2) q_lm_in, c1_lm_in, c2_lm_in = [], [], [] if self.use_elmo: context1_lm, context2_lm = tf.unstack(context_lm, axis=1, num=2) q_lm_in = [question_lm] c1_lm_in = [context1_lm] c2_lm_in = [context2_lm] if self.embed_mapper is not None: with tf.variable_scope("map_embed"): context1_rep = self.embed_mapper.apply(is_train, context1_rep, context1_mask, *c1_lm_in) with tf.variable_scope("map_embed", reuse=True): context2_rep = self.embed_mapper.apply(is_train, context2_rep, context2_mask, *c2_lm_in) question_rep = self.embed_mapper.apply(is_train, question_rep, question_mask, *q_lm_in) with tf.variable_scope("seq_enc"): question_enc = self.sequence_encoder.apply(is_train, question_rep, question_mask) question_enc = tf.identity(question_enc, name='encode_question') tf.add_to_collection(INTERMEDIATE_LAYER_COLLECTION, question_enc) def encode_sentences(context, sentence_segs, sentence_mask, rep_name): context = self.sentences_encoder.apply(context, sentence_segs, sentence_mask) if self.sentence_mapper is not None: with tf.variable_scope('sentence_mapper'): context = self.sentence_mapper.apply(is_train, context, mask=sentence_mask) context = tf.identity(context, name=rep_name) tf.add_to_collection(INTERMEDIATE_LAYER_COLLECTION, context) return context with tf.variable_scope('sentences_enc'): context1_sent_rep = encode_sentences(context1_rep, context1_sentence_segments, context1_sentence_mask, 'encode_context1') with tf.variable_scope('sentences_enc', reuse=True): context2_sent_rep = encode_sentences(context2_rep, context2_sentence_segments, context2_sentence_mask, 'encode_context2') # First Iteration (same as in the single context model) with tf.variable_scope("context1_relevance"): c1_q_merged_rep, context1_sentences_logits, context1_pred = \ merge_weight_predict(is_train=is_train, context_rep=context1_sent_rep, question_rep=question_enc, context_mask=context1_sentence_mask, merger=self.merger, post_merger=self.post_merger, max_pool=self.max_pool, predictor=self.predictor, answer=[answer[0]]) # Question Reformulation with tf.variable_scope("reformulation"): with tf.variable_scope('c2q'): question_rep = self.context_to_question_attention.apply( is_train, x=question_rep, keys=context1_rep, memories=context1_rep, x_mask=question_mask, memory_mask=context1_mask) reread_q_enc = self.sequence_encoder.apply( is_train, question_rep, question_mask) with tf.variable_scope('q2c'): context1_rep = self.question_to_context_attention.apply( is_train, x=context1_rep, keys=question_rep, memories=question_rep, x_mask=context1_mask, memory_mask=question_mask) reread_c1_enc = self.sequence_encoder.apply( is_train, context1_rep, context1_mask) with tf.variable_scope('reread_merge'): reformulated_q = self.reread_merger.apply( is_train, reread_q_enc, reread_c1_enc) reformulated_q = fully_connected( reformulated_q, c1_q_merged_rep.shape.as_list()[-1], use_bias=True, activation=get_keras_activation('relu'), kernel_initializer=get_keras_initialization( 'glorot_uniform')) reformulated_q = tf.identity(reformulated_q, name='reformulated_question') tf.add_to_collection(INTERMEDIATE_LAYER_COLLECTION, reformulated_q) # Second Iteration with tf.variable_scope("context2_relevance"): first_iter_probs = None if self.multiply_iteration_probs: first_iter_probs = tf.expand_dims(context1_pred.get_probs(), axis=1) c2_q_merged_rep, context2_sentences_logits, context2_pred = \ merge_weight_predict(is_train=is_train, context_rep=context2_sent_rep, question_rep=reformulated_q, context_mask=context2_sentence_mask, merger=self.merger, post_merger=self.post_merger, max_pool=self.max_pool, predictor=self.predictor, answer=[answer[1]], multiply_probs=first_iter_probs) return MultipleBinaryPredictions([context1_pred, context2_pred])
def _get_predictions_for(self, is_train, question_embed, question_mask, context_embed, context_mask, answer, question_lm, context_lm, sentence_segments, sentence_mask): question_rep, context_rep = question_embed, context_embed context_rep, = tf.unstack(context_rep, axis=1, num=1) context_mask, = tf.unstack(context_mask, axis=1, num=1) context_sentence_segments, = tf.unstack(sentence_segments, axis=1, num=1) context_sentence_mask, = tf.unstack(sentence_mask, axis=1, num=1) q_lm_in, c_lm_in = [], [] if self.use_elmo: context_lm, = tf.unstack(context_lm, axis=1, num=1) q_lm_in = [question_lm] c_lm_in = [context_lm] if self.embed_mapper is not None: with tf.variable_scope("map_embed"): context_rep = self.embed_mapper.apply(is_train, context_rep, context_mask, *c_lm_in) with tf.variable_scope("map_embed", reuse=True): question_rep = self.embed_mapper.apply(is_train, question_rep, question_mask, *q_lm_in) with tf.variable_scope('yes_no_question_prediction'): yes_no_q_enc = self.yes_no_question_encoder.apply( is_train, question_rep, question_mask) yes_no_choice_logits = fully_connected( yes_no_q_enc, 2, use_bias=True, activation=None, kernel_initializer=get_keras_initialization('glorot_uniform'), name='yes_no_choice') if self.question_mapper is not None: with tf.variable_scope("map_question"): question_rep = self.question_mapper.apply( is_train, question_rep, question_mask) if self.context_mapper is not None: with tf.variable_scope("map_context"): context_rep = self.context_mapper.apply( is_train, context_rep, context_mask) with tf.variable_scope("buid_memories"): keys, memories = self.memory_builder.apply(is_train, question_rep, question_mask) with tf.variable_scope("apply_attention"): context_rep = self.attention.apply(is_train, context_rep, keys, memories, context_mask, question_mask) if self.match_encoder is not None: with tf.variable_scope("process_attention"): context_rep = self.match_encoder.apply(is_train, context_rep, context_mask) with tf.variable_scope('yes_no_answer_prediction'): yes_no_c_enc = self.yes_no_context_encoder.apply( is_train, context_rep, context_mask) yes_no_answer_logits = fully_connected( yes_no_c_enc, 2, use_bias=True, activation=None, kernel_initializer=get_keras_initialization('glorot_uniform'), name='yes_no_answer') with tf.variable_scope('supporting_fact_prediction'): pre_context_sents = context_rep if self.pre_sp_mapper is not None: with tf.variable_scope('pre_sp_mapper'): pre_context_sents = self.pre_sp_mapper.apply( is_train, pre_context_sents, context_mask) context_sents = self.sentences_encoder.apply( pre_context_sents, context_sentence_segments, context_sentence_mask) context_sents = tf.identity(context_sents, name='debug') if self.sentence_mapper is not None: with tf.variable_scope('sentence_mapper'): context_sents = self.sentence_mapper.apply( is_train, context_sents, mask=context_sentence_mask) sentences_logits = fully_connected( context_sents, 1, use_bias=True, activation=None, kernel_initializer=get_keras_initialization('glorot_uniform'), name='supporting_fact_fc') with tf.variable_scope("predict"): return self.predictor.apply( is_train, context_rep, answer, context_mask, yes_no_choice_logits=yes_no_choice_logits, yes_no_answer_logits=yes_no_answer_logits, sentence_logits=tf.squeeze(sentences_logits, axis=[2]), sentence_mask=context_sentence_mask)
def _apply_transposed(self, is_train, x): w_init = get_keras_initialization(self.w_init) r_init = None if self.recurrent_init is None else get_keras_initialization( self.recurrent_init) x_size = x.shape.as_list()[-1] if x_size is None: raise ValueError("Last dimension must be defined (have shape %s)" % str(x.shape)) if self._kind == "GRU": cell = cudnn_rnn_ops.CudnnGRU(self.n_layers, self.n_units, x_size, input_mode="linear_input") elif self._kind == "LSTM": cell = cudnn_rnn_ops.CudnnLSTM(self.n_layers, self.n_units, x_size, input_mode="linear_input") else: raise ValueError() n_params = cell.params_size().eval() weights, biases = cell.params_to_canonical(tf.zeros([n_params])) def init(shape, dtype=None, partition_info=None): # This a bit hacky, since the api for these models is akward. We have to compute the shape of # the weights / biases by calling `cell.params_to_canonical` with a unused tensor, and then # use .eval() to actually get the shape. Then we can apply the user-requested initialzers if self._kind == "LSTM": is_recurrent = [ False, False, False, False, True, True, True, True ] is_forget_bias = [ False, True, False, False, False, True, False, False ] else: is_recurrent = [False, False, False, True, True, True] is_forget_bias = [False] * 6 init_biases = [ tf.constant(self.lstm_bias / 2.0, tf.float32, (self.n_units, )) if z else tf.zeros(self.n_units) for z in is_forget_bias ] init_weights = [] for w, r in zip(weights, is_recurrent): if r and r_init is not None: init_weights.append( tf.reshape( r_init((self.n_units, self.n_units), w.dtype), tf.shape(w))) else: init_weights.append(w_init(tf.shape(w).eval(), w.dtype)) out = cell.canonical_to_params(init_weights, init_biases) out.set_shape((n_params, )) return out parameters = tf.get_variable("gru_parameters", n_params, tf.float32, initializer=init) if self.keep_recurrent < 1: # Not super well test, try to figure out which indices in `parameters` are recurrent weights and drop them # this is implementing drop-connect for the recurrent weights is_recurrent = weights[:len(weights) // 2] + [ tf.ones_like(w) for w in weights[len(weights) // 2:] ] recurrent_mask = cell.canonical_to_params( is_recurrent, biases) # ones at recurrent weights recurrent_mask = 1 - recurrent_mask * ( 1 - self.keep_recurrent ) # ones are non-recurrent param, keep_prob elsewhere parameters = tf.cond( is_train, lambda: tf.floor( tf.random_uniform( (n_params, )) + recurrent_mask) * parameters, lambda: parameters) if self._kind == "LSTM": if self.learn_initial_states: raise NotImplementedError() else: initial_state_h = tf.zeros( (self.n_layers, tf.shape(x)[1], self.n_units), tf.float32) initial_state_c = tf.zeros( (self.n_layers, tf.shape(x)[1], self.n_units), tf.float32) out = cell(x, initial_state_h, initial_state_c, parameters, True) else: if self.learn_initial_states: initial_state = tf.get_variable("initial_state", self.n_units, tf.float32, tf.zeros_initializer()) initial_state = tf.tile( tf.expand_dims(tf.expand_dims(initial_state, 0), 0), [self.n_layers, tf.shape(x)[1], 1]) else: initial_state = tf.zeros( (self.n_layers, tf.shape(x)[1], self.n_units), tf.float32) out = cell(x, initial_state, parameters, True) return out
def apply(self, is_train, context_embed, answer, context_mask=None): init_fn = get_keras_initialization(self.init) m1, m2 = self.predictor.apply(is_train, context_embed, context_mask) if m1.shape.as_list()[-1] != 1: with tf.variable_scope("start_pred"): start_logits = fully_connected(m1, 1, activation_fn=None, weights_initializer=init_fn) else: start_logits = m1 start_logits = tf.squeeze(start_logits, axis=[2]) if m1.shape.as_list()[-1] != 1: with tf.variable_scope("end_pred"): end_logits = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn) else: end_logits = m2 end_logits = tf.squeeze(end_logits, axis=[2]) masked_start_logits = exp_mask(start_logits, context_mask) masked_end_logits = exp_mask(end_logits, context_mask) start_atten = tf.einsum("ajk,aj->ak", m1, tf.nn.softmax(masked_start_logits)) end_atten = tf.einsum("ajk,aj->ak", m2, tf.nn.softmax(masked_end_logits)) with tf.variable_scope("encode_context"): enc = self.encoder.apply(is_train, context_embed, context_mask) if len(enc.shape) == 3: _, encodings, fe = enc.shape.as_list() enc = tf.reshape(enc, (-1, encodings * fe)) with tf.variable_scope("confidence"): conf = [start_atten, end_atten, enc] none_logit = self.confidence_predictor.apply( is_train, tf.concat(conf, axis=1)) with tf.variable_scope("confidence_logits"): none_logit = fully_connected(none_logit, 1, activation_fn=None, weights_initializer=init_fn) none_logit = tf.squeeze(none_logit, axis=1) batch_dim = tf.shape(start_logits)[0] # (batch, (l * l)) logits for each (start, end) pair all_logits = tf.reshape( tf.expand_dims(masked_start_logits, 1) + tf.expand_dims(masked_end_logits, 2), (batch_dim, -1)) # (batch, (l * l) + 1) logits including the none option all_logits = tf.concat( [all_logits, tf.expand_dims(none_logit, 1)], axis=1) log_norms = tf.reduce_logsumexp(all_logits, axis=1) # Now build a "correctness" mask in the same format correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2)) correct_mask = tf.reshape(correct_mask, (batch_dim, -1)) correct_mask = tf.concat([ correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True)) ], axis=1) # Note we are happily allowing the model to place weights on "backwards" spans, and also giving # it points for predicting spans that start and end at different answer spans. It would be easy to # fix by masking out some of the `all_logit` matrix and specify a more accuracy correct_mask, but I # in general left it this way to be consistent with the independent bound models that do the same. # Some early tests found properly masking things to not make much difference (or even to hurt), but it # still could be an avenue for improvement log_correct = tf.reduce_logsumexp( all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1) loss = tf.reduce_mean(-(log_correct - log_norms)) probs = tf.nn.softmax(all_logits) tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits, probs[:, -1], none_logit, context_mask)
def apply(self, is_train, x, mask=None): batch_size = tf.shape(x)[0] x_word_dim = tf.shape(x)[1] x_feature_dim = x.shape.as_list()[-1] project_size = self.project_size if project_size is None: project_size = x_feature_dim // self.n_heads if x_feature_dim % self.n_heads != 0: raise ValueError() mem_size = self.memory_size if mem_size is None: mem_size = project_size init = get_keras_initialization(self.init) query_proj = tf.get_variable( "query_proj", (x_feature_dim, self.n_heads, project_size), initializer=init) if self.shared_project: key_proj = query_proj else: key_proj = tf.get_variable( "key_proj", (x_feature_dim, self.n_heads, project_size), initializer=init) mem_proj = tf.get_variable("mem_proj", (x_feature_dim, self.n_heads, mem_size), initializer=init) queries = tf.tensordot( x, query_proj, [[2], [0]]) # (batch, word, n_head, project_size) keys = tf.tensordot(x, key_proj, [[2], [0]]) # (batch, key, n_head, project_size) if self.project_bias: queries += tf.get_variable("query_bias", (1, 1, self.n_heads, project_size), initializer=tf.zeros_initializer()) keys += tf.get_variable("key_bias", (1, 1, self.n_heads, project_size), initializer=tf.zeros_initializer()) # dist_matrix = tf.matmul(queries, keys, transpose_b=True) dist_matrix = tf.einsum("bwhd,bkhd->bwkh", queries, keys) # dots of (batch, word, key, head) if self.scale: dist_matrix /= tf.sqrt(float(project_size)) if self.bilinear_comp: query_bias_proj = tf.get_variable("query_bias_proj", (x_feature_dim, self.n_heads), initializer=init) key_bias_proj = tf.get_variable("query_bias_proj", (x_feature_dim, self.n_heads), initializer=init) dist_matrix += tf.expand_dims( tf.tensordot(x, query_bias_proj, [[2], [0]]), 2) dist_matrix += tf.expand_dims( tf.tensordot(x, key_bias_proj, [[2], [0]]), 1) joint_mask = compute_attention_mask(mask, mask, x_word_dim, x_word_dim) if joint_mask is not None: dist_matrix += tf.expand_dims( VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype)), 2) dist_matrix += tf.expand_dims( tf.expand_dims(tf.eye(x_word_dim) * VERY_NEGATIVE_NUMBER, 0), 2) if self.bias: bias = tf.get_variable("bias", (1, 1, self.n_heads, 1), initializer=tf.zeros_initializer()) dist_matrix += bias select_probs = tf.nn.softmax( dist_matrix) # for each (batch, word, head) probability over keys memories = tf.tensordot(x, mem_proj, [[2], [0]]) # (batch, memory, head, mem_size) response = tf.einsum("bwhk,bkhd->bwhd", select_probs, memories) # (batch, word, head, mem_size) response = tf.reshape( response, (batch_size, x_word_dim, self.n_heads * mem_size)) # concat heads if self.merge is not None: with tf.variable_scope("merge"): response = self.merge.apply(is_train, x, response) return response