예제 #1
0
 def test_split_channel(self):
     x = torch.ones(2, 4, 16, 16)
     nc = x.shape[1]
     # simple splitting
     x1, x2 = ops.split_channel(x, 'simple')
     for c in range(nc // 2):
         self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:, c, :, :]))
         self.assertTrue(
             ops.tensor_equal(x2[:, c, :, :], x[:, nc // 2 + c, :, :]))
     # cross splitting
     x1, x2 = ops.split_channel(x, 'cross')
     for c in range(nc // 2):
         self.assertTrue(ops.tensor_equal(x1[:, c, :, :], x[:,
                                                            2 * c, :, :]))
         self.assertTrue(
             ops.tensor_equal(x2[:, c, :, :], x[:, 2 * c + 1, :, :]))
예제 #2
0
    def normal_flow(self, x, logdet=None):
        """
        Normal flow

        :param x: input tensor
        :type x: torch.Tensor
        :param logdet: log determinant
        :type logdet: torch.Tensor
        :return: output and logdet
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        # activation normalization layer
        z, logdet = self.actnorm(x, logdet=logdet, reverse=False)

        # flow permutation layer
        if self.permutation == 'invconv':
            z, logdet = self.invconv(z, logdet, reverse=False)
        elif self.permutation == 'reverse':
            z = self.reverse(z, reverse=False)
        else:
            z = self.shuffle(z, reverse=False)

        # flow coupling layer
        z1, z2 = ops.split_channel(z, 'simple')
        if self.coupling == 'additive':
            z2 += self.f(z1)
        else:
            h = self.f(z1)
            shift, scale = ops.split_channel(h, 'cross')
            # scale = F.sigmoid(scale + 2.)
            scale = torch.sigmoid(scale + 2.)
            z2 += shift
            z2 *= scale
            logdet = ops.reduce_sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = ops.cat_channel(z1, z2)

        return z, logdet
예제 #3
0
    def prior(self, y_onehot=None):
        """
        Prior

        :param y_onehot: one-hot vector of label
        :type y_onehot: torch.Tensor
        :return: hidden output
        :rtype: torch.Tensor
        """
        nc = self.h_top.shape[1]
        h = self.h_top.detach().clone()
        assert torch.sum(h) == 0.
        if self.hps.ablation.learn_top:
            h = self.learn_top(h)
        if self.hps.ablation.y_condition:
            assert y_onehot is not None
            h += self.y_emb(y_onehot).view(-1, nc, 1, 1)
        return ops.split_channel(h, 'simple')
예제 #4
0
 def test_cat_channel(self):
     x = torch.ones(2, 4, 16, 16)
     x1, x2 = ops.split_channel(x, 'simple')
     self.assertTrue(ops.tensor_equal(ops.cat_channel(x1, x2), x))