def normal_flow(self, x, logdet=None): """ Normal flow :param x: input tensor :type x: torch.Tensor :param logdet: log determinant :type logdet: torch.Tensor :return: output and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ # activation normalization layer z, logdet = self.actnorm(x, logdet=logdet, reverse=False) # flow permutation layer if self.permutation == 'invconv': z, logdet = self.invconv(z, logdet, reverse=False) elif self.permutation == 'reverse': z = self.reverse(z, reverse=False) else: z = self.shuffle(z, reverse=False) # flow coupling layer z1, z2 = ops.split_channel(z, 'simple') if self.coupling == 'additive': z2 += self.f(z1) else: h = self.f(z1) shift, scale = ops.split_channel(h, 'cross') # scale = F.sigmoid(scale + 2.) scale = torch.sigmoid(scale + 2.) z2 += shift z2 *= scale logdet = ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet z = ops.cat_channel(z1, z2) return z, logdet
def test_reduce_sum(self): x = torch.ones(2, 3, 16, 16) sum = ops.reduce_sum(x, dim=[1, 2, 3]) sum_shape = float(x.shape[1] * x.shape[2] * x.shape[3]) self.assertTrue( ops.tensor_equal(torch.Tensor([sum_shape, sum_shape]), sum))