示例#1
0
 def straight_through(self, z_e_x):
     if not self.object_level:
         z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
         z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach(),
                                 self.object_level)
         z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
         z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
                                                dim=0,
                                                index=indices)
         z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
         z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
     else:
         z_q_x, indices = vq_st(z_e_x, self.embedding.weight.detach(),
                                self.object_level)
         z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
                                                dim=0,
                                                index=indices)
         z_q_x_bar = z_q_x_bar_flatten.view_as(z_e_x)
     return z_q_x, z_q_x_bar
def test_vq_st_gradient1():
    inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
    codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
    codes, _ = vq_st(inputs, codebook)

    grad_output = torch.rand((2, 3, 5, 7))
    grad_inputs, = torch.autograd.grad(codes, inputs,
        grad_outputs=[grad_output])

    # Straight-through estimator
    assert grad_inputs.size() == (2, 3, 5, 7)
    assert np.allclose(grad_output.numpy(), grad_inputs.numpy())
def test_vq_st_shape():
    inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
    codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
    codes, indices = vq_st(inputs, codebook)

    assert codes.size() == (2, 3, 5, 7)
    assert codes.requires_grad
    assert codes.dtype == torch.float32

    assert indices.size() == (2 * 3 * 5,)
    assert not indices.requires_grad
    assert indices.dtype == torch.int64
示例#4
0
    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
        z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
                                               dim=0,
                                               index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()

        return z_q_x, z_q_x_bar
def test_vq_st_gradient2():
    inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
    codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
    codes, _ = vq_st(inputs, codebook)

    indices = vq(inputs, codebook)
    codes_torch = torch.embedding(codebook, indices, padding_idx=-1,
        scale_grad_by_freq=False, sparse=False)

    grad_output = torch.rand((2, 3, 5, 7), dtype=torch.float32)
    grad_codebook, = torch.autograd.grad(codes, codebook,
        grad_outputs=[grad_output])
    grad_codebook_torch, = torch.autograd.grad(codes_torch, codebook,
        grad_outputs=[grad_output])

    # Gradient is the same as torch.embedding function
    assert grad_codebook.size() == (11, 7)
    assert np.allclose(grad_codebook.numpy(), grad_codebook_torch.numpy())
示例#6
0
    def straight_through(self, z_e_x, index=False):

        z_e_x_ = z_e_x.contiguous()

        if index:
            z_q_x_, indices, indices_not_flatten = vq_st_i(
                z_e_x_, self.embedding.weight.detach())
        else:
            z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())

        z_q_x = z_q_x_.contiguous()
        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
                                               dim=0,
                                               index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()
        if index:
            return z_q_x, z_q_x_bar, indices_not_flatten
        else:
            return z_q_x, z_q_x_bar