Example #1
0
 def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     # We add a dimension so that matmul performs a matrix-vector product.
     return torch.matmul(
         self.linear_transformation.to(device=embeddings.device),
         embeddings.unsqueeze(-1),
     ).squeeze(-1)
Example #2
0
 def forward(
     self, embeddings: FloatTensorType, operator_idxs: LongTensorType
 ) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     match_shape(operator_idxs, *embeddings.size()[:-1])
     return (
         embeddings + self.translations.to(device=embeddings.device)[operator_idxs]
     )
Example #3
0
 def forward(self, embeddings: FloatTensorType,
             operator_idxs: LongTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     match_shape(operator_idxs, *embeddings.size()[:-1])
     # We add a dimension so that matmul performs a matrix-vector product.
     return torch.matmul(
         self.linear_transformations.to(
             device=embeddings.device)[operator_idxs],
         embeddings.unsqueeze(-1),
     ).squeeze(-1)
Example #4
0
 def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     real_a = embeddings[..., :self.dim // 2]
     imag_a = embeddings[..., self.dim // 2:]
     real_b = self.real.to(device=embeddings.device)
     imag_b = self.imag.to(device=embeddings.device)
     prod = torch.empty_like(embeddings)
     prod[..., :self.dim // 2] = real_a * real_b - imag_a * imag_b
     prod[..., self.dim // 2:] = real_a * imag_b + imag_a * real_b
     return prod
Example #5
0
 def forward(self, embeddings: FloatTensorType,
             operator_idxs: LongTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     match_shape(operator_idxs, *embeddings.size()[:-1])
     real_a = embeddings[..., :self.dim // 2]
     imag_a = embeddings[..., self.dim // 2:]
     real_b = self.real.to(device=embeddings.device)[operator_idxs]
     imag_b = self.imag.to(device=embeddings.device)[operator_idxs]
     prod = torch.empty_like(embeddings)
     prod[..., :self.dim // 2] = real_a * real_b - imag_a * imag_b
     prod[..., self.dim // 2:] = real_a * imag_b + imag_a * real_b
     return prod
Example #6
0
 def test_zero_dimensions(self):
     t = torch.zeros(())
     self.assertIsNone(match_shape(t))
     self.assertIsNone(match_shape(t, ...))
     with self.assertRaises(TypeError):
         match_shape(t, 0)
     with self.assertRaises(TypeError):
         match_shape(t, 1)
     with self.assertRaises(TypeError):
         match_shape(t, -1)
Example #7
0
    def forward(
        self,
        lhs_pos: FloatTensorType,
        rhs_pos: FloatTensorType,
        lhs_neg: FloatTensorType,
        rhs_neg: FloatTensorType,
    ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]:
        num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1)
        match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim)
        match_shape(lhs_neg, num_chunks, -1, dim)
        match_shape(rhs_neg, num_chunks, -1, dim)

        pos_scores, lhs_neg_scores, rhs_neg_scores = self.base_comparator.forward(
            lhs_pos[..., 1:], rhs_pos[..., 1:], lhs_neg[..., 1:], rhs_neg[...,
                                                                          1:])

        lhs_pos_bias = lhs_pos[..., 0]
        rhs_pos_bias = rhs_pos[..., 0]

        pos_scores += lhs_pos_bias
        pos_scores += rhs_pos_bias

        lhs_neg_scores += rhs_pos_bias.unsqueeze(-1)
        lhs_neg_scores += lhs_neg[..., 0].unsqueeze(-2)

        rhs_neg_scores += lhs_pos_bias.unsqueeze(-1)
        rhs_neg_scores += rhs_neg[..., 0].unsqueeze(-2)

        return pos_scores, lhs_neg_scores, rhs_neg_scores
Example #8
0
 def test_bad_args(self):
     t = torch.empty((0,))
     with self.assertRaises(RuntimeError):
         match_shape(t, ..., ...)
     with self.assertRaises(RuntimeError):
         match_shape(t, "foo")
     with self.assertRaises(AttributeError):
         match_shape(None)
Example #9
0
def batched_all_pairs_squared_l2_dist(a: FloatTensorType,
                                      b: FloatTensorType) -> FloatTensorType:
    """For each batch, return the squared L2 distance between each pair of vectors

    Let A and B be tensors of shape NxM_AxD and NxM_BxD, each containing N*M_A
    and N*M_B vectors of dimension D grouped in N batches of size M_A and M_B.
    For each batch, for each vector of A and each vector of B, return the sum
    of the squares of the differences of their components.

    """
    num_chunks, num_a, dim = match_shape(a, -1, -1, -1)
    num_b = match_shape(b, num_chunks, -1, dim)
    a_squared = a.norm(dim=-1).pow(2)
    b_squared = b.norm(dim=-1).pow(2)
    # Calculate res_i,k = sum_j((a_i,j - b_k,j)^2) for each i and k as
    # sum_j(a_i,j^2) - 2 sum_j(a_i,j b_k,j) + sum_j(b_k,j^2), by using a matrix
    # multiplication for the ab part, adding the b^2 as part of the baddbmm call
    # and the a^2 afterwards.
    res = torch.baddbmm(b_squared.unsqueeze(-2),
                        a,
                        b.transpose(-2, -1),
                        alpha=-2).add_(a_squared.unsqueeze(-1))
    match_shape(res, num_chunks, num_a, num_b)
    return res
Example #10
0
    def forward(
        self,
        lhs_pos: FloatTensorType,
        rhs_pos: FloatTensorType,
        lhs_neg: FloatTensorType,
        rhs_neg: FloatTensorType,
    ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]:
        num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1)
        match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim)
        match_shape(lhs_neg, num_chunks, -1, dim)
        match_shape(rhs_neg, num_chunks, -1, dim)

        # Smaller distances are higher scores, so take their negatives.
        pos_scores = ((lhs_pos.float() - rhs_pos.float()).pow_(2).sum(
            dim=-1).clamp_min_(1e-30).sqrt_().neg())
        lhs_neg_scores = batched_all_pairs_l2_dist(rhs_pos, lhs_neg).neg()
        rhs_neg_scores = batched_all_pairs_l2_dist(lhs_pos, rhs_neg).neg()

        return pos_scores, lhs_neg_scores, rhs_neg_scores
Example #11
0
    def forward(
        self,
        lhs_pos: FloatTensorType,
        rhs_pos: FloatTensorType,
        lhs_neg: FloatTensorType,
        rhs_neg: FloatTensorType,
    ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]:
        num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1)
        match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim)
        match_shape(lhs_neg, num_chunks, -1, dim)
        match_shape(rhs_neg, num_chunks, -1, dim)

        # Equivalent to (but faster than) torch.einsum('cid,cid->ci', ...).
        pos_scores = (lhs_pos.float() * rhs_pos.float()).sum(-1)
        # Equivalent to (but faster than) torch.einsum('cid,cjd->cij', ...).
        lhs_neg_scores = torch.bmm(rhs_pos, lhs_neg.transpose(-1, -2))
        rhs_neg_scores = torch.bmm(lhs_pos, rhs_neg.transpose(-1, -2))

        return pos_scores, lhs_neg_scores, rhs_neg_scores
Example #12
0
 def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     return embeddings + self.translation.to(device=embeddings.device)
Example #13
0
 def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     return self.diagonal.to(device=embeddings.device) * embeddings
Example #14
0
 def forward(self, embeddings: FloatTensorType,
             operator_idxs: LongTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     match_shape(operator_idxs, *embeddings.size()[:-1])
     return embeddings
Example #15
0
 def test_many_dimension(self):
     t = torch.zeros((3, 4, 5))
     self.assertIsNone(match_shape(t, 3, 4, 5))
     self.assertIsNone(match_shape(t, ...))
     self.assertIsNone(match_shape(t, ..., 5))
     self.assertIsNone(match_shape(t, 3, ..., 5))
     self.assertIsNone(match_shape(t, 3, 4, 5, ...))
     self.assertEqual(match_shape(t, -1, 4, 5), 3)
     self.assertEqual(match_shape(t, -1, ...), 3)
     self.assertEqual(match_shape(t, -1, 4, ...), 3)
     self.assertEqual(match_shape(t, -1, ..., 5), 3)
     self.assertEqual(match_shape(t, -1, 4, -1), (3, 5))
     self.assertEqual(match_shape(t, ..., -1, -1), (4, 5))
     self.assertEqual(match_shape(t, -1, -1, -1), (3, 4, 5))
     self.assertEqual(match_shape(t, -1, -1, ..., -1), (3, 4, 5))
     with self.assertRaises(TypeError):
         match_shape(t)
     with self.assertRaises(TypeError):
         match_shape(t, 3)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4)
     with self.assertRaises(TypeError):
         match_shape(t, 5, 4, 3)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4, 5, 6)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 4, ..., 4, 5)
Example #16
0
 def test_one_dimension(self):
     t = torch.zeros((3,))
     self.assertIsNone(match_shape(t, 3))
     self.assertIsNone(match_shape(t, ...))
     self.assertIsNone(match_shape(t, 3, ...))
     self.assertIsNone(match_shape(t, ..., 3))
     self.assertEqual(match_shape(t, -1), 3)
     with self.assertRaises(TypeError):
         match_shape(t)
     with self.assertRaises(TypeError):
         match_shape(t, 3, 1)
     with self.assertRaises(TypeError):
         match_shape(t, 3, ..., 3)
Example #17
0
    def prepare_negatives(
        self,
        pos_input: EntityList,
        pos_embs: FloatTensorType,
        module: AbstractEmbedding,
        type_: Negatives,
        num_uniform_neg: int,
        rel: Union[int, LongTensorType],
        entity_type: str,
        operator: Union[None, AbstractOperator, AbstractDynamicOperator],
    ) -> Tuple[FloatTensorType, Mask]:
        """Given some chunked positives, set up chunks of negatives.

        This function operates on one side (left-hand or right-hand) at a time.
        It takes all the information about the positives on that side (the
        original input value, the corresponding embeddings, and the module used
        to convert one to the other). It then produces negatives for that side
        according to the specified mode. The positive embeddings come in in
        chunked form and the negatives are produced within each of these chunks.
        The negatives can be either none, or the positives from the same chunk,
        or all the possible entities. In the second mode, uniformly-sampled
        entities can also be appended to the per-chunk negatives (each chunk
        having a different sample). This function returns both the chunked
        embeddings of the negatives and a mask of the same size as the chunked
        positives-vs-negatives scores, whose non-zero elements correspond to the
        scores that must be ignored.

        """
        num_pos = len(pos_input)
        num_chunks, chunk_size, dim = match_shape(pos_embs, -1, -1, -1)
        last_chunk_size = num_pos - (num_chunks - 1) * chunk_size

        ignore_mask: Mask = []
        if type_ is Negatives.NONE:
            neg_embs = pos_embs.new_empty((num_chunks, 0, dim))
        elif type_ is Negatives.UNIFORM:
            uniform_neg_embs = module.sample_entities(num_chunks,
                                                      num_uniform_neg)
            neg_embs = self.adjust_embs(uniform_neg_embs, rel, entity_type,
                                        operator)
        elif type_ is Negatives.BATCH_UNIFORM:
            neg_embs = pos_embs
            if num_uniform_neg > 0:
                try:
                    uniform_neg_embs = module.sample_entities(
                        num_chunks, num_uniform_neg)
                except NotImplementedError:
                    pass  # only use pos_embs i.e. batch negatives
                else:
                    neg_embs = torch.cat(
                        [
                            pos_embs,
                            self.adjust_embs(uniform_neg_embs, rel,
                                             entity_type, operator),
                        ],
                        dim=1,
                    )

            chunk_indices = torch.arange(chunk_size,
                                         dtype=torch.long,
                                         device=pos_embs.device)
            last_chunk_indices = chunk_indices[:last_chunk_size]
            # Ignore scores between positive pairs.
            ignore_mask.append(
                (slice(num_chunks - 1), chunk_indices, chunk_indices))
            ignore_mask.append((-1, last_chunk_indices, last_chunk_indices))
            # In the last chunk, ignore the scores between the positives that
            # are not padding (i.e., the first last_chunk_size ones) and the
            # negatives that are padding (i.e., all of them except the first
            # last_chunk_size ones). Stop the last slice at chunk_size so that
            # it doesn't also affect the uniformly-sampled negatives.
            ignore_mask.append(
                (-1, slice(last_chunk_size), slice(last_chunk_size,
                                                   chunk_size)))

        elif type_ is Negatives.ALL:
            pos_input_ten = pos_input.to_tensor()
            neg_embs = self.adjust_embs(
                module.get_all_entities().expand(num_chunks, -1, dim),
                rel,
                entity_type,
                operator,
            )

            if num_uniform_neg > 0:
                logger.warning("Adding uniform negatives makes no sense "
                               "when already using all negatives")

            chunk_indices = torch.arange(chunk_size,
                                         dtype=torch.long,
                                         device=pos_embs.device)
            last_chunk_indices = chunk_indices[:last_chunk_size]
            # Ignore scores between positive pairs: since the i-th such pair has
            # the pos_input[i] entity on this side, ignore_mask[i, pos_input[i]]
            # must be set to 1 for every i. This becomes slightly more tricky as
            # the rows may be wrapped into multiple chunks (the last of which
            # may be smaller).
            ignore_mask.append((
                torch.arange(num_chunks - 1,
                             dtype=torch.long,
                             device=pos_embs.device).unsqueeze(1),
                chunk_indices.unsqueeze(0),
                pos_input_ten[:-last_chunk_size].view(num_chunks - 1,
                                                      chunk_size),
            ))
            ignore_mask.append(
                (-1, last_chunk_indices, pos_input_ten[-last_chunk_size:]))

        else:
            raise NotImplementedError("Unknown negative type %s" % type_)

        return neg_embs, ignore_mask
Example #18
0
 def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
     match_shape(embeddings, ..., self.dim)
     return embeddings