def compute_loss(self, batch_source, batch_pos, batch_neg):
        """ Compute the loss (BPR) for a batch of examples
        :param batch_source: a batch of source keyphrase indices (list of lists)
        :param batch_pos: True aliases of the Mentions
        :param batch_neg: False aliases of the Mentions
        """

        batch_size = len(batch_source)

        avg_source = torch.mean(batch_source, dim=1)
        avg_pos = torch.mean(batch_pos, dim=1)
        avg_neg = torch.mean(batch_neg, dim=1)

        # B by dim
        source_embed = self.embed(avg_source)
        # B by dim
        pos_embed = self.embed(avg_pos)
        # B by dim
        neg_embed = self.embed(avg_neg)

        loss = self._bce_loss(
            utils.row_wise_dot(source_embed, pos_embed) -
            utils.row_wise_dot(source_embed, neg_embed),
            self.ones[:batch_size])
        return loss
    def score_pair(self, source, target, source_len, target_len):
        """

        :param source: Batchsize by Max_String_Length
        :param target: Batchsize by Max_String_Length
        :return: Batchsize by 1
        """
        source_embed = self.embed_dev(source, source_len)
        target_embed = self.embed_dev(target, target_len)
        scores = utils.row_wise_dot(source_embed, target_embed)
        return scores
Exemple #3
0
    def compute_loss(self, batch_source, pos_result, neg_result, batch_lengths,
                     pos_len, neg_len):
        """ Compute the loss (BPR) for a batch of examples
        :param batch_source: a batch of source keyphrase indices (list of lists)
        :param pos_result: True aliases of the Mentions
        :param neg_result: False aliases of the Mentions
        :param batch_lengths: a list of sample lengths, one for each sample in the batch (list of lists)
        :param pos_len: lengths of positives
        :param neg_len: lengths of negatives
        :return:
        """

        batch_size = len(batch_source)
        # B by dim
        source_embed = self.embed(batch_source, batch_lengths)
        # B by dim
        pos_embed = self.embed(pos_result, pos_len)
        # B by dim
        neg_embed = self.embed(neg_result, neg_len)
        loss = self._bce_loss(
            utils.row_wise_dot(source_embed, pos_embed) -
            utils.row_wise_dot(source_embed, neg_embed),
            self.ones[:batch_size])
        return loss
Exemple #4
0
    def score_dev_test_batch(self, batch_queries, batch_targets, batch_size):

        if batch_size == self.config.dev_batch_size:
            source_embed = self.embed_dev(batch_queries)
            target_embed = self.embed_dev(batch_targets)
        else:
            source_embed = self.embed_dev(batch_queries, batch_size=batch_size)
            target_embed = self.embed_dev(batch_targets, batch_size=batch_size)

        scores = utils.row_wise_dot(source_embed, target_embed)

        # what is this?
        scores[scores != scores] = 0

        return scores