Ejemplo n.º 1
0
    def reparameterize(self, logits_map):
        """ reparameterize the encoder output and the prior

        :param logits_map: the map of logits
        :returns: a dict of reparameterized things
        :rtype: dict

        """
        nan_check_and_break(logits_map['encoder_logits'], "enc_logits")
        nan_check_and_break(logits_map['prior_logits'], "prior_logits")
        z_enc_t, params_enc_t = self.reparameterizer(
            logits_map['encoder_logits'])

        # XXX: clamp the variance of gaussian priors to not explode
        logits_map['prior_logits'] = self._clamp_variance(
            logits_map['prior_logits'])

        # reparamterize the prior distribution
        z_prior_t, params_prior_t = self.reparameterizer(
            logits_map['prior_logits'])

        z = {  # reparameterization
            'prior': z_prior_t,
            'posterior': z_enc_t,
            'x_features': logits_map['x_features']
        }
        params = {  # params of the reparameterization
            'prior': params_prior_t,
            'posterior': params_enc_t
        }

        return z, params
Ejemplo n.º 2
0
    def step(self, x_t, inference_only=False, **kwargs):
        """ Single step forward pass.

        :param x_i: the input tensor for time t
        :param inference_only: whether or not to run the decoding process
        :returns: decoded for current time and params for current time
        :rtype: torch.Tensor, torch.Tensor

        """
        x_t_inference = add_noise_to_imgs(x_t) \
            if self.config['add_img_noise'] else x_t     # add image quantization noise
        z_t, params_t = self.posterior(x_t_inference)

        # sanity checks
        nan_check_and_break(x_t_inference, "x_related_inference")
        nan_check_and_break(z_t['prior'], "prior")
        nan_check_and_break(z_t['posterior'], "posterior")
        nan_check_and_break(z_t['x_features'], "x_features")

        if not inference_only:  # decode the posterior
            decoded_t = self.decode(z_t, produce_output=True)
            nan_check_and_break(decoded_t, "decoded_t")
            return decoded_t, params_t

        return None, params_t
Ejemplo n.º 3
0
    def loss_function(self, x, labels, output_map):
        ''' loss is: L_{classifier} * L_{VAE} '''
        vae_loss_map = self.vae.loss_function(
            output_map['decoded'],
            [x.clone()
             for _ in range(len(output_map['decoded']))], output_map['params'])

        # get classifier loss, use BCE with logits if multi-dimensional
        loss_fn = F.binary_cross_entropy_with_logits \
            if len(labels.shape) > 1 else F.cross_entropy
        pred_loss = loss_fn(input=output_map['preds'],
                            target=labels,
                            reduction='none')
        pred_loss = torch.sum(pred_loss,
                              -1) if len(pred_loss.shape) > 1 else pred_loss
        nan_check_and_break(pred_loss, "pred_loss")

        # ACT loss
        vae_loss_map['saccades_scalar'] = output_map['saccades_scalar']
        act_loss = torch.zeros_like(pred_loss)
        # act_loss = F.binary_cross_entropy_with_logits(
        #     input=output_map['act'],
        #     target=torch.ones_like(output_map['act']),
        #     reduction='none'
        # )

        # TODO: try multi-task loss, ie:
        # vae_loss_map['loss'] = vae_loss_map['loss'] + (pred_loss + act_loss)
        vae_loss_map['loss'] = vae_loss_map['loss'] * (pred_loss + act_loss)

        # compute the means for visualizations of bp
        vae_loss_map['act_loss_mean'] = torch.mean(act_loss)
        vae_loss_map['pred_loss_mean'] = torch.mean(pred_loss)
        vae_loss_map['loss_mean'] = torch.mean(vae_loss_map['loss'])
        return vae_loss_map
Ejemplo n.º 4
0
    def decode(self, z_t, produce_output=False, update_memory=True):
        """ decodes using VRNN

        :param z_t: the latent sample
        :param produce_output: produce output or just update stae
        :returns: decoded logits
        :rtype: torch.Tensor

        """
        # grab state from RNN, TODO: evaluate recovery methods below
        # [0] grabs the h from LSTM (as opposed to (h, c))
        final_state = torch.mean(self.memory.get_state()[0], 0)

        # feature transform for z_t
        phi_z_t = self.phi_z(z_t['posterior'])

        # sanity checks
        nan_check_and_break(final_state, "final_rnn_output[decode]")
        nan_check_and_break(phi_z_t, "phi_z_t")

        if update_memory:  # concat and run through RNN to update state
            input_t = torch.cat([z_t['x_features'], phi_z_t], -1).unsqueeze(0)
            self.memory(input_t.contiguous())

        # decode only if flag is set
        dec_t = None
        if produce_output:
            dec_input_t = torch.cat([phi_z_t, final_state], -1)
            dec_t = self.decoder(dec_input_t)

        return dec_t
Ejemplo n.º 5
0
    def reparmeterize(self, logits):
        """ Given logits reparameterize to a gaussian using
            first half of features for mean and second half for std.

        :param logits: unactivated logits
        :returns: reparameterized tensor (if training), param dict
        :rtype: torch.Tensor, dict

        """
        eps = eps_fn(self.config['half'])
        feature_size = logits.size(-1)
        assert feature_size % 2 == 0 and feature_size // 2 == self.output_size
        if logits.dim() == 2:
            mu = logits[:, 0:int(feature_size / 2)]
            nan_check_and_break(mu, "mu")
            sigma = logits[:, int(feature_size / 2):] + eps
            # sigma = F.softplus(logits[:, int(feature_size/2):]) + eps
            # sigma = F.hardtanh(logits[:, int(feature_size/2):], min_val=-6.,max_val=2.)
        elif logits.dim() == 3:
            mu = logits[:, :, 0:int(feature_size / 2)]
            sigma = logits[:, :, int(feature_size / 2):] + eps
        else:
            raise Exception(
                "unknown number of dims for isotropic gauss reparam")

        return self._reparametrize_gaussian(mu, sigma)
Ejemplo n.º 6
0
    def loss_function(self, recon_x, x, params, K=1, **extra_loss_terms):
        """ Produces ELBO.

        :param recon_x: the unactivated reconstruction preds.
        :param x: input tensor.
        :param params: the dict of reparameterization.
        :param K: number of monte-carlo samples to use.
        :param extra_loss_terms: kwargs of extra [B] dimensional losses
        :returns: loss dict
        :rtype: dict

        """
        nll = self.nll(x, recon_x, self.config['nll_type'])

        # multiple monte-carlo samples for the decoder.
        if self.training:
            for k in range(1, K):
                z_k, params_k = self.reparameterize(logits=params['logits'],
                                                    labels=params.get(
                                                        'labels', None))
                recon_x_i = self.decode(z_k)
                nll = nll + self.nll(x, recon_x_i, self.config['nll_type'])

            nll = nll / K

        kld = self.kld(params)
        elbo = nll + kld  # save the base ELBO, but use the beta-vae elbo for the full loss

        # handle the mutual information term
        mut_info = self.mut_info(params, x.size(0))

        # get the kl-beta from the annealer or just set to fixed value
        kl_beta = self.compute_kl_beta([self.config['kl_beta']])[0]

        # sanity checks only dont in fp32 due to too much fp16 magic
        if not self.config['half']:
            utils.nan_check_and_break(nll, "nll")
            if kl_beta > 0:  # only check if we have a KLD
                utils.nan_check_and_break(kld, "kld")

        # if we are provided additional losses add them together
        additional_losses = torch.sum(
            torch.cat([v.unsqueeze(0) for v in extra_loss_terms.values()], 0), 0) \
            if extra_loss_terms else torch.zeros_like(nll)

        # compute full loss to use for optimization
        loss = (nll + additional_losses + kl_beta * kld) - mut_info
        return {
            'loss': loss,
            'elbo': elbo,
            'loss_mean': torch.mean(loss),
            'elbo_mean': torch.mean(elbo),
            'nll_mean': torch.mean(nll),
            'kld_mean': torch.mean(kld),
            'additional_loss_mean': torch.mean(additional_losses),
            'kl_beta_scalar': kl_beta,
            'mut_info_mean': torch.mean(mut_info)
        }
Ejemplo n.º 7
0
    def log_likelihood(self, z, params):
        cont = self.continuous.log_likelihood(z[:, 0:self.continuous.output_size], params)
        disc = self.discrete.log_likelihood(z[:, self.continuous.output_size:], params)
        if disc.dim() < 2:
            disc = disc.unsqueeze(-1)

        # sanity check and return
        nan_check_and_break(cont, 'cont_ll')
        nan_check_and_break(disc, 'disc_ll')

        return torch.cat([cont, disc], 1)
Ejemplo n.º 8
0
    def _reparametrize_gaussian(self, mu, logvar):
        """ Internal member to reparametrize gaussian.

        :param mu: mean logits
        :param logvar: log-variance.
        :returns: reparameterized tensor and param dict
        :rtype: torch.Tensor, dict

        """
        if self.training:  # returns a stochastic sample for training
            std = logvar.mul(0.5).exp()
            eps = same_type(is_half(logvar),
                            logvar.is_cuda)(logvar.size()).normal_()
            eps = Variable(eps)
            nan_check_and_break(logvar, "logvar")
            return eps.mul(std).add_(mu), {'mu': mu, 'logvar': logvar}

        return mu, {'mu': mu, 'logvar': logvar}
Ejemplo n.º 9
0
    def loss_function(self, recon_x, x, reparam_map):
        """ VAE with no KL objective. Still uses reparam.

        :param recon_x: the unactivated reconstruction preds.
        :param x: input tensor.
        :returns: loss dict
        :rtype: dict

        """
        nll = distributions.nll(x, recon_x, self.config['nll_type'])
        utils.nan_check_and_break(nll, "nll")
        return {
            'loss': nll,
            'loss_mean': torch.mean(nll),
            'elbo_mean': torch.mean(torch.zeros_like(nll)),
            'nll_mean': torch.mean(nll),
            'kld_mean': torch.mean(torch.zeros_like(nll)),
            'proxy_mean': torch.mean(torch.zeros_like(nll)),
            'mut_info_mean': torch.mean(torch.zeros_like(nll)),
        }
Ejemplo n.º 10
0
    def encode(self, x, *xargs):
        """ single sample encode using x

        :param x: the input tensor
        :returns: dict of encoded logits
        :rtype: dict

        """
        if self.config['decoder_layer_type'] == 'pixelcnn':
            x = (x - .5) * 2.

        # get the memory trace, TODO: evaluate different recovery methods below
        final_state = torch.mean(self.memory.get_state()[0], 0)
        nan_check_and_break(final_state, "final_rnn_output")

        # extract input data features
        phi_x_t = self._extract_features(x, *xargs).squeeze()

        # encoder projection
        enc_input_t = torch.cat([phi_x_t, final_state], dim=-1)
        enc_t = self._lazy_build_encoder(enc_input_t.size(-1))(enc_input_t)

        # prior projection , consider: + eps_fn(self.config['cuda']))
        prior_t = self.prior(final_state.contiguous())

        # sanity checks
        nan_check_and_break(enc_t, "enc_t")
        nan_check_and_break(prior_t, "priot_t")

        return {
            'encoder_logits': enc_t,
            'prior_logits': prior_t,
            'x_features': phi_x_t
        }
Ejemplo n.º 11
0
    def z_where_inv(z_where, clip_scale=5.0):
        # Take a batch of z_where vectors, and compute their "inverse".
        # That is, for each row compute:
        # [s,x,y] -> [1/s,-x/s,-y/s]
        # These are the parameters required to perform the inverse of the
        # spatial transform performed in the generative model.
        n = z_where.size(0)
        out = torch.cat((LocalizedSpatialTransformerFn.ng_ones(
            [1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)

        # Divide all entries by the scale. abs(scale) ensures images arent flipped
        scale = torch.max(torch.abs(z_where[:, 0:1]),
                          zeros_like(z_where[:, 0:1]) + clip_scale)
        if torch.sum(scale == 0) > 0:
            print("tensor scale of {} dim was 0!!".format(scale.shape))
            exit(-1)

        nan_check_and_break(scale, "scale")
        out = out / scale
        # out = out / z_where[:, 0:1]

        return out
Ejemplo n.º 12
0
    def step(self, x_i, inference_only=False):
        """ Single step forward pass.

        :param x_related: input tensor
        :param inference_only:
        :returns:
        :rtype:

        """
        x_i_inference = add_noise_to_imgs(x_i) \
            if self.config['add_img_noise'] else x_i             # add image quantization noise
        z_t, params_t = self.posterior(x_i_inference)
        nan_check_and_break(x_i_inference, "x_related_inference")
        nan_check_and_break(z_t['prior'], "prior")
        nan_check_and_break(z_t['posterior'], "posterior")
        nan_check_and_break(z_t['x_features'], "x_features")

        # decode the posterior
        decoded_t = self.decode(z_t, produce_output=True)
        nan_check_and_break(decoded_t, "decoded_t")

        return decoded_t, params_t
Ejemplo n.º 13
0
    def loss_function(self, recon_x, x, params):
        """ Produces ELBO, handles mutual info and proxy loss terms too.

        :param recon_x: the unactivated reconstruction preds.
        :param x: input tensor.
        :param params: the dict of reparameterization.
        :param mut_info: the calculated mutual info.
        :returns: loss dict
        :rtype: dict

        """
        if self.config['decoder_layer_type'] == 'pixelcnn':
            x = (x - .5) * 2.

        nll = nll_fn(x, recon_x, self.config['nll_type'])
        nan_check_and_break(nll, "nll")
        kld = self.kld(params)
        nan_check_and_break(kld, "kld")
        elbo = nll + kld  # save the base ELBO, but use the beta-vae elbo for the full loss

        # add the proxy loss if it exists
        proxy_loss = self.reparameterizer.proxy_layer.loss_function() \
            if hasattr(self.reparameterizer, 'proxy_layer') else torch.zeros_like(elbo)

        # handle the mutual information term
        mut_info = self.mut_info(params, x.size(0))

        loss = (nll + self.config['kl_beta'] * kld) - mut_info
        return {
            'loss': loss,
            'loss_mean': torch.mean(loss),
            'elbo_mean': torch.mean(elbo),
            'nll_mean': torch.mean(nll),
            'kld_mean': torch.mean(kld),
            'proxy_mean': torch.mean(proxy_loss),
            'mut_info_mean': torch.mean(mut_info)
        }
Ejemplo n.º 14
0
    def loss_function(self, recon_x, x, **unused_kwargs):
        """ Autoencoder is simple the NLL term in the VAE.

        :param recon_x: the unactivated reconstruction preds.
        :param x: input tensor.
        :returns: loss dict
        :rtype: dict

        """
        nll = distributions.nll(x, recon_x, self.config['nll_type'])
        if not self.config['half']:
            utils.nan_check_and_break(nll, "nll")

        return {
            'loss': nll,
            'elbo': torch.zeros_like(nll),
            'loss_mean': torch.mean(nll),
            'elbo_mean': 0,
            'nll_mean': torch.mean(nll),
            'kld_mean': 0,
            'kl_beta_scalar': 0,
            'proxy_mean': 0,
            'mut_info_mean': 0,
        }
Ejemplo n.º 15
0
    def _reparametrize_gaussian(self, mu, logvar, force=False):
        """ Internal member to reparametrize gaussian.

        :param mu: mean logits
        :param logvar: log-variance.
        :returns: reparameterized tensor and param dict
        :rtype: torch.Tensor, dict

        """
        if self.training or force:  # returns a stochastic sample for training
            std = logvar.mul(0.5)  # Usually has .exp(), but overflows fp16
            eps = torch.zeros_like(logvar).normal_().type(std.dtype)
            if not self.config['half']:  # sanity check while not fp16
                nan_check_and_break(logvar, "logvar")

            reparam_sample = eps.mul(std).add_(mu)
            return reparam_sample, {
                'z': reparam_sample,
                'mu': mu,
                'logvar': logvar
            }
            # return D.Normal(mu, logvar).rsample(), {'mu': mu, 'logvar': logvar}

        return mu, {'z': mu, 'mu': mu, 'logvar': logvar}
Ejemplo n.º 16
0
 def _z_to_image_transformer(self, z, imgs):
     imgs = imgs.type(z.dtype) if isinstance(imgs, torch.Tensor) else imgs
     z_proj = self.posterior_to_st(z)
     crops_pred = self.spatial_transformer(z_proj, imgs, self.vae.chans)
     nan_check_and_break(crops_pred, "predicted_crops_t")
     return {'crops_pred': crops_pred}
Ejemplo n.º 17
0
        def _forward_internal(x_related, inference_only=False):
            params, crops, crops_true, inlays, decodes = [], [], [], [], []

            # reset the state, output and the truncate window
            self.vae.memory.init_state(batch_size, cuda=x_related.is_cuda)
            self.vae.memory.init_output(batch_size,
                                        seqlen=1,
                                        cuda=x_related.is_cuda)

            # accumulator for predictions and ACT
            act = zeros((batch_size, 1),
                        cuda=x_related.is_cuda,
                        dtype=get_dtype(x_related)).squeeze().requires_grad_()
            if self.config['concat_prediction_size'] <= 0:
                x_preds = zeros((batch_size, self.config['latent_size']),
                                cuda=x_related.is_cuda,
                                dtype=get_dtype(x_related)).requires_grad_()
            else:  # the below creates a buffer that can concat results
                x_preds = []  # just concat this and project

            for i in range(self.config['max_time_steps']):
                # get posterior and params, expand 0'th dim for seqlen
                x_related_inference = add_noise_to_imgs(x_related) \
                    if self.config['add_img_noise'] else x_related
                z_t, params_t = self.vae.posterior(x_related_inference)

                nan_check_and_break(x_related_inference, "x_related_inference")
                nan_check_and_break(z_t['prior'], "prior")
                nan_check_and_break(z_t['posterior'], "posterior")
                nan_check_and_break(z_t['x_features'], "x_features")

                # extract the required crop from original image
                x_trunc_t = self._z_to_image(z_t['posterior'], x)

                # do preds and sum
                state = torch.mean(self.vae.memory.get_state()[0], 0)
                crops_pred_perturbed = add_noise_to_imgs(x_trunc_t['crops_pred']) \
                    if self.config['add_img_noise'] else x_trunc_t['crops_pred']
                state_proj = self.latent_projector(crops_pred_perturbed, state)

                # add to crops if concat_prediction_size is not specified
                # otherwise use a concatenation strategy
                if self.config['concat_prediction_size'] <= 0:
                    x_preds = x_preds + state_proj[:,
                                                   0:-1]  # last bit is for ACT
                else:
                    x_preds.append(state_proj[:, 0:-1])  # last bit is for ACT

                # decode the posterior
                decoded_t = self.vae.decode(z_t, produce_output=True)
                nan_check_and_break(decoded_t, "decoded_t")

                # cache for loss function & visualization
                params.append(params_t)
                crops.append(x_trunc_t['crops_pred'])
                decodes.append(decoded_t)

                # only add these if we are in the lambda setting:
                if 'crops_true' in x_trunc_t and 'inlay' in x_trunc_t:
                    crops_true.append(x_trunc_t['crops_true'])
                    inlays.append(x_trunc_t['inlay'])

                # conditionally break away based on ACT
                # act = act + torch.sigmoid(state_proj[:, -1])
                # if torch.max(act / max(i, 1)) >= 0.9998:
                #     break

            # stack if we are using the concat solution
            x_preds = torch.cat(
                x_preds,
                -1) if self.config['concat_prediction_size'] > 0 else x_preds
            preds = self.latent_projector.get_output(
                x_preds / self.config['max_time_steps'])
            return {
                'act': act / max(i, 1),
                'saccades_scalar': i,
                'decoded': decodes,
                'params': params,
                'preds': preds,
                'inlays': inlays,
                'crops': crops,
                'crops_true': crops_true
            }