def test_split_channel(self): x = torch.ones(2, 4, 16, 16) nc = x.shape[1] # simple splitting x1, x2 = ops.split_channel(x, 'simple') for c in range(nc // 2): self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, c, :, :])) self.assertTrue( ops.tensor_equal(x2[:, c, :, :], x[:, nc // 2 + c, :, :])) # cross splitting x1, x2 = ops.split_channel(x, 'cross') for c in range(nc // 2): self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, 2 * c, :, :])) self.assertTrue( ops.tensor_equal(x2[:, c, :, :], x[:, 2 * c + 1, :, :]))
def normal_flow(self, x, logdet=None): """ Normal flow :param x: input tensor :type x: torch.Tensor :param logdet: log determinant :type logdet: torch.Tensor :return: output and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ # activation normalization layer z, logdet = self.actnorm(x, logdet=logdet, reverse=False) # flow permutation layer if self.permutation == 'invconv': z, logdet = self.invconv(z, logdet, reverse=False) elif self.permutation == 'reverse': z = self.reverse(z, reverse=False) else: z = self.shuffle(z, reverse=False) # flow coupling layer z1, z2 = ops.split_channel(z, 'simple') if self.coupling == 'additive': z2 += self.f(z1) else: h = self.f(z1) shift, scale = ops.split_channel(h, 'cross') # scale = F.sigmoid(scale + 2.) scale = torch.sigmoid(scale + 2.) z2 += shift z2 *= scale logdet = ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = ops.cat_channel(z1, z2) return z, logdet
def prior(self, y_onehot=None): """ Prior :param y_onehot: one-hot vector of label :type y_onehot: torch.Tensor :return: hidden output :rtype: torch.Tensor """ nc = self.h_top.shape[1] h = self.h_top.detach().clone() assert torch.sum(h) == 0. if self.hps.ablation.learn_top: h = self.learn_top(h) if self.hps.ablation.y_condition: assert y_onehot is not None h += self.y_emb(y_onehot).view(-1, nc, 1, 1) return ops.split_channel(h, 'simple')
def test_cat_channel(self): x = torch.ones(2, 4, 16, 16) x1, x2 = ops.split_channel(x, 'simple') self.assertTrue(ops.tensor_equal(ops.cat_channel(x1, x2), x))