Example #1
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 3:
            group_ids = answer[2]
            # Turn the ids into segment ids using tf.unique
            _, group_segments = tf.unique(group_ids, out_idx=tf.int32)

            losses = []
            for answer_mask, logits in zip(
                    answer, [masked_start_logits, masked_end_logits]):
                group_norms = segment_logsumexp(logits, group_segments)
                if self.aggregate == "sum":
                    log_score = segment_logsumexp(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer_mask, tf.float32)), group_segments)
                else:
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - group_norms)))
            loss = tf.add_n(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  tf.nn.softmax(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
Example #2
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        l = tf.shape(start_logits)[1]
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        # Explicit score for each span
        span_scores = tf.expand_dims(start_logits, 2) + tf.expand_dims(
            end_logits, 1)

        # Mask for in-bound spans, now (batch, start, end) matrix
        mask = tf.sequence_mask(mask, l)
        mask = tf.logical_and(tf.expand_dims(mask, 2), tf.expand_dims(mask, 1))

        # Also mask out spans that are negative/inverse by taking only the upper triangle
        mask = tf.matrix_band_part(mask, 0, self.bound)

        # Apply the mask
        mask = tf.cast(mask, tf.float32)
        span_scores = span_scores * mask + (1 - mask) * VERY_NEGATIVE_NUMBER

        if len(answer) == 1:
            answer = answer[0]
            span_scores = tf.reshape(span_scores,
                                     (tf.shape(start_logits)[0], -1))
            answer = answer[:, 0] * l + answer[:, 1]
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=span_scores, labels=answer)
            loss = tf.reduce_mean(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  tf.nn.softmax(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
Example #3
0
    def predict(self, answer, start_logits, end_logits, yes_no_logits,
                mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        answer_yes_no = answer[-1]
        losses_yes_no = tf.nn.softmax_cross_entropy_with_logits(
            logits=yes_no_logits, labels=answer_yes_no)

        if len(answer) == 2:
            # answer span is encoding in a sparse int array
            answer_spans = answer[0]
            losses1 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_start_logits, labels=answer_spans[:, 0])
            losses2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_end_logits, labels=answer_spans[:, 1])

            loss = tf.add_n([
                tf.reduce_mean(losses1),
                tf.reduce_mean(losses2),
                tf.reduce_mean(losses_yes_no)
            ],
                            name="loss")
        elif len(answer) == 3 and all(x.dtype == tf.int32 for x in answer):
            # all correct start/end bounds are marked in a dense bool array
            # In this case there might be multiple answer spans, so we need an aggregation strategy
            losses = []
            for answer_mask, logits in zip(
                    answer[:-1], [masked_start_logits, masked_end_logits]):
                log_norm = tf.reduce_logsumexp(logits, axis=1)
                if self.aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer_mask, tf.float32)),
                        axis=1)
                elif self.aggregate == "max":
                    log_score = tf.reduce_max(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer_mask, tf.float32)),
                        axis=1)
                else:
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - log_norm)))

            losses.append(tf.reduce_mean(losses_yes_no))
            loss = tf.add_n(losses)
        else:
            raise NotImplemented()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryAndYesNoPrediction(tf.nn.softmax(masked_start_logits),
                                          tf.nn.softmax(masked_end_logits),
                                          masked_start_logits,
                                          masked_end_logits,
                                          tf.argmax(answer_yes_no,
                                                    axis=-1), mask)
Example #4
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        bound = self.bound
        f1_weight = self.f1_weight
        aggregate = self.aggregate
        masked_logits1 = exp_mask(start_logits, mask)
        masked_logits2 = exp_mask(end_logits, mask)

        span_logits = []
        for i in range(self.bound):
            if i == 0:
                span_logits.append(masked_logits1 + masked_logits2)
            else:
                span_logits.append(masked_logits1[:, :-i] +
                                   masked_logits2[:, i:])
        span_logits = tf.concat(span_logits, axis=1)
        l = tf.shape(start_logits)[1]

        if len(answer) == 1:
            answer = answer[0]
            if answer.dtype == tf.int32:
                if f1_weight == 0:
                    answer_ix = to_packed_coordinates(answer, l, bound)
                    loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=span_logits, labels=answer_ix))
                else:
                    f1_mask = packed_span_f1_mask(answer, l, bound)
                    if f1_weight < 1:
                        f1_mask *= f1_weight
                        f1_mask += (1 - f1_weight) * tf.one_hot(
                            to_packed_coordinates(answer, l, bound), l)
                    # TODO can we stay in log space?  (actually its tricky since f1_mask can have zeros...)
                    probs = tf.nn.softmax(span_logits)
                    loss = -tf.reduce_mean(
                        tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
            else:
                log_norm = tf.reduce_logsumexp(span_logits, axis=1)
                if aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        span_logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                elif aggregate == "max":
                    log_score = tf.reduce_max(
                        span_logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer, tf.float32)),
                        axis=1)
                else:
                    raise NotImplementedError()
                loss = tf.reduce_mean(-(log_score - log_norm))
        else:
            raise NotImplementedError()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return PackedSpanPrediction(span_logits, l, bound)
Example #5
0
    def apply(self, is_train, context_embed, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        with tf.variable_scope("start_layer"):
            m1 = self.start_layer.apply(is_train, context_embed, context_mask)

        with tf.variable_scope("start_pred"):
            logits1 = fully_connected(tf.concat([m1, context_embed], axis=2), 1,
                                      activation=None, kernel_initializer=init_fn)
            masked_logits1 = exp_mask(tf.squeeze(logits1, squeeze_dims=[2]), context_mask)
            prediction1 = tf.nn.softmax(masked_logits1)

        m2_input = []
        if self.use_original:
            m2_input.append(context_embed)
        if self.use_start_layer:
            m2_input.append(m1)
        if self.soft_select_start_word:
            soft_select = tf.einsum("ai,aik->ak", prediction1, m1)
            soft_select_tiled = tf.tile(tf.expand_dims(soft_select, axis=1), [1, tf.shape(m1)[1], 1])
            m2_input += [soft_select_tiled, soft_select_tiled * m1]

        with tf.variable_scope("end_layer"):
            m2 = self.end_layer.apply(is_train, tf.concat(m2_input, axis=2), context_mask)

        with tf.variable_scope("end_pred"):
            logits2 = fully_connected(tf.concat([m2, context_embed], axis=2), 1,
                                      activation=None, kernel_initializer=init_fn)

        return logits1, logits2
Example #6
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)
        batch_dim = tf.shape(start_logits)[0]

        if len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            none_logit = tf.get_variable("none-logit",
                                         initializer=self.non_init,
                                         dtype=tf.float32)
            none_logit = tf.tile(tf.expand_dims(none_logit, 0), [batch_dim])

            all_logits = tf.reshape(
                tf.expand_dims(masked_start_logits, 1) +
                tf.expand_dims(masked_end_logits, 2), (batch_dim, -1))

            # (batch, (l * l) + 1) logits including the none option
            all_logits = tf.concat(
                [all_logits, tf.expand_dims(none_logit, 1)], axis=1)
            log_norms = tf.reduce_logsumexp(all_logits, axis=1)

            # Now build a "correctness" mask in the same format
            correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1),
                                          tf.expand_dims(answer[1], 2))
            correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
            correct_mask = tf.concat([
                correct_mask,
                tf.logical_not(tf.reduce_any(answer[0], axis=1,
                                             keep_dims=True))
            ],
                                     axis=1)

            log_correct = tf.reduce_logsumexp(
                all_logits + VERY_NEGATIVE_NUMBER *
                (1 - tf.cast(correct_mask, tf.float32)),
                axis=1)
            loss = tf.reduce_mean(-(log_correct - log_norms))
            probs = tf.nn.softmax(all_logits)
            tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
            return ConfidencePrediction(probs[:, :-1], masked_start_logits,
                                        masked_end_logits, probs[:, -1],
                                        none_logit)
        else:
            raise NotImplemented()
Example #7
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 1:
            raise NotImplementedError()
        elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            losses = []
            for answer_mask, logits in zip(
                    answer, [masked_start_logits, masked_end_logits]):
                answer_mask = tf.cast(answer_mask, tf.float32)
                loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(
                    answer_mask, tf.float32),
                                                               logits=logits)
                losses.append(loss)
            loss = tf.add_n(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES,
                             tf.reduce_mean(loss, name="sigmoid-loss"))
        return BoundaryPrediction(tf.nn.sigmoid(masked_start_logits),
                                  tf.nn.sigmoid(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
Example #8
0
    def apply(self, is_train, x, mask=None):
        if self.key_mapper is not None:
            with tf.variable_scope("map_keys"):
                keys = self.key_mapper.apply(is_train, x, mask)
        else:
            keys = x

        weights = tf.get_variable("weights", keys.shape.as_list()[-1], dtype=tf.float32,
                                  initializer=get_keras_initialization(self.init))
        dist = tf.tensordot(keys, weights, axes=[[2], [0]])  # (batch, x_words)
        dist = exp_mask(dist, mask)
        dist = tf.nn.softmax(dist)

        out = tf.einsum("ajk,aj->ak", x, dist)  # (batch, x_dim)

        if self.post_process is not None:
            with tf.variable_scope("post_process"):
                out = self.post_process.apply(is_train, out)
        return out
Example #9
0
    def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        m1, m2 = self.predictor.apply(is_train, context_embed, context_mask)

        if m1.shape.as_list()[-1] != 1:
            with tf.variable_scope("start_pred"):
                start_logits = fully_connected(m1,
                                               1,
                                               activation_fn=None,
                                               weights_initializer=init_fn)
        else:
            start_logits = m1
        start_logits = tf.squeeze(start_logits, squeeze_dims=[2])

        if m1.shape.as_list()[-1] != 1:
            with tf.variable_scope("end_pred"):
                end_logits = fully_connected(m2,
                                             1,
                                             activation_fn=None,
                                             weights_initializer=init_fn)
        else:
            end_logits = m2
        end_logits = tf.squeeze(end_logits, squeeze_dims=[2])

        masked_start_logits = exp_mask(start_logits, context_mask)
        masked_end_logits = exp_mask(end_logits, context_mask)

        start_atten = tf.einsum("ajk,aj->ak", m1,
                                tf.nn.softmax(masked_start_logits))
        end_atten = tf.einsum("ajk,aj->ak", m2,
                              tf.nn.softmax(masked_end_logits))
        with tf.variable_scope("encode_context"):
            enc = self.encoder.apply(is_train, context_embed, context_mask)
        if len(enc.shape) == 3:
            _, encodings, fe = enc.shape.as_list()
            enc = tf.reshape(enc, (-1, encodings * fe))

        with tf.variable_scope("confidence"):
            conf = [start_atten, end_atten, enc]
            none_logit = self.confidence_predictor.apply(
                is_train, tf.concat(conf, axis=1))
        with tf.variable_scope("confidence_logits"):
            none_logit = fully_connected(none_logit,
                                         1,
                                         activation_fn=None,
                                         weights_initializer=init_fn)
            none_logit = tf.squeeze(none_logit, axis=1)

        batch_dim = tf.shape(start_logits)[0]

        # (batch, (l * l)) logits for each (start, end) pair
        all_logits = tf.reshape(
            tf.expand_dims(masked_start_logits, 1) +
            tf.expand_dims(masked_end_logits, 2), (batch_dim, -1))

        # (batch, (l * l) + 1) logits including the none option
        all_logits = tf.concat(
            [all_logits, tf.expand_dims(none_logit, 1)], axis=1)
        log_norms = tf.reduce_logsumexp(all_logits, axis=1)

        # Now build a "correctness" mask in the same format
        correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1),
                                      tf.expand_dims(answer[1], 2))
        correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
        correct_mask = tf.concat([
            correct_mask,
            tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True))
        ],
                                 axis=1)

        # Note we are happily allowing the model to place weights on "backwards" spans, and also giving
        # it points for predicting spans that start and end at different answer spans. It would be easy to
        # fix by masking out some of the `all_logit` matrix and specify a more accuracy correct_mask, but I
        # in general left it this way to be consistent with the independent bound models that do the same.
        # Some early tests found properly masking things to not make much difference (or even to hurt), but it
        # still could be an avenue for improvement

        log_correct = tf.reduce_logsumexp(
            all_logits + VERY_NEGATIVE_NUMBER *
            (1 - tf.cast(correct_mask, tf.float32)),
            axis=1)
        loss = tf.reduce_mean(-(log_correct - log_norms))
        probs = tf.nn.softmax(all_logits)
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return ConfidencePrediction(probs[:, :-1], masked_start_logits,
                                    masked_end_logits, probs[:, -1],
                                    none_logit, context_mask)