def test_vq_shape(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) indices = vq(inputs, codebook) assert indices.size() == (2, 3, 5) assert not indices.requires_grad assert indices.dtype == torch.int64
def test_vq(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) indices = vq(inputs, codebook) differences = inputs.unsqueeze(3) - codebook distances = torch.norm(differences, p=2, dim=4) _, indices_torch = torch.min(distances, dim=3) assert np.allclose(indices.numpy(), indices_torch.numpy())
def test_vq_st_gradient2(): inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True) codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True) codes, _ = vq_st(inputs, codebook) indices = vq(inputs, codebook) codes_torch = torch.embedding(codebook, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) grad_output = torch.rand((2, 3, 5, 7), dtype=torch.float32) grad_codebook, = torch.autograd.grad(codes, codebook, grad_outputs=[grad_output]) grad_codebook_torch, = torch.autograd.grad(codes_torch, codebook, grad_outputs=[grad_output]) # Gradient is the same as torch.embedding function assert grad_codebook.size() == (11, 7) assert np.allclose(grad_codebook.numpy(), grad_codebook_torch.numpy())
def forward(self, z_e_x): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() latents = vq(z_e_x_, self.embedding.weight) return latents