Пример #1
0
    def eval(
        self,
        scores: Scores,
        batch_lhs: EntityList,
        batch_rhs: EntityList,
        batch_rel: Union[int, LongTensorType],
    ) -> Stats:
        # Assume dynamic relations.
        assert isinstance(batch_rel, torch.LongTensor)

        _, _, lhs_neg_scores, rhs_neg_scores = scores
        b = batch_lhs.size(0)
        for idx in range(b):
            # Assume non-featurized.
            cur_lhs = batch_lhs.to_tensor()[idx].item()
            cur_rel = batch_rel[idx].item()
            # Assume non-featurized.
            cur_rhs = batch_rhs.to_tensor()[idx].item()

            rhs_edges_filtered = self.lhs_map[cur_lhs, cur_rel]
            lhs_edges_filtered = self.rhs_map[cur_rhs, cur_rel]
            assert cur_lhs in lhs_edges_filtered
            assert cur_rhs in rhs_edges_filtered

            # The rank is computed as the number of non-negative margins (as
            # that means a negative with at least as good a score as a positive)
            # so to avoid counting positives we give them a negative margin.
            lhs_neg_scores[idx][lhs_edges_filtered] = -1e9
            rhs_neg_scores[idx][rhs_edges_filtered] = -1e9

        return super().eval(scores, batch_lhs, batch_rhs, batch_rel)
Пример #2
0
    def prepare_negatives(
        self,
        pos_input: EntityList,
        pos_embs: FloatTensorType,
        module: AbstractEmbedding,
        type_: Negatives,
        num_uniform_neg: int,
        rel: Union[int, LongTensorType],
        entity_type: str,
        operator: Union[None, AbstractOperator, AbstractDynamicOperator],
    ) -> Tuple[FloatTensorType, Mask]:
        """Given some chunked positives, set up chunks of negatives.

        This function operates on one side (left-hand or right-hand) at a time.
        It takes all the information about the positives on that side (the
        original input value, the corresponding embeddings, and the module used
        to convert one to the other). It then produces negatives for that side
        according to the specified mode. The positive embeddings come in in
        chunked form and the negatives are produced within each of these chunks.
        The negatives can be either none, or the positives from the same chunk,
        or all the possible entities. In the second mode, uniformly-sampled
        entities can also be appended to the per-chunk negatives (each chunk
        having a different sample). This function returns both the chunked
        embeddings of the negatives and a mask of the same size as the chunked
        positives-vs-negatives scores, whose non-zero elements correspond to the
        scores that must be ignored.

        """
        num_pos = len(pos_input)
        num_chunks, chunk_size, dim = match_shape(pos_embs, -1, -1, -1)
        last_chunk_size = num_pos - (num_chunks - 1) * chunk_size

        ignore_mask: Mask = []
        if type_ is Negatives.NONE:
            neg_embs = torch.empty((num_chunks, 0, dim))
        elif type_ is Negatives.UNIFORM:
            uniform_neg_embs = module.sample_entities(num_chunks,
                                                      num_uniform_neg)
            neg_embs = self.adjust_embs(
                uniform_neg_embs,
                rel,
                entity_type,
                operator,
            )
        elif type_ is Negatives.BATCH_UNIFORM:
            neg_embs = pos_embs
            if num_uniform_neg > 0:
                try:
                    uniform_neg_embs = module.sample_entities(
                        num_chunks, num_uniform_neg)
                except NotImplementedError:
                    pass  # only use pos_embs i.e. batch negatives
                else:
                    neg_embs = torch.cat([
                        pos_embs,
                        self.adjust_embs(
                            uniform_neg_embs,
                            rel,
                            entity_type,
                            operator,
                        )
                    ],
                                         dim=1)

            chunk_indices = torch.arange(chunk_size, dtype=torch.long)
            last_chunk_indices = chunk_indices[:last_chunk_size]
            # Ignore scores between positive pairs.
            ignore_mask.append(
                (slice(num_chunks - 1), chunk_indices, chunk_indices))
            ignore_mask.append((-1, last_chunk_indices, last_chunk_indices))
            # In the last chunk, ignore the scores between the positives that
            # are not padding (i.e., the first last_chunk_size ones) and the
            # negatives that are padding (i.e., all of them except the first
            # last_chunk_size ones). Stop the last slice at chunk_size so that
            # it doesn't also affect the uniformly-sampled negatives.
            ignore_mask.append(
                (-1, slice(last_chunk_size), slice(last_chunk_size,
                                                   chunk_size)))

        elif type_ is Negatives.ALL:
            pos_input = pos_input.to_tensor()
            neg_embs = self.adjust_embs(
                module.get_all_entities().expand(num_chunks, -1, dim),
                rel,
                entity_type,
                operator,
            )

            if num_uniform_neg > 0:
                log("WARNING: Adding uniform negatives makes no sense "
                    "when already using all negatives")

            chunk_indices = torch.arange(chunk_size, dtype=torch.long)
            last_chunk_indices = chunk_indices[:last_chunk_size]
            # Ignore scores between positive pairs: since the i-th such pair has
            # the pos_input[i] entity on this side, ignore_mask[i, pos_input[i]]
            # must be set to 1 for every i. This becomes slightly more tricky as
            # the rows may be wrapped into multiple chunks (the last of which
            # may be smaller).
            ignore_mask.append((
                torch.arange(num_chunks - 1, dtype=torch.long).unsqueeze(1),
                chunk_indices.unsqueeze(0),
                pos_input[:-last_chunk_size].view(num_chunks - 1, chunk_size),
            ))
            ignore_mask.append(
                (-1, last_chunk_indices, pos_input[-last_chunk_size:]))

        else:
            raise NotImplementedError("Unknown negative type %s" % type_)

        return neg_embs, ignore_mask
Пример #3
0
 def forward(self, input_: EntityList) -> FloatTensorType:
     return self.get(input_.to_tensor())