Пример #1
0
    def momentum_sample(self, state):
        '''
        Generates either Laplacian or gaussian momentum dependent upon problem set-up

        :return:
        '''
        p = {}
        p_cont = dict([[
            key,
            torch.randn(loc=torch.zeros(state[key].size()),
                        scale=torch.ones(state[key].size())).sample()
        ] for key in self._cont_latents
                       ]) if self._cont_latents is not None else {}
        p_disc = dict([[
            key,
            dist.Laplace(loc=torch.zeros(state[key].size()),
                         scale=torch.ones(state[key].size())).sample()
        ] for key in self._disc_latents
                       ]) if self._disc_latents is not None else {}
        p_if = dict([[
            key,
            dist.Laplace(loc=torch.zeros(state[key].size()),
                         scale=torch.ones(state[key].size())).sample()
        ] for key in self._if_latents
                     ]) if self._if_latents is not None else {}
        # quicker than using dict then update.
        p.update(p_cont)
        p.update(p_disc)
        p.update(p_if)
        return p
Пример #2
0
 def _dist(self, loc, scale, inv_scale, log_sqrt_vals, base_scale,
           event_shape):
     zeros = torch.zeros((), device=loc.device,
                         dtype=loc.dtype).expand(event_shape)
     return td.TransformedDistribution(
         td.Laplace(loc=zeros, scale=base_scale.expand(event_shape)),
         PCATransform(loc, scale, inv_scale, log_sqrt_vals))
Пример #3
0
def nll_laplace(x, recon):
    batch_size, num_half_chans = x.size(0), recon.size(1) // 2
    recon_mu = recon[:, 0:num_half_chans, :, :].contiguous()
    recon_logvar = recon[:, num_half_chans:, :, :].contiguous()

    nll = D.Laplace(
        # recon_mu.view(batch_size, -1),
        recon_mu.view(batch_size, -1),
        # F.hardtanh(recon_logvar.view(batch_size, -1), min_val=-4.5, max_val=0) + 1e-6
        recon_logvar.view(batch_size, -1)).log_prob(x.view(batch_size, -1))
    return -torch.sum(nll, dim=-1)
Пример #4
0
def rec_loss_fn(recon_x, x, sum_samples=True, correct=False, sumdim=(1, 2, 3)):
    if correct:
        x_dist = dist.Laplace(recon_x, 1.0)
        log_p_x_z = x_dist.log_prob(x)
        log_p_x_z = torch.sum(log_p_x_z, dim=sumdim)
    else:
        log_p_x_z = -torch.abs(recon_x - x)
        log_p_x_z = torch.mean(log_p_x_z, dim=sumdim)
    if sum_samples:
        return -torch.mean(log_p_x_z)
    else:
        return -log_p_x_z
Пример #5
0
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist
import matplotlib.pyplot as plt
from scipy.sparse.linalg import LinearOperator, cg

from src.params_to_flat import params_to_flat

p = 2
n = 30
od = 1
dist_w = dist.Laplace(0, 0.1)

# number of samples
N = 1500

# create parameters
xi = nn.Parameter(torch.zeros(n, p), requires_grad=True)
old_ig = None

class MarginalNetwork(nn.Module):

    def __init__(self, inp_dim, out_dim, hidden_dim=16, batch_shape=1):

        nn.Module.__init__(self)
        self.l1 = nn.Linear(inp_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l_mu = nn.Linear(hidden_dim, out_dim)
Пример #6
0
    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25  #tbs[i]/denominator

        loss_all = dict()
        num_data = dict()
        for i in range(4):
            loss_all[i] = 0
            num_data[i] = 0
        count = 0
        for t in range(args.nsteps):
            for i in range(4):
                optimizers[i].zero_grad()
                a, b = next(iter(train_loaders[i]))
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                output = models[i](a)
                loss = nnloss(output, b)
                loss.backward()
                loss_all[i] += loss.item() * b.size(0)
                optimizers[i].step()
            count += 1
            if count % pace == 0 or t == args.nsteps - 1:
                with torch.no_grad():
                    for key in model.state_dict().keys():
                        if models[0].state_dict()[key].dtype == torch.int64:
                            model.state_dict()[key].data.copy_(
                                models[0].state_dict()[key])
                        else:
                            temp = torch.zeros_like(model.state_dict()[key])
                            # add noise
                            for s in range(4):
                                if args.type == 'G':
                                    nn = tdist.Normal(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                else:
                                    nn = tdist.Laplace(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                noise = nn.sample(models[s].state_dict()
                                                  [key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s] * (models[s].state_dict()[key] +
                                                noise)
                            # update global model
                            model.state_dict()[key].data.copy_(temp)
                            # updata local model
                            for s in range(4):
                                models[s].state_dict()[key].data.copy_(
                                    model.state_dict()[key])

        return loss_all[0] / num_data[0], loss_all[1] / num_data[1], \
               loss_all[2] / num_data[2], loss_all[3] / num_data[3]
Пример #7
0
    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizerGs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizerGs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

            discriminators[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizerDs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizerDs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25  #tbs[i]/denominator

        loss_all = dict()
        lossD_all = dict()
        lossG_all = dict()
        num_data = dict()
        num_dataG = dict()
        num_dataD = dict()
        for i in range(4):
            loss_all[i] = 0
            num_data[i] = EPS
            num_dataG[i] = EPS
            lossG_all[i] = 0
            lossD_all[i] = 0
            num_dataD[i] = EPS

        count = 0
        for t in range(args.nsteps):
            fs = []

            # optimize classifier

            for i in range(4):
                optimizers[i].zero_grad()
                a, b = next(data_iters[i])
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                output = models[i](a)
                loss = celoss(output, b)
                loss_all[i] += loss.item() * b.size(0)
                if epoch >= 0:
                    loss.backward(retain_graph=True)
                    optimizers[i].step()

                fs.append(models[i].encoder(a))

            #optimize alignment

            nn = []
            noises = []
            for i in range(4):
                nn = tdist.Normal(torch.tensor([0.0]),
                                  0.001 * torch.std(fs[i].detach().cpu()))
                noises.append(nn.sample(fs[i].size()).squeeze().to(device))

            for i in range(4):
                for j in range(4):
                    if i != j:
                        optimizerDs[i].zero_grad()
                        optimizerGs[i].zero_grad()
                        optimizerGs[j].zero_grad()

                        d1 = discriminators[i](fs[i] + noises[i])
                        d2 = discriminators[i](fs[j] + noises[j])
                        num_dataG[i] += d1.size(0)
                        num_dataD[i] += d1.size(0)
                        lossD = advDloss(d1, d2)
                        lossG = advGloss(d1, d2)
                        lossD_all[i] += lossD.item() * d1.size(0)
                        lossG_all[i] += lossG.item() * d1.size(0)
                        lossG_all[j] += lossG.item() * d2.size(0)
                        lossD = 0.1 * lossD
                        if epoch >= 5:
                            lossD.backward(retain_graph=True)
                            optimizerDs[i].step()
                            lossG.backward(retain_graph=True)
                            optimizerGs[i].step()
                            optimizerGs[j].step()
                        writer.add_histogram(
                            'Hist/hist_' + site[i] + '2' + site[j] + '_source',
                            d1, epoch * args.nsteps + t)
                        writer.add_histogram(
                            'Hist/hist_' + site[i] + '2' + site[j] + '_target',
                            d2, epoch * args.nsteps + t)

            count += 1
            if count % pace == 0 or t == args.nsteps - 1:
                with torch.no_grad():
                    for key in model.state_dict().keys():
                        if models[0].state_dict()[key].dtype == torch.int64:
                            model.state_dict()[key].data.copy_(
                                models[0].state_dict()[key])
                        else:
                            temp = torch.zeros_like(model.state_dict()[key])
                            # add noise
                            for s in range(4):
                                if args.type == 'G':
                                    nn = tdist.Normal(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                else:
                                    nn = tdist.Laplace(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                noise = nn.sample(models[s].state_dict()
                                                  [key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s] * (models[s].state_dict()[key] +
                                                noise)
                            # update global model
                            model.state_dict()[key].data.copy_(temp)
                            # updata local model
                            for s in range(4):
                                models[s].state_dict()[key].data.copy_(
                                    model.state_dict()[key])

        return loss_all, lossG_all, lossD_all, num_data, num_dataG, num_dataD
Пример #8
0
 def __init__(self, dim, sigma=None, lambd=None, device='cpu'):
     super().__init__(dim, sigma, lambd, device)
     self.laplace_dist = D.Laplace(loc=torch.tensor(0.0, device=device),
                                   scale=torch.tensor(self.lambd,
                                                      device=device))
     self.linf_radii = self.linf_rho = self._linf_table_info = None
Пример #9
0
Файл: vae.py Проект: pbloem/blog
def go(arg):

    tbw = SummaryWriter(log_dir=arg.tb_dir)

    ## Load the data
    if arg.task == 'mnist':
        transform = tfs.Compose([tfs.Pad(padding=2), tfs.ToTensor()])

        trainset = torchvision.datasets.MNIST(root=arg.data_dir,
                                              train=True,
                                              download=True,
                                              transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.MNIST(root=arg.data_dir,
                                             train=False,
                                             download=True,
                                             transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 32, 32

    elif arg.task == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                                train=True,
                                                download=True,
                                                transform=tfs.ToTensor())
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                               train=False,
                                               download=True,
                                               transform=tfs.ToTensor())
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 3, 32, 32

    elif arg.task == 'cifar-gs':
        transform = tfs.Compose([tfs.Grayscale(), tfs.ToTensor()])

        trainset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 32, 32

    elif arg.task == 'imagenet64':

        transform = tfs.Compose([tfs.ToTensor()])

        trainset = torchvision.datasets.ImageFolder(root=arg.data_dir +
                                                    os.sep + 'train',
                                                    transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep +
                                                   'valid',
                                                   transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)

        C, H, W = 3, 64, 64

    else:
        raise Exception('Task {} not recognized.'.format(arg.task))

    ## Set up the model
    out_channels = C
    if (arg.rloss == 'gauss' or arg.rloss == 'laplace'
            or arg.rloss == 'signorm' or arg.rloss == 'siglaplace'
            or arg.rloss == 'beta') and arg.scale is None:
        out_channels = 2 * C

    print(f'out channels: {out_channels}')

    encoder = Encoder(zsize=arg.zsize, colors=C)
    decoder = Decoder(zsize=arg.zsize,
                      out_channels=out_channels,
                      mult=arg.mult)

    if arg.testmodel:
        decoder = Test(out_channels=out_channels, height=H, width=W)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    opt = torch.optim.Adam(lr=arg.lr,
                           params=list(encoder.parameters()) +
                           list(decoder.parameters()))

    if arg.esched is not None:
        start, end = int(arg.esched[0] * arg.epochs), (arg.esched[1] *
                                                       arg.epochs)
        slope = 1.0 / (end - start)

    for epoch in range(arg.epochs):

        if arg.esched is not None:
            weight = (epoch - start) * slope
            weight = np.clip(weight, 0, 1)
        else:
            weight = 1.0

        for i, (input, _) in enumerate(tqdm.tqdm(trainloader)):
            if arg.limit is not None and i * arg.batch_size > arg.limit:
                break

                # Prepare the input
            b, c, w, h = input.size()
            if torch.cuda.is_available():
                input = input.cuda()

            # Forward pass
            if not arg.testmodel:
                zs = encoder(input)

                kloss = kl_loss(zs[:, :arg.zsize], zs[:, arg.zsize:])
                z = sample(zs[:, :arg.zsize], zs[:, arg.zsize:])

                out = decoder(z)
            else:
                out = decoder(input)
                kloss = 0

            # compute -log p per dimension
            if arg.rloss == 'xent':  # binary cross-entropy (not a proper log-prob)

                rloss = F.binary_cross_entropy_with_logits(out,
                                                           input,
                                                           reduction='none')

            elif arg.rloss == 'bdist':  #   xent + correction
                rloss = F.binary_cross_entropy_with_logits(out,
                                                           input,
                                                           reduction='none')

                za = out.abs()
                eza = (-za).exp()

                # - np.log(za) + np.log1p(-eza + EPS) - np.log1p(eza + EPS)
                logpart = -(za + arg.eps).log() + (-eza + arg.eps).log1p() - (
                    eza + arg.eps).log1p()

                rloss = rloss + weight * logpart

            elif arg.rloss == 'gauss':  # xent + correction
                if arg.scale is None:
                    means = T.sigmoid(out[:, :c, :, :])
                    vars = F.sigmoid(out[:, c:, :, :])

                    rloss = GAUSS_CONST + vars.log() + (
                        1.0 / (2.0 * vars.pow(2.0))) * (input - means).pow(2.0)
                else:
                    means = T.sigmoid(out[:, :c, :, :])
                    var = arg.scale

                    rloss = GAUSS_CONST + ln(
                        var) + (1.0 / (2.0 *
                                       (var * var))) * (input - means).pow(2.0)

            elif arg.rloss == 'mse':
                means = T.sigmoid(out[:, :c, :, :])
                rloss = (input - means).pow(2.0)

            elif arg.rloss == 'mae':
                means = T.sigmoid(out[:, :c, :, :])
                rloss = (input - means).abs()

            elif arg.rloss == 'laplace':  # xent + correction
                if arg.scale is None:
                    means = T.sigmoid(out[:, :c, :, :])
                    vars = F.softplus(out[:, c:, :, :])

                    rloss = (2.0 * vars).log() + (1.0 / vars) * (input -
                                                                 means).abs()
                else:
                    means = T.sigmoid(out[:, :c, :, :])
                    var = arg.scale

                    rloss = ln(2.0 * var) + (1.0 / var) * (input - means).abs()

            elif arg.rloss == 'signorm':
                if arg.scale is None:

                    mus = out[:, :c, :, :]
                    sgs, lsgs = T.exp(
                        out[:, c:, :, :] *
                        arg.varmult), out[:, c:, :, :] * arg.varmult

                else:
                    mus = out[:, :c, :, :]
                    sgs, lsgs = arg.scale, math.log(arg.scale)

                y = input

                lny = torch.log(y + arg.eps)
                ln1y = torch.log(1 - y + arg.eps)

                x = lny - ln1y

                rloss = lny + ln1y + lsgs + GAUSS_CONST + \
                        0.5 * (1.0 / (sgs * sgs + arg.eps)) * (x - mus) ** 2

            elif arg.rloss == 'siglaplace':

                if arg.scale is None:

                    mus = out[:, :c, :, :]
                    sgs, lsgs = T.exp(
                        out[:, c:, :, :] *
                        arg.varmult), out[:, c:, :, :] * arg.varmult

                else:
                    mus = out[:, :c, :, :]
                    sgs, lsgs = arg.scale, math.log(arg.scale)

                y = input

                lny = torch.log(y + arg.eps)
                ln1y = torch.log(1 - y + arg.eps)

                x = lny - ln1y

                rloss = lny + ln1y + lsgs + math.log(2.0) + \
                        (x - mus).abs() / sgs

            elif arg.rloss == 'beta':

                mean = T.sigmoid(out[:, :c, :, :])
                mult = F.softplus(out[:, c:, :, :] +
                                  arg.beta_add) + (1.0 /
                                                   (mean + arg.eps)) + arg.eps

                alpha = mean * mult
                beta = (1 - mean) * mult

                part = alpha.lgamma() + beta.lgamma() - (alpha + beta).lgamma()
                x = input

                rloss = -(alpha - 1) * (x + arg.eps).log() - (beta - 1) * (
                    1 - x + arg.eps).log() + part

            else:
                raise Exception(
                    f'reconstruction loss {arg.rloss} not recognized.')

            if contains_nan(rloss):
                if arg.rloss == 'beta':
                    print('part contains nan', contains_nan(part))

                    print('alpha contains nan', contains_nan(alpha))
                    print('beta  contains nan', contains_nan(beta))

                    print('log x contains nan',
                          contains_nan((x + arg.eps).log()))
                    print('log (1-x)  contains nan',
                          contains_nan((1 - x + arg.eps).log()))

                raise Exception('rloss contains nan')

            rloss = rloss.reshape(b, -1).sum(dim=1)  # reduce
            loss = (rloss + kloss).mean()

            opt.zero_grad()
            loss.backward()

            opt.step()

        with torch.no_grad():
            N = 5

            # Plot reconstructions

            inputs, _ = next(iter(testloader))

            if torch.cuda.is_available():
                inputs = inputs.cuda()

            b, c, h, w = inputs.size()

            if not arg.testmodel:
                zs = encoder(inputs)
                res = decoder(zs[:, :arg.zsize])
            else:
                res = decoder(inputs)

            outputs = res[:, :c, :, :]
            means = T.sigmoid(outputs)

            samples = None

            if arg.rloss == 'signorm' and out_channels > c:
                means = res[:, :c, :, :]
                vars = res[:, c:, :, :] * arg.varmult

                dist = ds.Normal(means, vars)
                samples = T.sigmoid(dist.sample())
                means = T.sigmoid(dist.mean)

            if arg.rloss == 'siglaplace' and out_channels > c:
                means = res[:, :c, :, :]
                vars = res[:, c:, :, :] * arg.varmult

                dist = ds.Laplace(means, vars)
                samples = T.sigmoid(dist.sample())
                means = T.sigmoid(dist.mean)

            if arg.rloss == 'beta':

                mean = T.sigmoid(res[:, :c, :, :])
                mult = (res[:, c:, :, :] +
                        arg.beta_add).exp() + (1.0 / mean) + arg.eps

                alpha = mean * mult
                beta = (1 - mean) * mult

                dist = ds.Beta(alpha, beta)
                samples = dist.sample()
                means = dist.mean
                vars = dist.variance

            plt.figure(figsize=(5, 4))

            for i in range(N):

                ax = plt.subplot(4, N, i + 1)
                inp = inputs[i].permute(1, 2, 0).cpu().numpy()
                if c == 1:
                    inp = inp.squeeze()

                ax.imshow(inp, cmap='gray_r')

                if i == 0:
                    ax.set_title('input')
                plt.axis('off')

                ax = plt.subplot(4, N, N + i + 1)

                outp = means[i].permute(1, 2, 0).cpu().numpy()
                if c == 1:
                    outp = outp.squeeze()

                ax.imshow(outp, cmap='gray_r')

                if i == 0:
                    ax.set_title('means/modes')
                plt.axis('off')

                if samples is not None:  # plot samples

                    ax = plt.subplot(4, N, 2 * N + i + 1)

                    outp = samples[i].permute(1, 2, 0).detach().cpu().numpy()
                    if c == 1:
                        outp = outp.squeeze()

                    ax.imshow(outp, cmap='gray_r')

                    if i == 0:
                        ax.set_title('sampled')
                    plt.axis('off')

                if out_channels > c:  # plot the variance (or other uncertainty)

                    ax = plt.subplot(4, N, 3 * N + i + 1)

                    outp = vars[i].permute(1, 2, 0).detach().cpu().numpy()
                    if c == 1:
                        outp = outp.squeeze()

                    ax.imshow(outp, cmap='copper')

                    if i == 0:
                        ax.set_title('var')
                    plt.axis('off')

            plt.tight_layout()
            plt.savefig(f'reconstruction.{arg.rloss}.{epoch:03}.png')

            if arg.zsize == 2:  # latent space plot

                N = 2000
                # gather up first 200 batches into one big tensor
                numbatches = N // arg.batch_size
                images, labels = [], []
                for i, (ims, lbs) in enumerate(testloader):
                    images.append(ims)
                    labels.append(lbs)

                    if i > numbatches:
                        break

                images, labels = torch.cat(images, dim=0), torch.cat(labels,
                                                                     dim=0)

                imagesg = images
                if torch.cuda.is_available():
                    imagesg = imagesg.cuda()

                n, c, h, w = images.size()

                z = encoder(imagesg)
                latents = z[:, :2].data.detach().cpu()

                mn, mx = latents.min(), latents.max()
                size = 1.0 * (mx - mn) / math.sqrt(n)
                # Change 0.75 to any value between ~ 0.5 and 1.5 to make the digits smaller or bigger

                fig = plt.figure(figsize=(8, 8))

                # colormap for the images
                norm = mpl.colors.Normalize(vmin=0, vmax=9)
                cmap = mpl.cm.get_cmap('tab10')

                for i in range(n):
                    x, y = latents[i, 0:2]
                    l = labels[i]

                    im = images[i, :]
                    alpha_im = im.permute(1, 2, 0).detach().cpu().numpy()
                    color = cmap(norm(l))
                    color_im = np.asarray(color)[None, None, :3]
                    color_im = np.broadcast_to(color_im, (h, w, 3))
                    # -- To make the digits transparent we make them solid color images and use the
                    #    actual data as an alpha channel.
                    #    color_im: 3-channel color image, with solid color corresponding to class
                    #    alpha_im: 1-channel grayscale image corrsponding to input data

                    im = np.concatenate([color_im, alpha_im], axis=2)
                    plt.imshow(im, extent=(x, x + size, y, y + size))

                    plt.xlim(mn, mx)
                    plt.ylim(mn, mx)

                plt.savefig(f'latent.{arg.rloss}.{epoch:03}.png')
Пример #10
0
 def __init__(self, mu, std):
     super(Laplace, self).__init__()
     self.dist = distributions.Laplace(mu, std)
Пример #11
0
 def sample(self, batch_size):
     return dists.Laplace(self.loc, self.scale).rsample((batch_size, ))