def predict(self, answer, start_logits, end_logits, mask) -> Prediction: masked_start_logits = exp_mask(start_logits, mask) masked_end_logits = exp_mask(end_logits, mask) if len(answer) == 3: group_ids = answer[2] # Turn the ids into segment ids using tf.unique _, group_segments = tf.unique(group_ids, out_idx=tf.int32) losses = [] for answer_mask, logits in zip( answer, [masked_start_logits, masked_end_logits]): group_norms = segment_logsumexp(logits, group_segments) if self.aggregate == "sum": log_score = segment_logsumexp( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), group_segments) else: raise ValueError() losses.append(tf.reduce_mean(-(log_score - group_norms))) loss = tf.add_n(losses) else: raise NotImplemented() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return BoundaryPrediction(tf.nn.softmax(masked_start_logits), tf.nn.softmax(masked_end_logits), masked_start_logits, masked_end_logits, mask)
def predict(self, answer, start_logits, end_logits, mask) -> Prediction: l = tf.shape(start_logits)[1] masked_start_logits = exp_mask(start_logits, mask) masked_end_logits = exp_mask(end_logits, mask) # Explicit score for each span span_scores = tf.expand_dims(start_logits, 2) + tf.expand_dims( end_logits, 1) # Mask for in-bound spans, now (batch, start, end) matrix mask = tf.sequence_mask(mask, l) mask = tf.logical_and(tf.expand_dims(mask, 2), tf.expand_dims(mask, 1)) # Also mask out spans that are negative/inverse by taking only the upper triangle mask = tf.matrix_band_part(mask, 0, self.bound) # Apply the mask mask = tf.cast(mask, tf.float32) span_scores = span_scores * mask + (1 - mask) * VERY_NEGATIVE_NUMBER if len(answer) == 1: answer = answer[0] span_scores = tf.reshape(span_scores, (tf.shape(start_logits)[0], -1)) answer = answer[:, 0] * l + answer[:, 1] losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=span_scores, labels=answer) loss = tf.reduce_mean(losses) else: raise NotImplemented() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return BoundaryPrediction(tf.nn.softmax(masked_start_logits), tf.nn.softmax(masked_end_logits), masked_start_logits, masked_end_logits, mask)
def predict(self, answer, start_logits, end_logits, yes_no_logits, mask) -> Prediction: masked_start_logits = exp_mask(start_logits, mask) masked_end_logits = exp_mask(end_logits, mask) answer_yes_no = answer[-1] losses_yes_no = tf.nn.softmax_cross_entropy_with_logits( logits=yes_no_logits, labels=answer_yes_no) if len(answer) == 2: # answer span is encoding in a sparse int array answer_spans = answer[0] losses1 = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=masked_start_logits, labels=answer_spans[:, 0]) losses2 = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=masked_end_logits, labels=answer_spans[:, 1]) loss = tf.add_n([ tf.reduce_mean(losses1), tf.reduce_mean(losses2), tf.reduce_mean(losses_yes_no) ], name="loss") elif len(answer) == 3 and all(x.dtype == tf.int32 for x in answer): # all correct start/end bounds are marked in a dense bool array # In this case there might be multiple answer spans, so we need an aggregation strategy losses = [] for answer_mask, logits in zip( answer[:-1], [masked_start_logits, masked_end_logits]): log_norm = tf.reduce_logsumexp(logits, axis=1) if self.aggregate == "sum": log_score = tf.reduce_logsumexp( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), axis=1) elif self.aggregate == "max": log_score = tf.reduce_max( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), axis=1) else: raise ValueError() losses.append(tf.reduce_mean(-(log_score - log_norm))) losses.append(tf.reduce_mean(losses_yes_no)) loss = tf.add_n(losses) else: raise NotImplemented() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return BoundaryAndYesNoPrediction(tf.nn.softmax(masked_start_logits), tf.nn.softmax(masked_end_logits), masked_start_logits, masked_end_logits, tf.argmax(answer_yes_no, axis=-1), mask)
def predict(self, answer, start_logits, end_logits, mask) -> Prediction: bound = self.bound f1_weight = self.f1_weight aggregate = self.aggregate masked_logits1 = exp_mask(start_logits, mask) masked_logits2 = exp_mask(end_logits, mask) span_logits = [] for i in range(self.bound): if i == 0: span_logits.append(masked_logits1 + masked_logits2) else: span_logits.append(masked_logits1[:, :-i] + masked_logits2[:, i:]) span_logits = tf.concat(span_logits, axis=1) l = tf.shape(start_logits)[1] if len(answer) == 1: answer = answer[0] if answer.dtype == tf.int32: if f1_weight == 0: answer_ix = to_packed_coordinates(answer, l, bound) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=span_logits, labels=answer_ix)) else: f1_mask = packed_span_f1_mask(answer, l, bound) if f1_weight < 1: f1_mask *= f1_weight f1_mask += (1 - f1_weight) * tf.one_hot( to_packed_coordinates(answer, l, bound), l) # TODO can we stay in log space? (actually its tricky since f1_mask can have zeros...) probs = tf.nn.softmax(span_logits) loss = -tf.reduce_mean( tf.log(tf.reduce_sum(probs * f1_mask, axis=1))) else: log_norm = tf.reduce_logsumexp(span_logits, axis=1) if aggregate == "sum": log_score = tf.reduce_logsumexp( span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) elif aggregate == "max": log_score = tf.reduce_max( span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) else: raise NotImplementedError() loss = tf.reduce_mean(-(log_score - log_norm)) else: raise NotImplementedError() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return PackedSpanPrediction(span_logits, l, bound)
def apply(self, is_train, context_embed, context_mask=None): init_fn = get_keras_initialization(self.init) with tf.variable_scope("start_layer"): m1 = self.start_layer.apply(is_train, context_embed, context_mask) with tf.variable_scope("start_pred"): logits1 = fully_connected(tf.concat([m1, context_embed], axis=2), 1, activation=None, kernel_initializer=init_fn) masked_logits1 = exp_mask(tf.squeeze(logits1, squeeze_dims=[2]), context_mask) prediction1 = tf.nn.softmax(masked_logits1) m2_input = [] if self.use_original: m2_input.append(context_embed) if self.use_start_layer: m2_input.append(m1) if self.soft_select_start_word: soft_select = tf.einsum("ai,aik->ak", prediction1, m1) soft_select_tiled = tf.tile(tf.expand_dims(soft_select, axis=1), [1, tf.shape(m1)[1], 1]) m2_input += [soft_select_tiled, soft_select_tiled * m1] with tf.variable_scope("end_layer"): m2 = self.end_layer.apply(is_train, tf.concat(m2_input, axis=2), context_mask) with tf.variable_scope("end_pred"): logits2 = fully_connected(tf.concat([m2, context_embed], axis=2), 1, activation=None, kernel_initializer=init_fn) return logits1, logits2
def predict(self, answer, start_logits, end_logits, mask) -> Prediction: masked_start_logits = exp_mask(start_logits, mask) masked_end_logits = exp_mask(end_logits, mask) batch_dim = tf.shape(start_logits)[0] if len(answer) == 2 and all(x.dtype == tf.bool for x in answer): none_logit = tf.get_variable("none-logit", initializer=self.non_init, dtype=tf.float32) none_logit = tf.tile(tf.expand_dims(none_logit, 0), [batch_dim]) all_logits = tf.reshape( tf.expand_dims(masked_start_logits, 1) + tf.expand_dims(masked_end_logits, 2), (batch_dim, -1)) # (batch, (l * l) + 1) logits including the none option all_logits = tf.concat( [all_logits, tf.expand_dims(none_logit, 1)], axis=1) log_norms = tf.reduce_logsumexp(all_logits, axis=1) # Now build a "correctness" mask in the same format correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2)) correct_mask = tf.reshape(correct_mask, (batch_dim, -1)) correct_mask = tf.concat([ correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True)) ], axis=1) log_correct = tf.reduce_logsumexp( all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1) loss = tf.reduce_mean(-(log_correct - log_norms)) probs = tf.nn.softmax(all_logits) tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits, probs[:, -1], none_logit) else: raise NotImplemented()
def predict(self, answer, start_logits, end_logits, mask) -> Prediction: masked_start_logits = exp_mask(start_logits, mask) masked_end_logits = exp_mask(end_logits, mask) if len(answer) == 1: raise NotImplementedError() elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer): losses = [] for answer_mask, logits in zip( answer, [masked_start_logits, masked_end_logits]): answer_mask = tf.cast(answer_mask, tf.float32) loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast( answer_mask, tf.float32), logits=logits) losses.append(loss) loss = tf.add_n(losses) else: raise NotImplemented() tf.add_to_collection(tf.GraphKeys.LOSSES, tf.reduce_mean(loss, name="sigmoid-loss")) return BoundaryPrediction(tf.nn.sigmoid(masked_start_logits), tf.nn.sigmoid(masked_end_logits), masked_start_logits, masked_end_logits, mask)
def apply(self, is_train, x, mask=None): if self.key_mapper is not None: with tf.variable_scope("map_keys"): keys = self.key_mapper.apply(is_train, x, mask) else: keys = x weights = tf.get_variable("weights", keys.shape.as_list()[-1], dtype=tf.float32, initializer=get_keras_initialization(self.init)) dist = tf.tensordot(keys, weights, axes=[[2], [0]]) # (batch, x_words) dist = exp_mask(dist, mask) dist = tf.nn.softmax(dist) out = tf.einsum("ajk,aj->ak", x, dist) # (batch, x_dim) if self.post_process is not None: with tf.variable_scope("post_process"): out = self.post_process.apply(is_train, out) return out
def apply(self, is_train, context_embed, answer, context_mask=None): init_fn = get_keras_initialization(self.init) m1, m2 = self.predictor.apply(is_train, context_embed, context_mask) if m1.shape.as_list()[-1] != 1: with tf.variable_scope("start_pred"): start_logits = fully_connected(m1, 1, activation_fn=None, weights_initializer=init_fn) else: start_logits = m1 start_logits = tf.squeeze(start_logits, squeeze_dims=[2]) if m1.shape.as_list()[-1] != 1: with tf.variable_scope("end_pred"): end_logits = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn) else: end_logits = m2 end_logits = tf.squeeze(end_logits, squeeze_dims=[2]) masked_start_logits = exp_mask(start_logits, context_mask) masked_end_logits = exp_mask(end_logits, context_mask) start_atten = tf.einsum("ajk,aj->ak", m1, tf.nn.softmax(masked_start_logits)) end_atten = tf.einsum("ajk,aj->ak", m2, tf.nn.softmax(masked_end_logits)) with tf.variable_scope("encode_context"): enc = self.encoder.apply(is_train, context_embed, context_mask) if len(enc.shape) == 3: _, encodings, fe = enc.shape.as_list() enc = tf.reshape(enc, (-1, encodings * fe)) with tf.variable_scope("confidence"): conf = [start_atten, end_atten, enc] none_logit = self.confidence_predictor.apply( is_train, tf.concat(conf, axis=1)) with tf.variable_scope("confidence_logits"): none_logit = fully_connected(none_logit, 1, activation_fn=None, weights_initializer=init_fn) none_logit = tf.squeeze(none_logit, axis=1) batch_dim = tf.shape(start_logits)[0] # (batch, (l * l)) logits for each (start, end) pair all_logits = tf.reshape( tf.expand_dims(masked_start_logits, 1) + tf.expand_dims(masked_end_logits, 2), (batch_dim, -1)) # (batch, (l * l) + 1) logits including the none option all_logits = tf.concat( [all_logits, tf.expand_dims(none_logit, 1)], axis=1) log_norms = tf.reduce_logsumexp(all_logits, axis=1) # Now build a "correctness" mask in the same format correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2)) correct_mask = tf.reshape(correct_mask, (batch_dim, -1)) correct_mask = tf.concat([ correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True)) ], axis=1) # Note we are happily allowing the model to place weights on "backwards" spans, and also giving # it points for predicting spans that start and end at different answer spans. It would be easy to # fix by masking out some of the `all_logit` matrix and specify a more accuracy correct_mask, but I # in general left it this way to be consistent with the independent bound models that do the same. # Some early tests found properly masking things to not make much difference (or even to hurt), but it # still could be an avenue for improvement log_correct = tf.reduce_logsumexp( all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1) loss = tf.reduce_mean(-(log_correct - log_norms)) probs = tf.nn.softmax(all_logits) tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits, probs[:, -1], none_logit, context_mask)