示例#1
0
 def _exp_score(self, h, t, r_h, r_t, p, power_norm) -> torch.FloatTensor:
     assert not power_norm
     # -\|R_h h - R_t t\|
     h, t, r_h, r_t = strip_dim(h, t, r_h, r_t)
     h = r_h @ h.unsqueeze(dim=-1)
     t = r_t @ t.unsqueeze(dim=-1)
     return -(h - t).norm(p)
示例#2
0
 def _exp_score(self, exact, h_mean, h_var, r_mean, r_var, similarity, t_mean, t_var):
     assert similarity == "KL"
     h_mean, h_var, r_mean, r_var, t_mean, t_var = strip_dim(h_mean, h_var, r_mean, r_var, t_mean, t_var)
     e_mean, e_var = h_mean - t_mean, h_var + t_var
     p = torch.distributions.MultivariateNormal(loc=e_mean, covariance_matrix=torch.diag(e_var))
     q = torch.distributions.MultivariateNormal(loc=r_mean, covariance_matrix=torch.diag(r_var))
     return -torch.distributions.kl.kl_divergence(p, q)
示例#3
0
 def _exp_score(self, h, r, m_r, t, p, power_norm) -> torch.FloatTensor:
     assert power_norm
     h, r, m_r, t = strip_dim(h, r, m_r, t)
     h_bot, t_bot = [
         clamp_norm(x.unsqueeze(dim=0) @ m_r, p=2, dim=-1, maxnorm=1.)
         for x in (h, t)
     ]
     return -((h_bot + r - t_bot)**p).sum()
示例#4
0
 def _exp_score(self, bn_h, bn_hr, core_tensor, do_h, do_r, do_hr, h, r, t) -> torch.FloatTensor:
     # DO_{hr}(BN_{hr}(DO_h(BN_h(h)) x_1 DO_r(W x_2 r))) x_3 t
     h, r, t = strip_dim(h, r, t)
     a = do_r((core_tensor * r[None, :, None]).sum(dim=1, keepdims=True))  # shape: (embedding_dim, 1, embedding_dim)
     b = do_h(bn_h(h.view(1, -1))).view(-1)  # shape: (embedding_dim)
     c = (b[:, None, None] * a).sum(dim=0, keepdims=True)  # shape: (1, 1, embedding_dim)
     d = do_hr(bn_hr((c.view(1, -1)))).view(1, 1, -1)  # shape: (1, 1, 1, embedding_dim)
     return (d * t[None, None, :]).sum()
示例#5
0
 def _exp_score(self, h, t, w, vt, vh, b, u, activation) -> torch.FloatTensor:
     # f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r t + b_r)
     # shapes: w: (k, dim, dim), vh/vt: (k, dim), b/u: (k,), h/t: (dim,)
     # remove batch/num dimension
     h, t, w, vt, vh, b, u = strip_dim(h, t, w, vt, vh, b, u)
     score = 0.
     for i in range(u.shape[-1]):
         first_part = h.view(1, self.dim) @ w[i] @ t.view(self.dim, 1)
         second_part = (vh[i] * h.view(-1)).sum()
         third_part = (vt[i] * t.view(-1)).sum()
         score = score + u[i] * activation(first_part + second_part + third_part + b[i])
     return score
示例#6
0
 def _exp_score(self, h, r, t, h_inv, r_inv, t_inv,
                clamp) -> torch.FloatTensor:
     h, r, t, h_inv, r_inv, t_inv = strip_dim(h, r, t, h_inv, r_inv, t_inv)
     assert clamp is None
     return 0.5 * distmult_interaction(
         h, r, t) + 0.5 * distmult_interaction(h_inv, r_inv, t_inv)
示例#7
0
 def _exp_score(self, h, t, p, power_norm) -> torch.FloatTensor:
     assert power_norm
     # -\|h - t\|
     h, t = strip_dim(h, t)
     return -(h - t).pow(p).sum()
示例#8
0
 def _exp_score(self, h, w_r, d_r, t, p,
                power_norm) -> torch.FloatTensor:  # noqa: D102
     assert not power_norm
     h, w_r, d_r, t = strip_dim(h, w_r, d_r, t)
     h, t = [x - (x * w_r).sum() * w_r for x in (h, t)]
     return -(h + d_r - t).norm(p=p)
示例#9
0
 def _exp_score(self, h, r, t) -> torch.FloatTensor:  # noqa: D102
     h, r, t = strip_dim(*(view_complex(x) for x in (h, r, t)))
     # check for unit length
     assert torch.allclose((r.abs()**2).sum(dim=-1).sqrt(), torch.ones(1))
     d = h * r - t
     return -(d.abs()**2).sum(dim=-1).sqrt()
示例#10
0
 def _exp_score(self, h, r, t) -> torch.FloatTensor:
     # f(h, r, t) = h @ r @ t
     h, r, t = strip_dim(h, r, t)
     return h.view(1, -1) @ r @ t.view(-1, 1)
示例#11
0
 def _exp_score(self, h, r, t, d_e, d_r, b_c, b_p,
                activation) -> torch.FloatTensor:
     # f(h, r, t) = g(t z(D_e h + D_r r + b_c) + b_p)
     h, r, t = strip_dim(h, r, t)
     return (t * activation((d_e * h) + (d_r * r) + b_c)).sum() + b_p
示例#12
0
 def _exp_score(self, h, r, t) -> torch.FloatTensor:  # noqa: D102
     h, r, t = strip_dim(h, r, t)
     return -(_rotate_quaternion(*(_split_quaternion(x)
                                   for x in [h, r])) * t).sum()