Beispiel #1
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        bound = self.bound
        f1_weight = self.f1_weight
        aggregate = self.aggregate
        masked_logits1 = exp_mask(start_logits, mask)
        masked_logits2 = exp_mask(end_logits, mask)

        span_logits = []
        for i in range(self.bound):
            if i == 0:
                span_logits.append(masked_logits1 + masked_logits2)
            else:
                span_logits.append(masked_logits1[:, :-i] +
                                   masked_logits2[:, i:])
        span_logits = tf.concat(span_logits, axis=1)
        l = tf.shape(start_logits)[1]

        if len(answer) == 1:
            answer = answer[0]
            if answer.dtype == tf.int32:
                if f1_weight == 0:
                    answer_ix = to_packed_coordinates(answer, l, bound)
                    loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=span_logits, labels=answer_ix))
                else:
                    f1_mask = packed_span_f1_mask(answer, l, bound)
                    if f1_weight < 1:
                        f1_mask *= f1_weight
                        f1_mask += (1 - f1_weight) * tf.one_hot(
                            to_packed_coordinates(answer, l, bound), l)
                    # TODO can we stay in log space?  (actually its tricky since f1_mask can have zeros...)
                    probs = tf.nn.softmax(span_logits)
                    loss = -tf.reduce_mean(
                        tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
            else:
                log_norm = tf.reduce_logsumexp(span_logits, axis=1)
                if aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        span_logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                elif aggregate == "max":
                    log_score = tf.reduce_max(
                        span_logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                else:
                    raise NotImplementedError()
                loss = tf.reduce_mean(-(log_score - log_norm))
        else:
            raise NotImplementedError()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return PackedSpanPrediction(span_logits, l, bound)
Beispiel #2
0
    def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        bool_mask = tf.sequence_mask(context_mask, tf.shape(context_embed)[1])

        with tf.variable_scope("predict"):
            m1, m2 = self.mapper.apply(is_train, context_embed, context_mask)

        if self.pre_process is not None:
            with tf.variable_scope("pre-process1"):
                m1 = self.pre_process.apply(is_train, m1, context_mask)
            with tf.variable_scope("pre-process2"):
                m2 = self.pre_process.apply(is_train, m2, context_mask)

        span_vector_lst = []
        mask_lst = []
        with tf.variable_scope("merge"):
            span_vector_lst.append(self.merge.apply(is_train, m1, m2))
        mask_lst.append(bool_mask)
        for i in range(1, self.bound):
            with tf.variable_scope("merge", reuse=True):
                span_vector_lst.append(
                    self.merge.apply(is_train, m1[:, :-i], m2[:, i:]))
            mask_lst.append(bool_mask[:, i:])

        mask = tf.concat(mask_lst, axis=1)
        span_vectors = tf.concat(
            span_vector_lst,
            axis=1)  # all logits -> flattened per-span predictions

        if self.post_process is not None:
            with tf.variable_scope("post-process"):
                span_vectors = self.post_process.apply(is_train, span_vectors)

        with tf.variable_scope("compute_logits"):
            logits = fully_connected(span_vectors,
                                     1,
                                     activation_fn=None,
                                     weights_initializer=init_fn)

        logits = tf.squeeze(logits, squeeze_dims=[2])
        logits = logits + VERY_NEGATIVE_NUMBER * (
            1 - tf.cast(tf.concat(mask, axis=1), tf.float32))

        l = tf.shape(context_embed)[1]

        if len(answer) == 1:
            answer = answer[0]
            if answer.dtype == tf.int32:
                if self.f1_weight == 0:
                    answer_ix = to_packed_coordinates(answer, l, self.bound)
                    loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=logits, labels=answer_ix))
                else:
                    f1_mask = packed_span_f1_mask(answer, l, self.bound)
                    if self.f1_weight < 1:
                        f1_mask *= self.f1_weight
                        f1_mask += (1 - self.f1_weight) * tf.one_hot(
                            to_packed_coordinates(answer, l, self.bound), l)

                    # TODO can we stay in log space?  (actually its tricky since f1_mask can have zeros...)
                    probs = tf.nn.softmax(logits)
                    loss = -tf.reduce_mean(
                        tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
            else:
                log_norm = tf.reduce_logsumexp(logits, axis=1)
                if self.aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                elif self.aggregate == "max":
                    log_score = tf.reduce_max(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                else:
                    raise NotImplementedError()
                loss = tf.reduce_mean(-(log_score - log_norm))
        else:
            raise NotImplementedError()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return PackedSpanPrediction(logits, l, self.bound)