Exemplo n.º 1
0
    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 3:
            group_ids = answer[2]
            # Turn the ids into segment ids using tf.unique
            _, group_segments = tf.unique(group_ids, out_idx=tf.int32)

            losses = []
            for answer_mask, logits in zip(
                    answer, [masked_start_logits, masked_end_logits]):
                group_norms = segment_logsumexp(logits, group_segments)
                if self.aggregate == "sum":
                    log_score = segment_logsumexp(
                        logits + VERY_NEGATIVE_NUMBER *
                        (1 - tf.cast(answer_mask, tf.float32)), group_segments)
                else:
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - group_norms)))
            loss = tf.add_n(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  tf.nn.softmax(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
Exemplo n.º 2
0
    def test_segment_log_sum_exp(self):
        sess = self.sess
        with sess.as_default():
            for i in range(10):
                groups = []
                for group_id in range(10):
                    group = []
                    for _ in range(np.random.randint(1, 5)):
                        group.append(np.random.normal(0, 2, 10))
                    groups.append(group)

                flat_groups = np.stack(flatten_iterable(groups), axis=0)
                semgents = np.array(
                    flatten_iterable([ix] * len(g)
                                     for ix, g in enumerate(groups)))
                actual = sess.run(segment_logsumexp(flat_groups, semgents))
                expected = [
                    np.log(np.sum(np.exp(np.concatenate(g, axis=0))))
                    for g in groups
                ]
                self.assertTrue(np.allclose(actual, expected))