def test_vec_log_sum_exp_batch_stable(): h = np.random.randint(22, 41) i1 = torch.rand(1, h, h) i2 = torch.rand(1, h, h) i = torch.cat([i1, i2], dim=0) lse1 = vec_log_sum_exp(i1, 2) lse2 = vec_log_sum_exp(i2, 2) one_x_one = torch.cat([lse1, lse2], dim=0) lse = vec_log_sum_exp(i, 2) np.testing.assert_allclose(one_x_one.numpy(), lse.numpy())
def test_vec_log_sum_exp_shape(): dim = torch.randint(0, 3, (1,)).item() shape = torch.randint(1, 21, (3,)) in_ = torch.rand(*shape) out = vec_log_sum_exp(in_, dim) shape[dim] = 1 for i in range(len(shape)): assert out.size(i) == shape[i]
def test_vec_log_sum_exp(): vec = torch.rand(1, np.random.randint(5, 31)) ours = vec_log_sum_exp(vec, 1).squeeze() xs = {} for i in range(vec.size(1)): xs[i] = vec[0, i].item() gold = explicit_log_sum_exp(xs) np.testing.assert_allclose(ours, gold, rtol=1e-6)
def test_vec_log_sum_exp_shape(): dim = torch.randint(0, 3, (1, )).item() shape = torch.randint(1, 21, (3, )) in_ = torch.rand(*shape) out = vec_log_sum_exp(in_, dim) shape[dim] = 1 for i in range(len(shape)): assert out.size(i) == shape[i]
def forward(self, unary, lengths, batch_size): """For CRF forward on a batch. :param unary: torch.FloatTensor: [T, B, N] :param lengths: torch.LongTensor: [B] :param batch_size: int: B :return: torch.FloatTensor: [B] """ min_length = torch.min(lengths) # alphas: [B, 1, N] alphas = torch.Tensor(batch_size, 1, self.n_tags).fill_(-1e4).to(unary.device) alphas[:, 0, self.start_idx] = 0. alphas.requires_grad = True trans = self.transitions # [1, N, N] for i, unary_t in enumerate(unary): # unary_t: [B, N] unary_t = unary_t.unsqueeze(2) # [B, N, 1] # Broadcast alphas along the rows of trans # Broadcast trans along the batch of alphas # [B, 1, N] + [1, N, N] -> [B, N, N] # Broadcast unary_t along the cols of result # [B, N, N] + [B, N, 1] -> [B, N, N] scores = alphas + trans + unary_t new_alphas = vec_log_sum_exp(scores, 2).transpose(1, 2) # If we haven't reached your length zero out old alpha and take new one. # If we are past your length, zero out new_alpha and keep old one. if i >= min_length: mask = (i < lengths).view(-1, 1, 1) alphas = alphas.masked_fill(mask, 0) + new_alphas.masked_fill(mask == 0, 0) else: alphas = new_alphas terminal_vars = alphas + trans[:, self.end_idx] alphas = vec_log_sum_exp(terminal_vars, 2) return alphas.squeeze()
def forward(self, inputs, targets): # This is the cosine distance annealing referred to in https://arxiv.org/pdf/1911.03688.pdf fract = min(self.steps / self.warmup_steps, 1) c = (self.max_scale - 1) * fract + 1 self.steps += 1 # These will get broadcast to [B, B, H] query = self.model.encode_query(inputs).unsqueeze(1) # [B, 1, H] response = self.model.encode_response(targets).unsqueeze( 0) # [1, B, H] # all_scores is now a batch x batch matrix where index (i, j) is the score between # the i^th x vector and the j^th y vector all_score = c * self.score(query, response) # [B, B] # The diagonal has the scores of correct pair, (i, i) pos_score = torch.diag(all_score) # vec_log_sum_exp will calculate the batched log_sum_exp in a numerically stable way # the result is a [B, 1] vector which we squeeze to make it [B] to match the diag # Because we are minimizing the negative log we turned the division into a subtraction here loss = pos_score - vec_log_sum_exp(all_score, -1).squeeze() # Batch loss loss = torch.sum(loss) # minimize the negative loss return -loss
def test_vec_log_sum_exp_ones(): l = np.random.randint(1, 21) in_ = torch.ones(1, l) lse = vec_log_sum_exp(in_, 1).squeeze() np.testing.assert_allclose(lse.detach().numpy(), math.log(l * math.e))