예제 #1
0
def logistic_distribution(loc: Tensor, scale: Tensor):
    base_distribution = td.Uniform(loc.new_zeros(1), scale.new_zeros(1))
    transforms = [
        td.SigmoidTransform().inv,
        td.AffineTransform(loc=loc, scale=scale)
    ]
    return td.TransformedDistribution(base_distribution, transforms)
예제 #2
0
 def __init__(self, loc, scale, **kwargs):
     loc, scale = map(torch.as_tensor, (loc, scale))
     base_distribution = ptd.Uniform(torch.zeros_like(loc),
                                     torch.ones_like(loc), **kwargs)
     transforms = [
         ptd.SigmoidTransform().inv,
         ptd.AffineTransform(loc=loc, scale=scale),
     ]
     super().__init__(base_distribution, transforms)
예제 #3
0
    def __init__(self, loc: torch.Tensor, scale: torch.Tensor):
        self.loc, self.scale = dist.utils.broadcast_all(loc, scale)

        zero, one = torch.Tensor([0, 1]).type_as(loc)

        base_distribution = dist.Uniform(zero, one).expand(self.loc.shape)
        transforms = [
            dist.SigmoidTransform().inv,
            dist.AffineTransform(loc=self.loc, scale=self.scale)
        ]

        super(Logistic, self).__init__(base_distribution, transforms)
예제 #4
0
            samples = self.prior().rsample((batch_size, ))
            samples = samples.view((batch_size, ) + self.final_shape)
            return self.inverse(samples, max_iter=max_iter)

    def set_num_terms(self, n_terms):
        for block in self.stack:
            for layer in block.stack:
                layer.numSeriesTerms = n_terms


if __name__ == "__main__":
    scale = 1.
    loc = 0.
    base_distribution = distributions.Uniform(0., 1.)
    transforms_1 = [
        distributions.SigmoidTransform().inv,
        distributions.AffineTransform(loc=loc, scale=scale)
    ]
    logistic_1 = distributions.TransformedDistribution(base_distribution,
                                                       transforms_1)

    transforms_2 = [
        LogisticTransform(),
        distributions.AffineTransform(loc=loc, scale=scale)
    ]
    logistic_2 = distributions.TransformedDistribution(base_distribution,
                                                       transforms_2)

    x = torch.zeros(2)
    print(logistic_1.log_prob(x), logistic_2.log_prob(x))
    1 / 0
예제 #5
0
            # only send batch_size to prior, prior has final_shape as attribute
            samples = self.prior().rsample((batch_size,))
            samples = samples.view((batch_size,) + self.final_shape)
            return self.inverse(samples, max_iter=max_iter)

    def set_num_terms(self, n_terms):
        for block in self.stack:
            for layer in block.stack:
                layer.numSeriesTerms = n_terms


if __name__ == "__main__":
    scale = 1.
    loc = 0.
    base_distribution = distributions.Uniform(0., 1.)
    transforms_1 = [distributions.SigmoidTransform().inv, distributions.AffineTransform(loc=loc, scale=scale)]
    logistic_1 = distributions.TransformedDistribution(base_distribution, transforms_1)

    transforms_2 = [LogisticTransform(), distributions.AffineTransform(loc=loc, scale=scale)]
    logistic_2 = distributions.TransformedDistribution(base_distribution, transforms_2)

    x = torch.zeros(2)
    print(logistic_1.log_prob(x), logistic_2.log_prob(x))
    1/0

    diff = lambda x, y: (x - y).abs().sum()
    batch_size = 13
    channels = 3
    h, w = 32, 32
    in_shape = (batch_size, channels, h, w)
    x = torch.randn((batch_size, channels, h, w), requires_grad=True)
예제 #6
0
def sample(args):
    """
    Performs the following:
    1. construct model object & load state dict from saved model;
    2. make H x W samples from a set of gaussian or logistic prior on the latent space;
    3. save to disk as a grid of images.
    """
    # parse settings:
    if args.dataset == 'mnist':
        input_dim = 28 * 28
        img_height = 28
        img_width = 28
        img_depth = 1
    if args.dataset == 'svhn':
        input_dim = 32 * 32 * 3
        img_height = 32
        img_width = 32
        img_depth = 3
    if args.dataset == 'cifar10':
        input_dim = 32 * 32 * 3
        img_height = 32
        img_width = 32
        img_depth = 3
    if args.dataset == 'tfd':
        raise NotImplementedError(
            "[sample] Toronto Faces Dataset unsupported right now. Sorry!")
        input_dim = None
        img_height = None
        img_width = None
        img_depth = None

    # shut off gradients for sampling:
    torch.set_grad_enabled(False)

    # build model & load state dict:
    nice = NICEModel(input_dim, args.nhidden, args.nlayers)
    if args.model_path is not None:
        nice.load_state_dict(torch.load(args.model_path, map_location='cpu'))
        print("[sample] Loaded model from file.")
    nice.eval()

    # sample a batch:
    if args.prior == 'logistic':
        LOGISTIC_LOC = 0.0
        LOGISTIC_SCALE = (3. / (np.pi**2))  # (sets variance to 1)
        logistic = dist.TransformedDistribution(dist.Uniform(0.0, 1.0), [
            dist.SigmoidTransform().inv,
            dist.AffineTransform(loc=LOGISTIC_LOC, scale=LOGISTIC_SCALE)
        ])
        print(
            "[sample] sampling from logistic prior with loc={0:.4f}, scale={1:.4f}."
            .format(LOGISTIC_LOC, LOGISTIC_SCALE))
        ys = logistic.sample(torch.Size([args.nrows * args.ncols, input_dim]))
        xs = nice.inverse(ys)
    if args.prior == 'gaussian':
        print("[sample] sampling from gaussian prior with loc=0.0, scale=1.0.")
        ys = torch.randn(args.nrows * args.ncols, input_dim)
        xs = nice.inverse(ys)

    # format sample into images of correct shape:
    image_batch = unflatten_images(xs, img_depth, img_height, img_width)

    # arrange into a grid and save to file:
    torchvision.utils.save_image(image_batch,
                                 args.save_image_path,
                                 nrow=args.nrows)
    print("[sample] Saved {0}-by-{1} sampled images to {2}.".format(
        args.nrows, args.ncols, args.save_image_path))