def reparameterize(self, logits_map): """ reparameterize the encoder output and the prior :param logits_map: the map of logits :returns: a dict of reparameterized things :rtype: dict """ nan_check_and_break(logits_map['encoder_logits'], "enc_logits") nan_check_and_break(logits_map['prior_logits'], "prior_logits") z_enc_t, params_enc_t = self.reparameterizer( logits_map['encoder_logits']) # XXX: clamp the variance of gaussian priors to not explode logits_map['prior_logits'] = self._clamp_variance( logits_map['prior_logits']) # reparamterize the prior distribution z_prior_t, params_prior_t = self.reparameterizer( logits_map['prior_logits']) z = { # reparameterization 'prior': z_prior_t, 'posterior': z_enc_t, 'x_features': logits_map['x_features'] } params = { # params of the reparameterization 'prior': params_prior_t, 'posterior': params_enc_t } return z, params
def step(self, x_t, inference_only=False, **kwargs): """ Single step forward pass. :param x_i: the input tensor for time t :param inference_only: whether or not to run the decoding process :returns: decoded for current time and params for current time :rtype: torch.Tensor, torch.Tensor """ x_t_inference = add_noise_to_imgs(x_t) \ if self.config['add_img_noise'] else x_t # add image quantization noise z_t, params_t = self.posterior(x_t_inference) # sanity checks nan_check_and_break(x_t_inference, "x_related_inference") nan_check_and_break(z_t['prior'], "prior") nan_check_and_break(z_t['posterior'], "posterior") nan_check_and_break(z_t['x_features'], "x_features") if not inference_only: # decode the posterior decoded_t = self.decode(z_t, produce_output=True) nan_check_and_break(decoded_t, "decoded_t") return decoded_t, params_t return None, params_t
def loss_function(self, x, labels, output_map): ''' loss is: L_{classifier} * L_{VAE} ''' vae_loss_map = self.vae.loss_function( output_map['decoded'], [x.clone() for _ in range(len(output_map['decoded']))], output_map['params']) # get classifier loss, use BCE with logits if multi-dimensional loss_fn = F.binary_cross_entropy_with_logits \ if len(labels.shape) > 1 else F.cross_entropy pred_loss = loss_fn(input=output_map['preds'], target=labels, reduction='none') pred_loss = torch.sum(pred_loss, -1) if len(pred_loss.shape) > 1 else pred_loss nan_check_and_break(pred_loss, "pred_loss") # ACT loss vae_loss_map['saccades_scalar'] = output_map['saccades_scalar'] act_loss = torch.zeros_like(pred_loss) # act_loss = F.binary_cross_entropy_with_logits( # input=output_map['act'], # target=torch.ones_like(output_map['act']), # reduction='none' # ) # TODO: try multi-task loss, ie: # vae_loss_map['loss'] = vae_loss_map['loss'] + (pred_loss + act_loss) vae_loss_map['loss'] = vae_loss_map['loss'] * (pred_loss + act_loss) # compute the means for visualizations of bp vae_loss_map['act_loss_mean'] = torch.mean(act_loss) vae_loss_map['pred_loss_mean'] = torch.mean(pred_loss) vae_loss_map['loss_mean'] = torch.mean(vae_loss_map['loss']) return vae_loss_map
def decode(self, z_t, produce_output=False, update_memory=True): """ decodes using VRNN :param z_t: the latent sample :param produce_output: produce output or just update stae :returns: decoded logits :rtype: torch.Tensor """ # grab state from RNN, TODO: evaluate recovery methods below # [0] grabs the h from LSTM (as opposed to (h, c)) final_state = torch.mean(self.memory.get_state()[0], 0) # feature transform for z_t phi_z_t = self.phi_z(z_t['posterior']) # sanity checks nan_check_and_break(final_state, "final_rnn_output[decode]") nan_check_and_break(phi_z_t, "phi_z_t") if update_memory: # concat and run through RNN to update state input_t = torch.cat([z_t['x_features'], phi_z_t], -1).unsqueeze(0) self.memory(input_t.contiguous()) # decode only if flag is set dec_t = None if produce_output: dec_input_t = torch.cat([phi_z_t, final_state], -1) dec_t = self.decoder(dec_input_t) return dec_t
def reparmeterize(self, logits): """ Given logits reparameterize to a gaussian using first half of features for mean and second half for std. :param logits: unactivated logits :returns: reparameterized tensor (if training), param dict :rtype: torch.Tensor, dict """ eps = eps_fn(self.config['half']) feature_size = logits.size(-1) assert feature_size % 2 == 0 and feature_size // 2 == self.output_size if logits.dim() == 2: mu = logits[:, 0:int(feature_size / 2)] nan_check_and_break(mu, "mu") sigma = logits[:, int(feature_size / 2):] + eps # sigma = F.softplus(logits[:, int(feature_size/2):]) + eps # sigma = F.hardtanh(logits[:, int(feature_size/2):], min_val=-6.,max_val=2.) elif logits.dim() == 3: mu = logits[:, :, 0:int(feature_size / 2)] sigma = logits[:, :, int(feature_size / 2):] + eps else: raise Exception( "unknown number of dims for isotropic gauss reparam") return self._reparametrize_gaussian(mu, sigma)
def loss_function(self, recon_x, x, params, K=1, **extra_loss_terms): """ Produces ELBO. :param recon_x: the unactivated reconstruction preds. :param x: input tensor. :param params: the dict of reparameterization. :param K: number of monte-carlo samples to use. :param extra_loss_terms: kwargs of extra [B] dimensional losses :returns: loss dict :rtype: dict """ nll = self.nll(x, recon_x, self.config['nll_type']) # multiple monte-carlo samples for the decoder. if self.training: for k in range(1, K): z_k, params_k = self.reparameterize(logits=params['logits'], labels=params.get( 'labels', None)) recon_x_i = self.decode(z_k) nll = nll + self.nll(x, recon_x_i, self.config['nll_type']) nll = nll / K kld = self.kld(params) elbo = nll + kld # save the base ELBO, but use the beta-vae elbo for the full loss # handle the mutual information term mut_info = self.mut_info(params, x.size(0)) # get the kl-beta from the annealer or just set to fixed value kl_beta = self.compute_kl_beta([self.config['kl_beta']])[0] # sanity checks only dont in fp32 due to too much fp16 magic if not self.config['half']: utils.nan_check_and_break(nll, "nll") if kl_beta > 0: # only check if we have a KLD utils.nan_check_and_break(kld, "kld") # if we are provided additional losses add them together additional_losses = torch.sum( torch.cat([v.unsqueeze(0) for v in extra_loss_terms.values()], 0), 0) \ if extra_loss_terms else torch.zeros_like(nll) # compute full loss to use for optimization loss = (nll + additional_losses + kl_beta * kld) - mut_info return { 'loss': loss, 'elbo': elbo, 'loss_mean': torch.mean(loss), 'elbo_mean': torch.mean(elbo), 'nll_mean': torch.mean(nll), 'kld_mean': torch.mean(kld), 'additional_loss_mean': torch.mean(additional_losses), 'kl_beta_scalar': kl_beta, 'mut_info_mean': torch.mean(mut_info) }
def log_likelihood(self, z, params): cont = self.continuous.log_likelihood(z[:, 0:self.continuous.output_size], params) disc = self.discrete.log_likelihood(z[:, self.continuous.output_size:], params) if disc.dim() < 2: disc = disc.unsqueeze(-1) # sanity check and return nan_check_and_break(cont, 'cont_ll') nan_check_and_break(disc, 'disc_ll') return torch.cat([cont, disc], 1)
def _reparametrize_gaussian(self, mu, logvar): """ Internal member to reparametrize gaussian. :param mu: mean logits :param logvar: log-variance. :returns: reparameterized tensor and param dict :rtype: torch.Tensor, dict """ if self.training: # returns a stochastic sample for training std = logvar.mul(0.5).exp() eps = same_type(is_half(logvar), logvar.is_cuda)(logvar.size()).normal_() eps = Variable(eps) nan_check_and_break(logvar, "logvar") return eps.mul(std).add_(mu), {'mu': mu, 'logvar': logvar} return mu, {'mu': mu, 'logvar': logvar}
def loss_function(self, recon_x, x, reparam_map): """ VAE with no KL objective. Still uses reparam. :param recon_x: the unactivated reconstruction preds. :param x: input tensor. :returns: loss dict :rtype: dict """ nll = distributions.nll(x, recon_x, self.config['nll_type']) utils.nan_check_and_break(nll, "nll") return { 'loss': nll, 'loss_mean': torch.mean(nll), 'elbo_mean': torch.mean(torch.zeros_like(nll)), 'nll_mean': torch.mean(nll), 'kld_mean': torch.mean(torch.zeros_like(nll)), 'proxy_mean': torch.mean(torch.zeros_like(nll)), 'mut_info_mean': torch.mean(torch.zeros_like(nll)), }
def encode(self, x, *xargs): """ single sample encode using x :param x: the input tensor :returns: dict of encoded logits :rtype: dict """ if self.config['decoder_layer_type'] == 'pixelcnn': x = (x - .5) * 2. # get the memory trace, TODO: evaluate different recovery methods below final_state = torch.mean(self.memory.get_state()[0], 0) nan_check_and_break(final_state, "final_rnn_output") # extract input data features phi_x_t = self._extract_features(x, *xargs).squeeze() # encoder projection enc_input_t = torch.cat([phi_x_t, final_state], dim=-1) enc_t = self._lazy_build_encoder(enc_input_t.size(-1))(enc_input_t) # prior projection , consider: + eps_fn(self.config['cuda'])) prior_t = self.prior(final_state.contiguous()) # sanity checks nan_check_and_break(enc_t, "enc_t") nan_check_and_break(prior_t, "priot_t") return { 'encoder_logits': enc_t, 'prior_logits': prior_t, 'x_features': phi_x_t }
def z_where_inv(z_where, clip_scale=5.0): # Take a batch of z_where vectors, and compute their "inverse". # That is, for each row compute: # [s,x,y] -> [1/s,-x/s,-y/s] # These are the parameters required to perform the inverse of the # spatial transform performed in the generative model. n = z_where.size(0) out = torch.cat((LocalizedSpatialTransformerFn.ng_ones( [1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1) # Divide all entries by the scale. abs(scale) ensures images arent flipped scale = torch.max(torch.abs(z_where[:, 0:1]), zeros_like(z_where[:, 0:1]) + clip_scale) if torch.sum(scale == 0) > 0: print("tensor scale of {} dim was 0!!".format(scale.shape)) exit(-1) nan_check_and_break(scale, "scale") out = out / scale # out = out / z_where[:, 0:1] return out
def step(self, x_i, inference_only=False): """ Single step forward pass. :param x_related: input tensor :param inference_only: :returns: :rtype: """ x_i_inference = add_noise_to_imgs(x_i) \ if self.config['add_img_noise'] else x_i # add image quantization noise z_t, params_t = self.posterior(x_i_inference) nan_check_and_break(x_i_inference, "x_related_inference") nan_check_and_break(z_t['prior'], "prior") nan_check_and_break(z_t['posterior'], "posterior") nan_check_and_break(z_t['x_features'], "x_features") # decode the posterior decoded_t = self.decode(z_t, produce_output=True) nan_check_and_break(decoded_t, "decoded_t") return decoded_t, params_t
def loss_function(self, recon_x, x, params): """ Produces ELBO, handles mutual info and proxy loss terms too. :param recon_x: the unactivated reconstruction preds. :param x: input tensor. :param params: the dict of reparameterization. :param mut_info: the calculated mutual info. :returns: loss dict :rtype: dict """ if self.config['decoder_layer_type'] == 'pixelcnn': x = (x - .5) * 2. nll = nll_fn(x, recon_x, self.config['nll_type']) nan_check_and_break(nll, "nll") kld = self.kld(params) nan_check_and_break(kld, "kld") elbo = nll + kld # save the base ELBO, but use the beta-vae elbo for the full loss # add the proxy loss if it exists proxy_loss = self.reparameterizer.proxy_layer.loss_function() \ if hasattr(self.reparameterizer, 'proxy_layer') else torch.zeros_like(elbo) # handle the mutual information term mut_info = self.mut_info(params, x.size(0)) loss = (nll + self.config['kl_beta'] * kld) - mut_info return { 'loss': loss, 'loss_mean': torch.mean(loss), 'elbo_mean': torch.mean(elbo), 'nll_mean': torch.mean(nll), 'kld_mean': torch.mean(kld), 'proxy_mean': torch.mean(proxy_loss), 'mut_info_mean': torch.mean(mut_info) }
def loss_function(self, recon_x, x, **unused_kwargs): """ Autoencoder is simple the NLL term in the VAE. :param recon_x: the unactivated reconstruction preds. :param x: input tensor. :returns: loss dict :rtype: dict """ nll = distributions.nll(x, recon_x, self.config['nll_type']) if not self.config['half']: utils.nan_check_and_break(nll, "nll") return { 'loss': nll, 'elbo': torch.zeros_like(nll), 'loss_mean': torch.mean(nll), 'elbo_mean': 0, 'nll_mean': torch.mean(nll), 'kld_mean': 0, 'kl_beta_scalar': 0, 'proxy_mean': 0, 'mut_info_mean': 0, }
def _reparametrize_gaussian(self, mu, logvar, force=False): """ Internal member to reparametrize gaussian. :param mu: mean logits :param logvar: log-variance. :returns: reparameterized tensor and param dict :rtype: torch.Tensor, dict """ if self.training or force: # returns a stochastic sample for training std = logvar.mul(0.5) # Usually has .exp(), but overflows fp16 eps = torch.zeros_like(logvar).normal_().type(std.dtype) if not self.config['half']: # sanity check while not fp16 nan_check_and_break(logvar, "logvar") reparam_sample = eps.mul(std).add_(mu) return reparam_sample, { 'z': reparam_sample, 'mu': mu, 'logvar': logvar } # return D.Normal(mu, logvar).rsample(), {'mu': mu, 'logvar': logvar} return mu, {'z': mu, 'mu': mu, 'logvar': logvar}
def _z_to_image_transformer(self, z, imgs): imgs = imgs.type(z.dtype) if isinstance(imgs, torch.Tensor) else imgs z_proj = self.posterior_to_st(z) crops_pred = self.spatial_transformer(z_proj, imgs, self.vae.chans) nan_check_and_break(crops_pred, "predicted_crops_t") return {'crops_pred': crops_pred}
def _forward_internal(x_related, inference_only=False): params, crops, crops_true, inlays, decodes = [], [], [], [], [] # reset the state, output and the truncate window self.vae.memory.init_state(batch_size, cuda=x_related.is_cuda) self.vae.memory.init_output(batch_size, seqlen=1, cuda=x_related.is_cuda) # accumulator for predictions and ACT act = zeros((batch_size, 1), cuda=x_related.is_cuda, dtype=get_dtype(x_related)).squeeze().requires_grad_() if self.config['concat_prediction_size'] <= 0: x_preds = zeros((batch_size, self.config['latent_size']), cuda=x_related.is_cuda, dtype=get_dtype(x_related)).requires_grad_() else: # the below creates a buffer that can concat results x_preds = [] # just concat this and project for i in range(self.config['max_time_steps']): # get posterior and params, expand 0'th dim for seqlen x_related_inference = add_noise_to_imgs(x_related) \ if self.config['add_img_noise'] else x_related z_t, params_t = self.vae.posterior(x_related_inference) nan_check_and_break(x_related_inference, "x_related_inference") nan_check_and_break(z_t['prior'], "prior") nan_check_and_break(z_t['posterior'], "posterior") nan_check_and_break(z_t['x_features'], "x_features") # extract the required crop from original image x_trunc_t = self._z_to_image(z_t['posterior'], x) # do preds and sum state = torch.mean(self.vae.memory.get_state()[0], 0) crops_pred_perturbed = add_noise_to_imgs(x_trunc_t['crops_pred']) \ if self.config['add_img_noise'] else x_trunc_t['crops_pred'] state_proj = self.latent_projector(crops_pred_perturbed, state) # add to crops if concat_prediction_size is not specified # otherwise use a concatenation strategy if self.config['concat_prediction_size'] <= 0: x_preds = x_preds + state_proj[:, 0:-1] # last bit is for ACT else: x_preds.append(state_proj[:, 0:-1]) # last bit is for ACT # decode the posterior decoded_t = self.vae.decode(z_t, produce_output=True) nan_check_and_break(decoded_t, "decoded_t") # cache for loss function & visualization params.append(params_t) crops.append(x_trunc_t['crops_pred']) decodes.append(decoded_t) # only add these if we are in the lambda setting: if 'crops_true' in x_trunc_t and 'inlay' in x_trunc_t: crops_true.append(x_trunc_t['crops_true']) inlays.append(x_trunc_t['inlay']) # conditionally break away based on ACT # act = act + torch.sigmoid(state_proj[:, -1]) # if torch.max(act / max(i, 1)) >= 0.9998: # break # stack if we are using the concat solution x_preds = torch.cat( x_preds, -1) if self.config['concat_prediction_size'] > 0 else x_preds preds = self.latent_projector.get_output( x_preds / self.config['max_time_steps']) return { 'act': act / max(i, 1), 'saccades_scalar': i, 'decoded': decodes, 'params': params, 'preds': preds, 'inlays': inlays, 'crops': crops, 'crops_true': crops_true }