示例#1
0
def test_vec_log_sum_exp_batch_stable():
    h = np.random.randint(22, 41)
    i1 = torch.rand(1, h, h)
    i2 = torch.rand(1, h, h)
    i = torch.cat([i1, i2], dim=0)
    lse1 = vec_log_sum_exp(i1, 2)
    lse2 = vec_log_sum_exp(i2, 2)
    one_x_one = torch.cat([lse1, lse2], dim=0)
    lse = vec_log_sum_exp(i, 2)
    np.testing.assert_allclose(one_x_one.numpy(), lse.numpy())
示例#2
0
def test_vec_log_sum_exp_batch_stable():
    h = np.random.randint(22, 41)
    i1 = torch.rand(1, h, h)
    i2 = torch.rand(1, h, h)
    i = torch.cat([i1, i2], dim=0)
    lse1 = vec_log_sum_exp(i1, 2)
    lse2 = vec_log_sum_exp(i2, 2)
    one_x_one = torch.cat([lse1, lse2], dim=0)
    lse = vec_log_sum_exp(i, 2)
    np.testing.assert_allclose(one_x_one.numpy(), lse.numpy())
示例#3
0
def test_vec_log_sum_exp_shape():
    dim = torch.randint(0, 3, (1,)).item()
    shape = torch.randint(1, 21, (3,))
    in_ = torch.rand(*shape)
    out = vec_log_sum_exp(in_, dim)
    shape[dim] = 1
    for i in range(len(shape)):
        assert out.size(i) == shape[i]
示例#4
0
def test_vec_log_sum_exp():
    vec = torch.rand(1, np.random.randint(5, 31))
    ours = vec_log_sum_exp(vec, 1).squeeze()
    xs = {}
    for i in range(vec.size(1)):
        xs[i] = vec[0, i].item()
    gold = explicit_log_sum_exp(xs)
    np.testing.assert_allclose(ours, gold, rtol=1e-6)
示例#5
0
def test_vec_log_sum_exp_shape():
    dim = torch.randint(0, 3, (1, )).item()
    shape = torch.randint(1, 21, (3, ))
    in_ = torch.rand(*shape)
    out = vec_log_sum_exp(in_, dim)
    shape[dim] = 1
    for i in range(len(shape)):
        assert out.size(i) == shape[i]
示例#6
0
def test_vec_log_sum_exp():
    vec = torch.rand(1, np.random.randint(5, 31))
    ours = vec_log_sum_exp(vec, 1).squeeze()
    xs = {}
    for i in range(vec.size(1)):
        xs[i] = vec[0, i].item()
    gold = explicit_log_sum_exp(xs)
    np.testing.assert_allclose(ours, gold, rtol=1e-6)
示例#7
0
    def forward(self, unary, lengths, batch_size):
        """For CRF forward on a batch.

        :param unary: torch.FloatTensor: [T, B, N]
        :param lengths: torch.LongTensor: [B]
        :param batch_size: int: B

        :return: torch.FloatTensor: [B]
        """
        min_length = torch.min(lengths)
        # alphas: [B, 1, N]
        alphas = torch.Tensor(batch_size, 1, self.n_tags).fill_(-1e4).to(unary.device)
        alphas[:, 0, self.start_idx] = 0.
        alphas.requires_grad = True

        trans = self.transitions  # [1, N, N]

        for i, unary_t in enumerate(unary):
            # unary_t: [B, N]
            unary_t = unary_t.unsqueeze(2)  # [B, N, 1]
            # Broadcast alphas along the rows of trans
            # Broadcast trans along the batch of alphas
            # [B, 1, N] + [1, N, N] -> [B, N, N]
            # Broadcast unary_t along the cols of result
            # [B, N, N] + [B, N, 1] -> [B, N, N]
            scores = alphas + trans + unary_t
            new_alphas = vec_log_sum_exp(scores, 2).transpose(1, 2)
            # If we haven't reached your length zero out old alpha and take new one.
            # If we are past your length, zero out new_alpha and keep old one.

            if i >= min_length:
                mask = (i < lengths).view(-1, 1, 1)
                alphas = alphas.masked_fill(mask, 0) + new_alphas.masked_fill(mask == 0, 0)
            else:
                alphas = new_alphas

        terminal_vars = alphas + trans[:, self.end_idx]
        alphas = vec_log_sum_exp(terminal_vars, 2)
        return alphas.squeeze()
示例#8
0
 def forward(self, inputs, targets):
     # This is the cosine distance annealing referred to in https://arxiv.org/pdf/1911.03688.pdf
     fract = min(self.steps / self.warmup_steps, 1)
     c = (self.max_scale - 1) * fract + 1
     self.steps += 1
     # These will get broadcast to [B, B, H]
     query = self.model.encode_query(inputs).unsqueeze(1)  # [B, 1, H]
     response = self.model.encode_response(targets).unsqueeze(
         0)  # [1, B, H]
     # all_scores is now a batch x batch matrix where index (i, j) is the score between
     # the i^th x vector and the j^th y vector
     all_score = c * self.score(query, response)  # [B, B]
     # The diagonal has the scores of correct pair, (i, i)
     pos_score = torch.diag(all_score)
     # vec_log_sum_exp will calculate the batched log_sum_exp in a numerically stable way
     # the result is a [B, 1] vector which we squeeze to make it [B] to match the diag
     # Because we are minimizing the negative log we turned the division into a subtraction here
     loss = pos_score - vec_log_sum_exp(all_score, -1).squeeze()
     # Batch loss
     loss = torch.sum(loss)
     # minimize the negative loss
     return -loss
示例#9
0
def test_vec_log_sum_exp_ones():
    l = np.random.randint(1, 21)
    in_ = torch.ones(1, l)
    lse = vec_log_sum_exp(in_, 1).squeeze()
    np.testing.assert_allclose(lse.detach().numpy(), math.log(l * math.e))
示例#10
0
def test_vec_log_sum_exp_ones():
    l = np.random.randint(1, 21)
    in_ = torch.ones(1, l)
    lse = vec_log_sum_exp(in_, 1).squeeze()
    np.testing.assert_allclose(lse.detach().numpy(), math.log(l * math.e))