def _d_kl_beta(p, q):
    alpha_p, beta_p = convert_parameters_beta(p)
    alpha_q, beta_q = convert_parameters_beta(q)
    dist_p = dist.Beta(alpha_p, beta_p)
    dist_q = dist.Beta(alpha_q, beta_q)
    d_kl = dist.kl_divergence(dist_p, dist_q).mean(-1)
    return d_kl
Exemple #2
0
    def forward(self, x):
        """
        This function takes a batch of data `x` and returns:
        - nll: -\E_q[log p(x | z, A)]
        - q_z: a torch Bernoulli Distribution for q(z)
        - p_z: a torch Bernoulli Distribution for p(z | nu) *where nu ~ q(nu)* [because it's for the KL divergence]
        - q_nu: a torch Beta Distribution for q(nu)
        - p_nu: a torch Beta Distribution for p(nu)
        - q_a: a torch Normal Distribution (univariate / diagonal) for q(A)
        - p_a: a torch Normal Distribution for p(A)

        The negative ELBO can be computed as:
        -ELBO = nll + KL(q_z || p_z) + KL(q_nu || p_nu)
        """
        batch_sz = x.size()[0]

        # p(nu)
        sz = self.beta_a.size()
        p_nu = distributions.Beta(torch.ones(sz) * self.alpha0, torch.ones(sz))

        # compute q(nu) parameters, and take samples
        beta_a = F.softplus(self.beta_a) + 0.01
        beta_b = F.softplus(self.beta_b) + 0.01
        q_nu = distributions.Beta(beta_a, beta_b)

        nu = q_nu.rsample()  # NOTE: differentiable sample! via Knowles et al.

        # p(z | nu)
        logpi = torch.cumsum((nu + SMALL).log(),
                             dim=-1).unsqueeze(0).repeat(batch_sz, 1)
        p_z = distributions.Bernoulli(probs=logpi.exp())

        # q(z)
        # machine/fp precision is higher near 0 than at 1 (crucial)
        probs = F.sigmoid(torch.clamp(self.encoder(x.view(-1, self.D)), -25,
                                      9))

        q_z = shared.STRelaxedBernoulli(temperature=0.1, probs=probs)
        # q_z = distributions.RelaxedBernoulli(temperature=0.2, probs=probs)
        z = q_z.rsample()
        q_z = distributions.Bernoulli(probs=probs)
        # self.z_log_prob = q_z.log_prob(z)  # save for later

        # p(A)
        p_a = distributions.Normal(loc=0,
                                   scale=1)  # NOTE: this is broadcast up

        # q(A) - this is wrong, it normalizes the wrong thing
        q_a = distributions.Normal(loc=self.A_mean,
                                   scale=(self.A_logvar / 2).exp())

        A = self.A_mean
        # A = q_a.rsample()

        # now compute NLL:
        x_mean = torch.mm(z, A)
        nll = -(distributions.Normal(loc=x_mean,
                                     scale=self.sigma_n).log_prob(x))

        return nll, p_nu, q_nu, p_z, q_z, p_a, q_a
Exemple #3
0
    def kl(self, dist_a, prior=None):
        if prior is None:  # use standard reparamterizer
            return self._kld_beta_kerman_prior(
                dist_a['beta']['conc1'], dist_a['beta']['conc2']
            )

        # we have two distributions provided (eg: VRNN)
        return torch.sum(D.kl_divergence(
            D.Beta(dist_a['beta']['conc1'], dist_a['beta']['conc2']),
            D.Beta(prior['beta']['conc1'], prior['beta']['conc2'])
        ), -1)
Exemple #4
0
    def mutual_info(self, params, eps=1e-9):
        """ I(z_d; x) ~ H(z_prior, z_d) + H(z_prior)

        :param params: parameters of distribution
        :param eps: tolerance
        :returns: batch_size mutual information (prop-to) tensor.
        :rtype: torch.Tensor

        """
        z_true = D.Beta(params['beta']['conc1'],
                        params['beta']['conc2'])
        z_match = D.Beta(params['q_z_given_xhat']['beta']['conc1'],
                         params['q_z_given_xhat']['beta']['conc2'])
        kl_proxy_to_xent = torch.sum(D.kl_divergence(z_match, z_true), dim=-1)
        return self.config['continuous_mut_info'] * kl_proxy_to_xent
Exemple #5
0
    def _reparametrize_beta(self, conc1, conc2, force=False):
        """ Internal function to reparameterize beta distribution using concentrations.

        :param conc1: concentration 1
        :param conc2: concentration 2
        :returns: reparameterized sample, distribution params
        :rtype: torch.Tensor, dict

        """
        if self.training or force:
            beta = D.Beta(conc1, conc2).rsample()
            return beta, {'conc1': conc1, 'conc2': conc2}

        # can't use mean like in gaussian because beta mean can be > 1.0
        return D.Beta(conc1, conc2).sample(), {'conc1': conc1, 'conc2': conc2}
Exemple #6
0
    def _kld_beta_kerman_prior(self, conc1, conc2):
        """ Internal function to do a KL-div against the prior.

        :param conc1: concentration 1.
        :param conc2: concentration 2.
        :returns: batch_size tensor of kld against prior.
        :rtype: torch.Tensor

        """
        # prior = D.Beta(zeros_like(conc1) + 1/3,
        #                zeros_like(conc2) + 1/3)
        prior = D.Beta(zeros_like(conc1) + 1.1,
                       zeros_like(conc2) + 1.1)
        beta = D.Beta(conc1, conc2)
        return torch.sum(D.kl_divergence(beta, prior), -1)
 def policy_to_action(self, alpha, beta):
     # alpha and beta must be non-negative float
     eps = 1e-6  # to avoid inf and nan
     p = dist.Beta(alpha + eps, beta + eps)
     action = p.sample()
     log_prob = p.log_prob(action)
     return action, log_prob
Exemple #8
0
 def _remix(x: List[torch.Tensor],
            remix_alpha: float) -> List[torch.Tensor]:
     # Create random Permutation index in range 0 -> length of the mini-batch.
     idx = torch.randperm(x[0].shape[0])
     # Create beta dist over shape of alpha and sample from it
     mix = dist.Beta(remix_alpha + 1, remix_alpha).sample_n(x[0].shape[0])
     x = [(mix * t) + ((1 - mix) * t[idx]) for t in x]
     return x
Exemple #9
0
def stick_breaking(alpha0, k):
    """ This function breaks a stick into k pieces """
    betas = dist.Beta(torch.tensor([1.]),
                      torch.tensor([alpha0])).sample([k]).squeeze()
    remains = torch.cat(
        (torch.tensor([1.]), torch.cumprod(1 - betas[:-1], dim=0)), 0)
    p = betas * remains
    p /= p.sum()
    return p
Exemple #10
0
    def forward(self, x):

        batch_sz = x.size()[0]
        sz = self.q_pi_a.size()

        p_pi = distributions.Beta(
            torch.ones(sz) * self.p_pi_alpha,
            torch.ones(sz) * self.p_pi_beta)

        beta_a = F.softplus(self.q_pi_alpha) + 0.01
        beta_b = F.softplus(self.q_pi_beta) + 0.01
        q_pi = distributions.Beta(beta_a, beta_b)

        # Differentiable Sample Knowles et al.
        qpi_sample = q_pi.rsample()
        q_z = shared.STRelaxedBernoulli(temperature=0.1, probs=qpi_sample)
        z = q_z.rsample()
        q_z = distributions.Bernoulli(probs=qpi_sample)

        q_phi = distributions.Normal(loc=self.phi_mean,
                                     scale=(self.phi_logvar / 2).exp())
        q_w = distributions.Normal(loc=self.w_mean,
                                   scale=(self.w_logvar / 2).exp())

        # For now, just take the mean
        phi = q_phi.mean
        w = q_w.mean

        # Alternatively, sample
        # phi = q_phi.rsample()
        # w = q_w.rsample()

        # NLL
        sinbasis = torch.ones(K, N_SAMPLES) * torch.arange(0, N_SAMPLES, 1)

        for k in range(K):
            sinbasis[k] = torch.sin(sinbasis[k] * phi[k])

        x_mean = torch.mm(torch.mul(z, w),
                          sinbasis)  # z and w multiplied elementwise
        nll = -(distributions.Normal(loc=x_mean,
                                     scale=self.sigma_n).log_prob(x))
        return nll, p_pi, q_pi, q_z, q_phi, q_w, sinbasis
Exemple #11
0
    def log_likelihood(self, z, params):
        """ Log-likelihood of z induced under params.

        :param z: inferred latent z
        :param params: the params of the distribution
        :returns: log-likelihood
        :rtype: torch.Tensor

        """
        return D.Beta(params['beta']['conc1'],
                      params['beta']['conc2']).log_prob(z)
Exemple #12
0
    def sample_params(self, n_sample=torch.Size([])):
        clusters = self.cluster_distr.rsample(n_sample)
        params = self.cluster_to_params_graph(clusters)

        alpha_hsl0 = F.softplus(params[0:3])
        beta_hsl0 = F.softplus(params[3:6])
        hsl0 = td.Beta(alpha_hsl0, beta_hsl0).rsample()

        alpha_hsl1 = F.softplus(params[6:9])
        beta_hsl1 = F.softplus(params[9:12])
        hsl1 = td.Beta(alpha_hsl1, beta_hsl1).rsample()

        shape_trans01 = F.softplus(params[12:15])
        scale_trans01 = F.softplus(params[15:18])
        trans01 = td.Gamma(shape_trans01, scale_trans01).rsample()

        shape_trans10 = F.softplus(params[18:21])
        scale_trans10 = F.softplus(params[21:24])
        trans10 = td.Gamma(shape_trans10, scale_trans10).rsample()

        return hsl0, hsl1, trans01, trans10
Exemple #13
0
    def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1, dropout=0.0):
        """Creat a beta layer.

        Args:
            out_channels: Number of parallel representations for each input feature.
            in_features: Number of input features.
            num_repetitions: Number of parallel repetitions of this layer.

        """
        super().__init__(in_features, out_channels, num_repetitions, dropout)

        # Create beta parameters
        self.concentration0 = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions))
        self.concentration1 = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions))
        self.beta = dist.Beta(concentration0=self.concentration0, concentration1=self.concentration1)
Exemple #14
0
    def __init__(self, multiplicity, in_features, dropout=0.0):
        """Creat a beta layer.

        Args:
            multiplicity: Number of parallel representations for each input feature.
            in_features: Number of input features.

        """
        super().__init__(multiplicity, in_features, dropout)

        # Create beta parameters
        self.concentration0 = nn.Parameter(torch.rand(1, in_features, multiplicity))
        self.concentration1 = nn.Parameter(torch.rand(1, in_features, multiplicity))
        self.beta = dist.Beta(
            concentration0=self.concentration0, concentration1=self.concentration1
        )
    def with_beta_dist(
        cls: type[RandomMixUp],
        alpha: float = 0.2,
        *,
        beta: float | None = None,
        mode: MixUpMode | str = MixUpMode.linear,
        p: float = 1.0,
        num_classes: int | None = None,
        inplace: bool = False,
        featurewise: bool = False,
    ) -> RandomMixUp[td.Beta]:
        """
        Instantiate a :class:`RandomMixUp` with a Beta-distribution sampler.

        :param alpha: 1st concentration parameter of the distribution. Must be positive
        :param beta:  2nd concentration parameter of the distribution. Must be positive.
            If ``None``, then the parameter will be set to ``alpha``.

        :param mode: Which mode to use to mix up samples: geometric or linear.

        .. note::
            The (weighted) geometric mean, enabled by ``mode=geometric``, is only valid for positive
            inputs.

        :param p: The probability with which the transform will be applied to a given sample.
        :param num_classes: The total number of classes in the dataset that needs to be specified if
            wanting to mix up targets that are label-enoded. Passing label-encoded targets without
            specifying ``num_classes`` will result in a RuntimeError.
        :param featurewise: Whether to sample sample feature-wise instead of sample-wise.
        :param inplace: Whether the transform should be performed in-place.
        :return: A :class:`RandomMixUp` instance with ``lambda_sampler`` set to a  Beta-distribution
            with ``concentration1=alpha`` and ``concentration0=beta``.
        """
        beta = alpha if beta is None else beta
        lambda_sampler = td.Beta(concentration0=alpha, concentration1=beta)
        return cls(
            lambda_sampler=lambda_sampler,
            mode=mode,
            p=p,
            num_classes=num_classes,
            inplace=inplace,
            featurewise=featurewise,
        )
Exemple #16
0
    def __init__(
        self,
        alpha: float = 1.0,
        *,
        p: float = 0.5,
        num_classes: int | None = None,
        inplace: bool = False,
        seed: Optional[int] = None,
    ) -> None:
        """
        :param alpha: hyperparameter of the Beta distribution used for sampling the areas
            of the bounding boxes.

        :param num_classes: The total number of classes in the dataset that needs to be specified if
            wanting to mix up targets that are label-enoded. Passing label-encoded targets without
            specifying ``num_classes`` will result in a RuntimeError.

        :param p: The probability with which the transform will be applied to a given sample.

        :param inplace: Whether the transform should be performed in-place.

        :param seed: The PRNG seed to use for sampling pairs and bounding-box coordinates.

        :raises ValueError: if ``p`` is not in the range [0, 1] , if ``num_classes < 1``, or if
            ``alpha`` is not a positive real number.

        """
        super().__init__()
        if not 0 <= p <= 1:
            raise ValueError("'p' must be in the range [0, 1].")
        self.p = p
        if alpha < 0:
            raise ValueError("'alpha' must be positive.")
        self.alpha = alpha
        if (num_classes is not None) and num_classes < 1:
            raise ValueError(f"{ num_classes } must be greater than 1.")
        self.lambda_sampler = td.Beta(concentration0=alpha,
                                      concentration1=alpha)
        self.num_classes = num_classes
        self.inplace = inplace
        self.seed = seed
Exemple #17
0
    def prior(self, batch_size, **kwargs):
        """ Returns a Kerman beta prior.

        Kerman, J. (2011). Neutral noninformative and informative
        conjugate beta and gamma prior distributions. Electronic
        Journal of Statistics, 5, 1450-1470.

        :param batch_size: the number of prior samples
        :returns: prior
        :rtype: torch.Tensor

        """
        conc1 = Variable(
            same_type(self.config['half'], self.config['cuda'])(
                batch_size, self.output_size
            ).zero_() + 1/3
        )
        conc2 = Variable(
            same_type(self.config['half'], self.config['cuda'])(
                batch_size, self.output_size
            ).zero_() + 1/3
        )
        return D.Beta(conc1, conc2).sample()
 def act(self, x):
     VARIANCE = 0.25
     # Set a 4D shape
     x = x.view(1, 1, self.IMAGE_SIZE, self.IMAGE_SIZE)
     # First get the hidden representation
     mu, log_sigma = self.encode(x)
     # Compute alpha and beta for the beta distribution
     z = F.relu(self.linear1(mu))
     if self.action_dist == 'beta':
         alpha = F.softplus(self.linear2(z)) + 1
         beta = F.softplus(self.linear3(z)) + 1
         # Sample the beta distribution
         a_dist = dist.Beta(alpha, beta)
         actions = a_dist.sample()[0]
         log_proba = torch.sum(a_dist.log_prob(actions))
         # Now move the 3 beta samples in the action space
         # Note: only the first action, steer, must be rescaled
         actions[0] = actions[0] * 2 - 1
     elif self.action_dist == 'gaussian':
         raise Exception('TODO')
     else:
         raise Exception('Unrecognized action distribution.')
     return actions.numpy(), log_proba
Exemple #19
0
def go(arg):

    tbw = SummaryWriter(log_dir=arg.tb_dir)

    ## Load the data
    if arg.task == 'mnist':
        transform = tfs.Compose([tfs.Pad(padding=2), tfs.ToTensor()])

        trainset = torchvision.datasets.MNIST(root=arg.data_dir,
                                              train=True,
                                              download=True,
                                              transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.MNIST(root=arg.data_dir,
                                             train=False,
                                             download=True,
                                             transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 32, 32

    elif arg.task == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                                train=True,
                                                download=True,
                                                transform=tfs.ToTensor())
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                               train=False,
                                               download=True,
                                               transform=tfs.ToTensor())
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 3, 32, 32

    elif arg.task == 'cifar-gs':
        transform = tfs.Compose([tfs.Grayscale(), tfs.ToTensor()])

        trainset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 32, 32

    elif arg.task == 'imagenet64':

        transform = tfs.Compose([tfs.ToTensor()])

        trainset = torchvision.datasets.ImageFolder(root=arg.data_dir +
                                                    os.sep + 'train',
                                                    transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep +
                                                   'valid',
                                                   transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)

        C, H, W = 3, 64, 64

    else:
        raise Exception('Task {} not recognized.'.format(arg.task))

    ## Set up the model
    out_channels = C
    if (arg.rloss == 'gauss' or arg.rloss == 'laplace'
            or arg.rloss == 'signorm' or arg.rloss == 'siglaplace'
            or arg.rloss == 'beta') and arg.scale is None:
        out_channels = 2 * C

    print(f'out channels: {out_channels}')

    encoder = Encoder(zsize=arg.zsize, colors=C)
    decoder = Decoder(zsize=arg.zsize,
                      out_channels=out_channels,
                      mult=arg.mult)

    if arg.testmodel:
        decoder = Test(out_channels=out_channels, height=H, width=W)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    opt = torch.optim.Adam(lr=arg.lr,
                           params=list(encoder.parameters()) +
                           list(decoder.parameters()))

    if arg.esched is not None:
        start, end = int(arg.esched[0] * arg.epochs), (arg.esched[1] *
                                                       arg.epochs)
        slope = 1.0 / (end - start)

    for epoch in range(arg.epochs):

        if arg.esched is not None:
            weight = (epoch - start) * slope
            weight = np.clip(weight, 0, 1)
        else:
            weight = 1.0

        for i, (input, _) in enumerate(tqdm.tqdm(trainloader)):
            if arg.limit is not None and i * arg.batch_size > arg.limit:
                break

                # Prepare the input
            b, c, w, h = input.size()
            if torch.cuda.is_available():
                input = input.cuda()

            # Forward pass
            if not arg.testmodel:
                zs = encoder(input)

                kloss = kl_loss(zs[:, :arg.zsize], zs[:, arg.zsize:])
                z = sample(zs[:, :arg.zsize], zs[:, arg.zsize:])

                out = decoder(z)
            else:
                out = decoder(input)
                kloss = 0

            # compute -log p per dimension
            if arg.rloss == 'xent':  # binary cross-entropy (not a proper log-prob)

                rloss = F.binary_cross_entropy_with_logits(out,
                                                           input,
                                                           reduction='none')

            elif arg.rloss == 'bdist':  #   xent + correction
                rloss = F.binary_cross_entropy_with_logits(out,
                                                           input,
                                                           reduction='none')

                za = out.abs()
                eza = (-za).exp()

                # - np.log(za) + np.log1p(-eza + EPS) - np.log1p(eza + EPS)
                logpart = -(za + arg.eps).log() + (-eza + arg.eps).log1p() - (
                    eza + arg.eps).log1p()

                rloss = rloss + weight * logpart

            elif arg.rloss == 'gauss':  # xent + correction
                if arg.scale is None:
                    means = T.sigmoid(out[:, :c, :, :])
                    vars = F.sigmoid(out[:, c:, :, :])

                    rloss = GAUSS_CONST + vars.log() + (
                        1.0 / (2.0 * vars.pow(2.0))) * (input - means).pow(2.0)
                else:
                    means = T.sigmoid(out[:, :c, :, :])
                    var = arg.scale

                    rloss = GAUSS_CONST + ln(
                        var) + (1.0 / (2.0 *
                                       (var * var))) * (input - means).pow(2.0)

            elif arg.rloss == 'mse':
                means = T.sigmoid(out[:, :c, :, :])
                rloss = (input - means).pow(2.0)

            elif arg.rloss == 'mae':
                means = T.sigmoid(out[:, :c, :, :])
                rloss = (input - means).abs()

            elif arg.rloss == 'laplace':  # xent + correction
                if arg.scale is None:
                    means = T.sigmoid(out[:, :c, :, :])
                    vars = F.softplus(out[:, c:, :, :])

                    rloss = (2.0 * vars).log() + (1.0 / vars) * (input -
                                                                 means).abs()
                else:
                    means = T.sigmoid(out[:, :c, :, :])
                    var = arg.scale

                    rloss = ln(2.0 * var) + (1.0 / var) * (input - means).abs()

            elif arg.rloss == 'signorm':
                if arg.scale is None:

                    mus = out[:, :c, :, :]
                    sgs, lsgs = T.exp(
                        out[:, c:, :, :] *
                        arg.varmult), out[:, c:, :, :] * arg.varmult

                else:
                    mus = out[:, :c, :, :]
                    sgs, lsgs = arg.scale, math.log(arg.scale)

                y = input

                lny = torch.log(y + arg.eps)
                ln1y = torch.log(1 - y + arg.eps)

                x = lny - ln1y

                rloss = lny + ln1y + lsgs + GAUSS_CONST + \
                        0.5 * (1.0 / (sgs * sgs + arg.eps)) * (x - mus) ** 2

            elif arg.rloss == 'siglaplace':

                if arg.scale is None:

                    mus = out[:, :c, :, :]
                    sgs, lsgs = T.exp(
                        out[:, c:, :, :] *
                        arg.varmult), out[:, c:, :, :] * arg.varmult

                else:
                    mus = out[:, :c, :, :]
                    sgs, lsgs = arg.scale, math.log(arg.scale)

                y = input

                lny = torch.log(y + arg.eps)
                ln1y = torch.log(1 - y + arg.eps)

                x = lny - ln1y

                rloss = lny + ln1y + lsgs + math.log(2.0) + \
                        (x - mus).abs() / sgs

            elif arg.rloss == 'beta':

                mean = T.sigmoid(out[:, :c, :, :])
                mult = F.softplus(out[:, c:, :, :] +
                                  arg.beta_add) + (1.0 /
                                                   (mean + arg.eps)) + arg.eps

                alpha = mean * mult
                beta = (1 - mean) * mult

                part = alpha.lgamma() + beta.lgamma() - (alpha + beta).lgamma()
                x = input

                rloss = -(alpha - 1) * (x + arg.eps).log() - (beta - 1) * (
                    1 - x + arg.eps).log() + part

            else:
                raise Exception(
                    f'reconstruction loss {arg.rloss} not recognized.')

            if contains_nan(rloss):
                if arg.rloss == 'beta':
                    print('part contains nan', contains_nan(part))

                    print('alpha contains nan', contains_nan(alpha))
                    print('beta  contains nan', contains_nan(beta))

                    print('log x contains nan',
                          contains_nan((x + arg.eps).log()))
                    print('log (1-x)  contains nan',
                          contains_nan((1 - x + arg.eps).log()))

                raise Exception('rloss contains nan')

            rloss = rloss.reshape(b, -1).sum(dim=1)  # reduce
            loss = (rloss + kloss).mean()

            opt.zero_grad()
            loss.backward()

            opt.step()

        with torch.no_grad():
            N = 5

            # Plot reconstructions

            inputs, _ = next(iter(testloader))

            if torch.cuda.is_available():
                inputs = inputs.cuda()

            b, c, h, w = inputs.size()

            if not arg.testmodel:
                zs = encoder(inputs)
                res = decoder(zs[:, :arg.zsize])
            else:
                res = decoder(inputs)

            outputs = res[:, :c, :, :]
            means = T.sigmoid(outputs)

            samples = None

            if arg.rloss == 'signorm' and out_channels > c:
                means = res[:, :c, :, :]
                vars = res[:, c:, :, :] * arg.varmult

                dist = ds.Normal(means, vars)
                samples = T.sigmoid(dist.sample())
                means = T.sigmoid(dist.mean)

            if arg.rloss == 'siglaplace' and out_channels > c:
                means = res[:, :c, :, :]
                vars = res[:, c:, :, :] * arg.varmult

                dist = ds.Laplace(means, vars)
                samples = T.sigmoid(dist.sample())
                means = T.sigmoid(dist.mean)

            if arg.rloss == 'beta':

                mean = T.sigmoid(res[:, :c, :, :])
                mult = (res[:, c:, :, :] +
                        arg.beta_add).exp() + (1.0 / mean) + arg.eps

                alpha = mean * mult
                beta = (1 - mean) * mult

                dist = ds.Beta(alpha, beta)
                samples = dist.sample()
                means = dist.mean
                vars = dist.variance

            plt.figure(figsize=(5, 4))

            for i in range(N):

                ax = plt.subplot(4, N, i + 1)
                inp = inputs[i].permute(1, 2, 0).cpu().numpy()
                if c == 1:
                    inp = inp.squeeze()

                ax.imshow(inp, cmap='gray_r')

                if i == 0:
                    ax.set_title('input')
                plt.axis('off')

                ax = plt.subplot(4, N, N + i + 1)

                outp = means[i].permute(1, 2, 0).cpu().numpy()
                if c == 1:
                    outp = outp.squeeze()

                ax.imshow(outp, cmap='gray_r')

                if i == 0:
                    ax.set_title('means/modes')
                plt.axis('off')

                if samples is not None:  # plot samples

                    ax = plt.subplot(4, N, 2 * N + i + 1)

                    outp = samples[i].permute(1, 2, 0).detach().cpu().numpy()
                    if c == 1:
                        outp = outp.squeeze()

                    ax.imshow(outp, cmap='gray_r')

                    if i == 0:
                        ax.set_title('sampled')
                    plt.axis('off')

                if out_channels > c:  # plot the variance (or other uncertainty)

                    ax = plt.subplot(4, N, 3 * N + i + 1)

                    outp = vars[i].permute(1, 2, 0).detach().cpu().numpy()
                    if c == 1:
                        outp = outp.squeeze()

                    ax.imshow(outp, cmap='copper')

                    if i == 0:
                        ax.set_title('var')
                    plt.axis('off')

            plt.tight_layout()
            plt.savefig(f'reconstruction.{arg.rloss}.{epoch:03}.png')

            if arg.zsize == 2:  # latent space plot

                N = 2000
                # gather up first 200 batches into one big tensor
                numbatches = N // arg.batch_size
                images, labels = [], []
                for i, (ims, lbs) in enumerate(testloader):
                    images.append(ims)
                    labels.append(lbs)

                    if i > numbatches:
                        break

                images, labels = torch.cat(images, dim=0), torch.cat(labels,
                                                                     dim=0)

                imagesg = images
                if torch.cuda.is_available():
                    imagesg = imagesg.cuda()

                n, c, h, w = images.size()

                z = encoder(imagesg)
                latents = z[:, :2].data.detach().cpu()

                mn, mx = latents.min(), latents.max()
                size = 1.0 * (mx - mn) / math.sqrt(n)
                # Change 0.75 to any value between ~ 0.5 and 1.5 to make the digits smaller or bigger

                fig = plt.figure(figsize=(8, 8))

                # colormap for the images
                norm = mpl.colors.Normalize(vmin=0, vmax=9)
                cmap = mpl.cm.get_cmap('tab10')

                for i in range(n):
                    x, y = latents[i, 0:2]
                    l = labels[i]

                    im = images[i, :]
                    alpha_im = im.permute(1, 2, 0).detach().cpu().numpy()
                    color = cmap(norm(l))
                    color_im = np.asarray(color)[None, None, :3]
                    color_im = np.broadcast_to(color_im, (h, w, 3))
                    # -- To make the digits transparent we make them solid color images and use the
                    #    actual data as an alpha channel.
                    #    color_im: 3-channel color image, with solid color corresponding to class
                    #    alpha_im: 1-channel grayscale image corrsponding to input data

                    im = np.concatenate([color_im, alpha_im], axis=2)
                    plt.imshow(im, extent=(x, x + size, y, y + size))

                    plt.xlim(mn, mx)
                    plt.ylim(mn, mx)

                plt.savefig(f'latent.{arg.rloss}.{epoch:03}.png')
Exemple #20
0
 def __init__(self, a):
     dist = dists.Beta(a[0], a[1])
     super().__init__(dist, "beta", 2, a[0], a[1])