def forward(self, embeddings: FloatTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) # We add a dimension so that matmul performs a matrix-vector product. return torch.matmul( self.linear_transformation.to(device=embeddings.device), embeddings.unsqueeze(-1), ).squeeze(-1)
def forward( self, embeddings: FloatTensorType, operator_idxs: LongTensorType ) -> FloatTensorType: match_shape(embeddings, ..., self.dim) match_shape(operator_idxs, *embeddings.size()[:-1]) return ( embeddings + self.translations.to(device=embeddings.device)[operator_idxs] )
def forward(self, embeddings: FloatTensorType, operator_idxs: LongTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) match_shape(operator_idxs, *embeddings.size()[:-1]) # We add a dimension so that matmul performs a matrix-vector product. return torch.matmul( self.linear_transformations.to( device=embeddings.device)[operator_idxs], embeddings.unsqueeze(-1), ).squeeze(-1)
def forward(self, embeddings: FloatTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) real_a = embeddings[..., :self.dim // 2] imag_a = embeddings[..., self.dim // 2:] real_b = self.real.to(device=embeddings.device) imag_b = self.imag.to(device=embeddings.device) prod = torch.empty_like(embeddings) prod[..., :self.dim // 2] = real_a * real_b - imag_a * imag_b prod[..., self.dim // 2:] = real_a * imag_b + imag_a * real_b return prod
def forward(self, embeddings: FloatTensorType, operator_idxs: LongTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) match_shape(operator_idxs, *embeddings.size()[:-1]) real_a = embeddings[..., :self.dim // 2] imag_a = embeddings[..., self.dim // 2:] real_b = self.real.to(device=embeddings.device)[operator_idxs] imag_b = self.imag.to(device=embeddings.device)[operator_idxs] prod = torch.empty_like(embeddings) prod[..., :self.dim // 2] = real_a * real_b - imag_a * imag_b prod[..., self.dim // 2:] = real_a * imag_b + imag_a * real_b return prod
def test_zero_dimensions(self): t = torch.zeros(()) self.assertIsNone(match_shape(t)) self.assertIsNone(match_shape(t, ...)) with self.assertRaises(TypeError): match_shape(t, 0) with self.assertRaises(TypeError): match_shape(t, 1) with self.assertRaises(TypeError): match_shape(t, -1)
def forward( self, lhs_pos: FloatTensorType, rhs_pos: FloatTensorType, lhs_neg: FloatTensorType, rhs_neg: FloatTensorType, ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]: num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1) match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim) match_shape(lhs_neg, num_chunks, -1, dim) match_shape(rhs_neg, num_chunks, -1, dim) pos_scores, lhs_neg_scores, rhs_neg_scores = self.base_comparator.forward( lhs_pos[..., 1:], rhs_pos[..., 1:], lhs_neg[..., 1:], rhs_neg[..., 1:]) lhs_pos_bias = lhs_pos[..., 0] rhs_pos_bias = rhs_pos[..., 0] pos_scores += lhs_pos_bias pos_scores += rhs_pos_bias lhs_neg_scores += rhs_pos_bias.unsqueeze(-1) lhs_neg_scores += lhs_neg[..., 0].unsqueeze(-2) rhs_neg_scores += lhs_pos_bias.unsqueeze(-1) rhs_neg_scores += rhs_neg[..., 0].unsqueeze(-2) return pos_scores, lhs_neg_scores, rhs_neg_scores
def test_bad_args(self): t = torch.empty((0,)) with self.assertRaises(RuntimeError): match_shape(t, ..., ...) with self.assertRaises(RuntimeError): match_shape(t, "foo") with self.assertRaises(AttributeError): match_shape(None)
def batched_all_pairs_squared_l2_dist(a: FloatTensorType, b: FloatTensorType) -> FloatTensorType: """For each batch, return the squared L2 distance between each pair of vectors Let A and B be tensors of shape NxM_AxD and NxM_BxD, each containing N*M_A and N*M_B vectors of dimension D grouped in N batches of size M_A and M_B. For each batch, for each vector of A and each vector of B, return the sum of the squares of the differences of their components. """ num_chunks, num_a, dim = match_shape(a, -1, -1, -1) num_b = match_shape(b, num_chunks, -1, dim) a_squared = a.norm(dim=-1).pow(2) b_squared = b.norm(dim=-1).pow(2) # Calculate res_i,k = sum_j((a_i,j - b_k,j)^2) for each i and k as # sum_j(a_i,j^2) - 2 sum_j(a_i,j b_k,j) + sum_j(b_k,j^2), by using a matrix # multiplication for the ab part, adding the b^2 as part of the baddbmm call # and the a^2 afterwards. res = torch.baddbmm(b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2).add_(a_squared.unsqueeze(-1)) match_shape(res, num_chunks, num_a, num_b) return res
def forward( self, lhs_pos: FloatTensorType, rhs_pos: FloatTensorType, lhs_neg: FloatTensorType, rhs_neg: FloatTensorType, ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]: num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1) match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim) match_shape(lhs_neg, num_chunks, -1, dim) match_shape(rhs_neg, num_chunks, -1, dim) # Smaller distances are higher scores, so take their negatives. pos_scores = ((lhs_pos.float() - rhs_pos.float()).pow_(2).sum( dim=-1).clamp_min_(1e-30).sqrt_().neg()) lhs_neg_scores = batched_all_pairs_l2_dist(rhs_pos, lhs_neg).neg() rhs_neg_scores = batched_all_pairs_l2_dist(lhs_pos, rhs_neg).neg() return pos_scores, lhs_neg_scores, rhs_neg_scores
def forward( self, lhs_pos: FloatTensorType, rhs_pos: FloatTensorType, lhs_neg: FloatTensorType, rhs_neg: FloatTensorType, ) -> Tuple[FloatTensorType, FloatTensorType, FloatTensorType]: num_chunks, num_pos_per_chunk, dim = match_shape(lhs_pos, -1, -1, -1) match_shape(rhs_pos, num_chunks, num_pos_per_chunk, dim) match_shape(lhs_neg, num_chunks, -1, dim) match_shape(rhs_neg, num_chunks, -1, dim) # Equivalent to (but faster than) torch.einsum('cid,cid->ci', ...). pos_scores = (lhs_pos.float() * rhs_pos.float()).sum(-1) # Equivalent to (but faster than) torch.einsum('cid,cjd->cij', ...). lhs_neg_scores = torch.bmm(rhs_pos, lhs_neg.transpose(-1, -2)) rhs_neg_scores = torch.bmm(lhs_pos, rhs_neg.transpose(-1, -2)) return pos_scores, lhs_neg_scores, rhs_neg_scores
def forward(self, embeddings: FloatTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) return embeddings + self.translation.to(device=embeddings.device)
def forward(self, embeddings: FloatTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) return self.diagonal.to(device=embeddings.device) * embeddings
def forward(self, embeddings: FloatTensorType, operator_idxs: LongTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) match_shape(operator_idxs, *embeddings.size()[:-1]) return embeddings
def test_many_dimension(self): t = torch.zeros((3, 4, 5)) self.assertIsNone(match_shape(t, 3, 4, 5)) self.assertIsNone(match_shape(t, ...)) self.assertIsNone(match_shape(t, ..., 5)) self.assertIsNone(match_shape(t, 3, ..., 5)) self.assertIsNone(match_shape(t, 3, 4, 5, ...)) self.assertEqual(match_shape(t, -1, 4, 5), 3) self.assertEqual(match_shape(t, -1, ...), 3) self.assertEqual(match_shape(t, -1, 4, ...), 3) self.assertEqual(match_shape(t, -1, ..., 5), 3) self.assertEqual(match_shape(t, -1, 4, -1), (3, 5)) self.assertEqual(match_shape(t, ..., -1, -1), (4, 5)) self.assertEqual(match_shape(t, -1, -1, -1), (3, 4, 5)) self.assertEqual(match_shape(t, -1, -1, ..., -1), (3, 4, 5)) with self.assertRaises(TypeError): match_shape(t) with self.assertRaises(TypeError): match_shape(t, 3) with self.assertRaises(TypeError): match_shape(t, 3, 4) with self.assertRaises(TypeError): match_shape(t, 5, 4, 3) with self.assertRaises(TypeError): match_shape(t, 3, 4, 5, 6) with self.assertRaises(TypeError): match_shape(t, 3, 4, ..., 4, 5)
def test_one_dimension(self): t = torch.zeros((3,)) self.assertIsNone(match_shape(t, 3)) self.assertIsNone(match_shape(t, ...)) self.assertIsNone(match_shape(t, 3, ...)) self.assertIsNone(match_shape(t, ..., 3)) self.assertEqual(match_shape(t, -1), 3) with self.assertRaises(TypeError): match_shape(t) with self.assertRaises(TypeError): match_shape(t, 3, 1) with self.assertRaises(TypeError): match_shape(t, 3, ..., 3)
def prepare_negatives( self, pos_input: EntityList, pos_embs: FloatTensorType, module: AbstractEmbedding, type_: Negatives, num_uniform_neg: int, rel: Union[int, LongTensorType], entity_type: str, operator: Union[None, AbstractOperator, AbstractDynamicOperator], ) -> Tuple[FloatTensorType, Mask]: """Given some chunked positives, set up chunks of negatives. This function operates on one side (left-hand or right-hand) at a time. It takes all the information about the positives on that side (the original input value, the corresponding embeddings, and the module used to convert one to the other). It then produces negatives for that side according to the specified mode. The positive embeddings come in in chunked form and the negatives are produced within each of these chunks. The negatives can be either none, or the positives from the same chunk, or all the possible entities. In the second mode, uniformly-sampled entities can also be appended to the per-chunk negatives (each chunk having a different sample). This function returns both the chunked embeddings of the negatives and a mask of the same size as the chunked positives-vs-negatives scores, whose non-zero elements correspond to the scores that must be ignored. """ num_pos = len(pos_input) num_chunks, chunk_size, dim = match_shape(pos_embs, -1, -1, -1) last_chunk_size = num_pos - (num_chunks - 1) * chunk_size ignore_mask: Mask = [] if type_ is Negatives.NONE: neg_embs = pos_embs.new_empty((num_chunks, 0, dim)) elif type_ is Negatives.UNIFORM: uniform_neg_embs = module.sample_entities(num_chunks, num_uniform_neg) neg_embs = self.adjust_embs(uniform_neg_embs, rel, entity_type, operator) elif type_ is Negatives.BATCH_UNIFORM: neg_embs = pos_embs if num_uniform_neg > 0: try: uniform_neg_embs = module.sample_entities( num_chunks, num_uniform_neg) except NotImplementedError: pass # only use pos_embs i.e. batch negatives else: neg_embs = torch.cat( [ pos_embs, self.adjust_embs(uniform_neg_embs, rel, entity_type, operator), ], dim=1, ) chunk_indices = torch.arange(chunk_size, dtype=torch.long, device=pos_embs.device) last_chunk_indices = chunk_indices[:last_chunk_size] # Ignore scores between positive pairs. ignore_mask.append( (slice(num_chunks - 1), chunk_indices, chunk_indices)) ignore_mask.append((-1, last_chunk_indices, last_chunk_indices)) # In the last chunk, ignore the scores between the positives that # are not padding (i.e., the first last_chunk_size ones) and the # negatives that are padding (i.e., all of them except the first # last_chunk_size ones). Stop the last slice at chunk_size so that # it doesn't also affect the uniformly-sampled negatives. ignore_mask.append( (-1, slice(last_chunk_size), slice(last_chunk_size, chunk_size))) elif type_ is Negatives.ALL: pos_input_ten = pos_input.to_tensor() neg_embs = self.adjust_embs( module.get_all_entities().expand(num_chunks, -1, dim), rel, entity_type, operator, ) if num_uniform_neg > 0: logger.warning("Adding uniform negatives makes no sense " "when already using all negatives") chunk_indices = torch.arange(chunk_size, dtype=torch.long, device=pos_embs.device) last_chunk_indices = chunk_indices[:last_chunk_size] # Ignore scores between positive pairs: since the i-th such pair has # the pos_input[i] entity on this side, ignore_mask[i, pos_input[i]] # must be set to 1 for every i. This becomes slightly more tricky as # the rows may be wrapped into multiple chunks (the last of which # may be smaller). ignore_mask.append(( torch.arange(num_chunks - 1, dtype=torch.long, device=pos_embs.device).unsqueeze(1), chunk_indices.unsqueeze(0), pos_input_ten[:-last_chunk_size].view(num_chunks - 1, chunk_size), )) ignore_mask.append( (-1, last_chunk_indices, pos_input_ten[-last_chunk_size:])) else: raise NotImplementedError("Unknown negative type %s" % type_) return neg_embs, ignore_mask
def forward(self, embeddings: FloatTensorType) -> FloatTensorType: match_shape(embeddings, ..., self.dim) return embeddings