コード例 #1
0
    def normal_flow(self, x, logdet=None):
        """
        Normal flow

        :param x: input tensor
        :type x: torch.Tensor
        :param logdet: log determinant
        :type logdet: torch.Tensor
        :return: output and logdet
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        # activation normalization layer
        z, logdet = self.actnorm(x, logdet=logdet, reverse=False)

        # flow permutation layer
        if self.permutation == 'invconv':
            z, logdet = self.invconv(z, logdet, reverse=False)
        elif self.permutation == 'reverse':
            z = self.reverse(z, reverse=False)
        else:
            z = self.shuffle(z, reverse=False)

        # flow coupling layer
        z1, z2 = ops.split_channel(z, 'simple')
        if self.coupling == 'additive':
            z2 += self.f(z1)
        else:
            h = self.f(z1)
            shift, scale = ops.split_channel(h, 'cross')
            # scale = F.sigmoid(scale + 2.)
            scale = torch.sigmoid(scale + 2.)
            z2 += shift
            z2 *= scale
            logdet = ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = ops.cat_channel(z1, z2)

        return z, logdet
コード例 #2
0
 def test_reduce_sum(self):
     x = torch.ones(2, 3, 16, 16)
     sum = ops.reduce_sum(x, dim=[1, 2, 3])
     sum_shape = float(x.shape[1] * x.shape[2] * x.shape[3])
     self.assertTrue(
         ops.tensor_equal(torch.Tensor([sum_shape, sum_shape]), sum))