def straight_through(self, z_e_x): if not self.object_level: z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach(), self.object_level) z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() else: z_q_x, indices = vq_st(z_e_x, self.embedding.weight.detach(), self.object_level) z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar = z_q_x_bar_flatten.view_as(z_e_x) return z_q_x, z_q_x_bar
def test_vq_st_gradient1(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) codes, _ = vq_st(inputs, codebook) grad_output = torch.rand((2, 3, 5, 7)) grad_inputs, = torch.autograd.grad(codes, inputs, grad_outputs=[grad_output]) # Straight-through estimator assert grad_inputs.size() == (2, 3, 5, 7) assert np.allclose(grad_output.numpy(), grad_inputs.numpy())
def test_vq_st_shape(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) codes, indices = vq_st(inputs, codebook) assert codes.size() == (2, 3, 5, 7) assert codes.requires_grad assert codes.dtype == torch.float32 assert indices.size() == (2 * 3 * 5,) assert not indices.requires_grad assert indices.dtype == torch.int64
def straight_through(self, z_e_x): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight) z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() return z_q_x, z_q_x_bar
def test_vq_st_gradient2(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) codes, _ = vq_st(inputs, codebook) indices = vq(inputs, codebook) codes_torch = torch.embedding(codebook, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) grad_output = torch.rand((2, 3, 5, 7), dtype=torch.float32) grad_codebook, = torch.autograd.grad(codes, codebook, grad_outputs=[grad_output]) grad_codebook_torch, = torch.autograd.grad(codes_torch, codebook, grad_outputs=[grad_output]) # Gradient is the same as torch.embedding function assert grad_codebook.size() == (11, 7) assert np.allclose(grad_codebook.numpy(), grad_codebook_torch.numpy())
def straight_through(self, z_e_x, index=False): z_e_x_ = z_e_x.contiguous() if index: z_q_x_, indices, indices_not_flatten = vq_st_i( z_e_x_, self.embedding.weight.detach()) else: z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) z_q_x = z_q_x_.contiguous() z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) z_q_x_bar = z_q_x_bar_.contiguous() if index: return z_q_x, z_q_x_bar, indices_not_flatten else: return z_q_x, z_q_x_bar