def predict(self, answer, start_logits, end_logits, mask) -> Prediction: bound = self.bound f1_weight = self.f1_weight aggregate = self.aggregate masked_logits1 = exp_mask(start_logits, mask) masked_logits2 = exp_mask(end_logits, mask) span_logits = [] for i in range(self.bound): if i == 0: span_logits.append(masked_logits1 + masked_logits2) else: span_logits.append(masked_logits1[:, :-i] + masked_logits2[:, i:]) span_logits = tf.concat(span_logits, axis=1) l = tf.shape(start_logits)[1] if len(answer) == 1: answer = answer[0] if answer.dtype == tf.int32: if f1_weight == 0: answer_ix = to_packed_coordinates(answer, l, bound) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=span_logits, labels=answer_ix)) else: f1_mask = packed_span_f1_mask(answer, l, bound) if f1_weight < 1: f1_mask *= f1_weight f1_mask += (1 - f1_weight) * tf.one_hot( to_packed_coordinates(answer, l, bound), l) # TODO can we stay in log space? (actually its tricky since f1_mask can have zeros...) probs = tf.nn.softmax(span_logits) loss = -tf.reduce_mean( tf.log(tf.reduce_sum(probs * f1_mask, axis=1))) else: log_norm = tf.reduce_logsumexp(span_logits, axis=1) if aggregate == "sum": log_score = tf.reduce_logsumexp( span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) elif aggregate == "max": log_score = tf.reduce_max( span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) else: raise NotImplementedError() loss = tf.reduce_mean(-(log_score - log_norm)) else: raise NotImplementedError() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return PackedSpanPrediction(span_logits, l, bound)
def apply(self, is_train, context_embed, answer, context_mask=None): init_fn = get_keras_initialization(self.init) bool_mask = tf.sequence_mask(context_mask, tf.shape(context_embed)[1]) with tf.variable_scope("predict"): m1, m2 = self.mapper.apply(is_train, context_embed, context_mask) if self.pre_process is not None: with tf.variable_scope("pre-process1"): m1 = self.pre_process.apply(is_train, m1, context_mask) with tf.variable_scope("pre-process2"): m2 = self.pre_process.apply(is_train, m2, context_mask) span_vector_lst = [] mask_lst = [] with tf.variable_scope("merge"): span_vector_lst.append(self.merge.apply(is_train, m1, m2)) mask_lst.append(bool_mask) for i in range(1, self.bound): with tf.variable_scope("merge", reuse=True): span_vector_lst.append( self.merge.apply(is_train, m1[:, :-i], m2[:, i:])) mask_lst.append(bool_mask[:, i:]) mask = tf.concat(mask_lst, axis=1) span_vectors = tf.concat( span_vector_lst, axis=1) # all logits -> flattened per-span predictions if self.post_process is not None: with tf.variable_scope("post-process"): span_vectors = self.post_process.apply(is_train, span_vectors) with tf.variable_scope("compute_logits"): logits = fully_connected(span_vectors, 1, activation_fn=None, weights_initializer=init_fn) logits = tf.squeeze(logits, squeeze_dims=[2]) logits = logits + VERY_NEGATIVE_NUMBER * ( 1 - tf.cast(tf.concat(mask, axis=1), tf.float32)) l = tf.shape(context_embed)[1] if len(answer) == 1: answer = answer[0] if answer.dtype == tf.int32: if self.f1_weight == 0: answer_ix = to_packed_coordinates(answer, l, self.bound) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=answer_ix)) else: f1_mask = packed_span_f1_mask(answer, l, self.bound) if self.f1_weight < 1: f1_mask *= self.f1_weight f1_mask += (1 - self.f1_weight) * tf.one_hot( to_packed_coordinates(answer, l, self.bound), l) # TODO can we stay in log space? (actually its tricky since f1_mask can have zeros...) probs = tf.nn.softmax(logits) loss = -tf.reduce_mean( tf.log(tf.reduce_sum(probs * f1_mask, axis=1))) else: log_norm = tf.reduce_logsumexp(logits, axis=1) if self.aggregate == "sum": log_score = tf.reduce_logsumexp( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) elif self.aggregate == "max": log_score = tf.reduce_max( logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)), axis=1) else: raise NotImplementedError() loss = tf.reduce_mean(-(log_score - log_norm)) else: raise NotImplementedError() tf.add_to_collection(tf.GraphKeys.LOSSES, loss) return PackedSpanPrediction(logits, l, self.bound)