def update_actor(self, states, actions, advantages, old_pi): with tf.GradientTape() as tape: mean, std, log_std = self.actor(states) pi = DiagonalGaussian(mean, std, log_std) log_pi = pi.log_likelihood(actions) log_old_pi = old_pi.log_likelihood(actions) ratio = tf.exp(log_pi - log_old_pi) surr = tf.math.minimum( ratio * advantages, tf.clip_by_value(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages) loss = -tf.math.reduce_mean(surr) approx_ent = tf.math.reduce_mean(-log_pi) loss -= approx_ent * self.ent_weight # maximize the entropy to encourage exploration approx_kl = tf.math.reduce_mean(log_old_pi - log_pi) grad = tape.gradient(loss, self.actor.trainable_weights) # very important to clip gradient grad, grad_norm = tf.clip_by_global_norm(grad, 0.5) self.actor_optimizer.apply_gradients( zip(grad, self.actor.trainable_weights)) return approx_kl
def sample_action(self, observation): # shape (1,act_dim) mean, std, log_std = self.actor(observation[np.newaxis, :]) pi = DiagonalGaussian(mean, std, log_std) #action = tf.clip_by_value(pi.sample(), self.action_bound[0], self.action_bound[1]) action = pi.sample() # shape (1,1) value = self.critic(observation[np.newaxis, :]) return action[0], value[0, 0]
def sample(means, logvars, latent_dim, iaf=True, kl_min=None, anneal=False, kl_rate=None, dtype=tf.float32): """Perform sampling and calculate KL divergence. Args: means: tensor of shape (batch_size, latent_dim) logvars: tensor of shape (batch_size, latent_dim) latent_dim: dimension of latent space. iaf: perform linear IAF or not. kl_min: lower bound for KL divergence. anneal: perform KL cost annealing or not. kl_rate: KL divergence is multiplied by kl_rate if anneal is set to True. Returns: latent_vector: latent variable after sampling. A vector of shape (batch_size, latent_dim). kl_obj: objective to be minimized for the KL term. kl_cost: real KL divergence. """ if iaf: with tf.variable_scope('iaf'): prior = DiagonalGaussian(tf.zeros_like(means, dtype=dtype), tf.zeros_like(logvars, dtype=dtype)) posterior = DiagonalGaussian(means, logvars) z = posterior.sample logqs = posterior.logps(z) L = tf.get_variable("inverse_cholesky", [latent_dim, latent_dim], dtype=dtype, initializer=tf.zeros_initializer) diag_one = tf.ones([latent_dim], dtype=dtype) L = tf.matrix_set_diag(L, diag_one) mask = np.tril(np.ones([latent_dim,latent_dim])) L = L * mask latent_vector = tf.matmul(z, L) logps = prior.logps(latent_vector) kl_cost = logqs - logps else: noise = tf.random_normal(tf.shape(mean)) sample = mean + tf.exp(0.5 * logvar) * noise kl_cost = -0.5 * (logvars - tf.square(means) - tf.exp(logvars) + 1.0) kl_ave = tf.reduce_mean(kl_cost, [0]) #mean of kl_cost over batches kl_obj = kl_cost = tf.reduce_sum(kl_ave) if kl_min: kl_obj = tf.reduce_sum(tf.maximum(kl_ave, kl_min)) if anneal: kl_obj = kl_obj * kl_rate return latent_vector, kl_obj, kl_cost #both kl_obj and kl_cost are scalar
def update(self, states, actions, returns, advantages): mean, std, log_std = self.actor(states) old_pi = DiagonalGaussian(mean, std, log_std) for i in range(self.train_a_iters): kl = self.update_actor(states, actions, advantages, old_pi) if kl > tf.constant(1.5 * self.kl_target): print('Early stopping at step %d due to reaching max kl.' % i) break for i in range(self.train_c_iters): self.update_critic(states, returns)
def sample(means, logvars, latent_dim, iaf=True, kl_min=None, anneal=False, kl_rate=None, dtype=tf.float32): """Perform sampling and calculate KL divergence. Args: means: tensor of shape (batch_size, latent_dim) logvars: tensor of shape (batch_size, latent_dim) latent_dim: dimension of latent space. iaf: perform linear IAF or not. kl_min: lower bound for KL divergence. anneal: perform KL cost annealing or not. kl_rate: KL divergence is multiplied by kl_rate if anneal is set to True. Returns: latent_vector: latent variable after sampling. A vector of shape (batch_size, latent_dim). kl_obj: objective to be minimized for the KL term. kl_cost: real KL divergence. """ if iaf: with tf.variable_scope('iaf'): prior = DiagonalGaussian(tf.zeros_like(means, dtype=dtype), tf.zeros_like(logvars, dtype=dtype)) posterior = DiagonalGaussian(means, logvars) z = posterior.sample logqs = posterior.logps(z) L = tf.get_variable("inverse_cholesky", [latent_dim, latent_dim], dtype=dtype, initializer=tf.zeros_initializer) diag_one = tf.ones([latent_dim], dtype=dtype) L = tf.matrix_set_diag(L, diag_one) mask = np.tril(np.ones([latent_dim, latent_dim])) L = L * mask latent_vector = tf.matmul(z, L) logps = prior.logps(latent_vector) kl_cost = logqs - logps else: noise = tf.random_normal(tf.shape(mean)) sample = mean + tf.exp(0.5 * logvar) * noise kl_cost = -0.5 * (logvars - tf.square(means) - tf.exp(logvars) + 1.0) kl_ave = tf.reduce_mean(kl_cost, [0]) #mean of kl_cost over batches kl_obj = kl_cost = tf.reduce_sum(kl_ave) if kl_min: kl_obj = tf.reduce_sum(tf.maximum(kl_ave, kl_min)) if anneal: kl_obj = kl_obj * kl_rate return latent_vector, kl_obj, kl_cost #both kl_obj and kl_cost are scalar
def get_dists(self, obs): obs = torch.from_numpy(obs) mean, log_std = self.network(obs) dist = DiagonalGaussian(mean, log_std) return dist
def actions(self, obs): obs = torch.from_numpy(obs) mean, log_std = self.network(obs) dist = DiagonalGaussian(mean, log_std) sample = dist.sample() return sample, dist.logli(sample)
def __construct__(self, arch): """ Construct the model from the architecture dictionary. :param arch: architecture dictionary :return None """ # these are the same across all latent levels encoding_form = arch['encoding_form'] variable_update_form = arch['variable_update_form'] const_prior_var = arch['constant_prior_variances'] posterior_form = arch['posterior_form'] latent_level_type = RecurrentLatentLevel if arch['encoder_type'] == 'recurrent' else DenseLatentLevel encoder_arch = None if arch['encoder_type'] == 'inference_model': encoder_arch = dict() encoder_arch['non_linearity'] = arch['non_linearity_enc'] encoder_arch['connection_type'] = arch['connection_type_enc'] encoder_arch['batch_norm'] = arch['batch_norm_enc'] encoder_arch['weight_norm'] = arch['weight_norm_enc'] encoder_arch['dropout'] = arch['dropout_enc'] decoder_arch = dict() decoder_arch['non_linearity'] = arch['non_linearity_dec'] decoder_arch['connection_type'] = arch['connection_type_dec'] decoder_arch['batch_norm'] = arch['batch_norm_dec'] decoder_arch['weight_norm'] = arch['weight_norm_dec'] decoder_arch['dropout'] = arch['dropout_dec'] # construct a DenseLatentLevel for each level of latent variables for level in range(len(arch['n_latent'])): # get specifications for this level's encoder and decoder if arch['encoder_type'] == 'inference_model': encoder_arch['n_in'] = self.encoder_input_size(level, arch) encoder_arch['n_units'] = arch['n_units_enc'][level] encoder_arch['n_layers'] = arch['n_layers_enc'][level] decoder_arch['n_in'] = self.decoder_input_size(level, arch) decoder_arch['n_units'] = arch['n_units_dec'][level+1] decoder_arch['n_layers'] = arch['n_layers_dec'][level+1] n_latent = arch['n_latent'][level] n_det = [arch['n_det_enc'][level], arch['n_det_dec'][level]] learn_prior = True if arch['learn_top_prior'] else (level != len(arch['n_latent'])-1) self.levels[level] = latent_level_type(self.batch_size, encoder_arch, decoder_arch, n_latent, n_det, encoding_form, const_prior_var, variable_update_form, posterior_form, learn_prior) # construct the output decoder decoder_arch['n_in'] = self.decoder_input_size(-1, arch) decoder_arch['n_units'] = arch['n_units_dec'][0] decoder_arch['n_layers'] = arch['n_layers_dec'][0] self.output_decoder = MultiLayerPerceptron(**decoder_arch) # construct the output distribution if self.output_distribution == 'bernoulli': self.output_dist = Bernoulli(self.input_size, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='sigmoid', weight_norm=arch['weight_norm_dec']) elif self.output_distribution == 'multinomial': self.output_dist = Multinomial(self.input_size, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='linear', weight_norm=arch['weight_norm_dec']) elif self.output_distribution == 'gaussian': self.output_dist = DiagonalGaussian(self.input_size, None, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='sigmoid', weight_norm=arch['weight_norm_dec']) if self.constant_variances: if arch['single_output_variance']: self.trainable_log_var = Variable(torch.zeros(1), requires_grad=True) else: self.trainable_log_var = Variable(torch.normal(torch.zeros(self.input_size), 0.25), requires_grad=True) else: self.log_var_output = Dense(arch['n_units_dec'][0], self.input_size, weight_norm=arch['weight_norm_dec']) # make the state trainable if encoder_type is EM if arch['encoder_type'] in ['em', 'EM']: self.trainable_state()
class DenseLatentVariableModel(object): def __init__(self, train_config, arch, data_loader): self.encoding_form = arch['encoding_form'] self.constant_variances = arch['constant_prior_variances'] self.single_output_variance = arch['single_output_variance'] self.posterior_form = arch['posterior_form'] self.batch_size = train_config['batch_size'] self.n_training_samples = train_config['n_samples'] self.kl_min = train_config['kl_min'] self.concat_variables = arch['concat_variables'] self.top_size = arch['top_size'] self.input_size = np.prod(tuple(next(iter(data_loader))[0].size()[1:])).astype(int) assert train_config['output_distribution'] in ['bernoulli', 'gaussian', 'multinomial'], 'Output distribution not recognized.' self.output_distribution = train_config['output_distribution'] self.reconstruction = None self.kl_weight = 1. # construct the model self.levels = [None for _ in range(len(arch['n_latent']))] self.output_decoder = self.output_dist = self.mean_output = self.log_var_output = self.trainable_log_var = None self.state_optimizer = None self.__construct__(arch) self._cuda_device = None if train_config['cuda_device'] is not None: self.cuda(train_config['cuda_device']) def __construct__(self, arch): """ Construct the model from the architecture dictionary. :param arch: architecture dictionary :return None """ # these are the same across all latent levels encoding_form = arch['encoding_form'] variable_update_form = arch['variable_update_form'] const_prior_var = arch['constant_prior_variances'] posterior_form = arch['posterior_form'] latent_level_type = RecurrentLatentLevel if arch['encoder_type'] == 'recurrent' else DenseLatentLevel encoder_arch = None if arch['encoder_type'] == 'inference_model': encoder_arch = dict() encoder_arch['non_linearity'] = arch['non_linearity_enc'] encoder_arch['connection_type'] = arch['connection_type_enc'] encoder_arch['batch_norm'] = arch['batch_norm_enc'] encoder_arch['weight_norm'] = arch['weight_norm_enc'] encoder_arch['dropout'] = arch['dropout_enc'] decoder_arch = dict() decoder_arch['non_linearity'] = arch['non_linearity_dec'] decoder_arch['connection_type'] = arch['connection_type_dec'] decoder_arch['batch_norm'] = arch['batch_norm_dec'] decoder_arch['weight_norm'] = arch['weight_norm_dec'] decoder_arch['dropout'] = arch['dropout_dec'] # construct a DenseLatentLevel for each level of latent variables for level in range(len(arch['n_latent'])): # get specifications for this level's encoder and decoder if arch['encoder_type'] == 'inference_model': encoder_arch['n_in'] = self.encoder_input_size(level, arch) encoder_arch['n_units'] = arch['n_units_enc'][level] encoder_arch['n_layers'] = arch['n_layers_enc'][level] decoder_arch['n_in'] = self.decoder_input_size(level, arch) decoder_arch['n_units'] = arch['n_units_dec'][level+1] decoder_arch['n_layers'] = arch['n_layers_dec'][level+1] n_latent = arch['n_latent'][level] n_det = [arch['n_det_enc'][level], arch['n_det_dec'][level]] learn_prior = True if arch['learn_top_prior'] else (level != len(arch['n_latent'])-1) self.levels[level] = latent_level_type(self.batch_size, encoder_arch, decoder_arch, n_latent, n_det, encoding_form, const_prior_var, variable_update_form, posterior_form, learn_prior) # construct the output decoder decoder_arch['n_in'] = self.decoder_input_size(-1, arch) decoder_arch['n_units'] = arch['n_units_dec'][0] decoder_arch['n_layers'] = arch['n_layers_dec'][0] self.output_decoder = MultiLayerPerceptron(**decoder_arch) # construct the output distribution if self.output_distribution == 'bernoulli': self.output_dist = Bernoulli(self.input_size, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='sigmoid', weight_norm=arch['weight_norm_dec']) elif self.output_distribution == 'multinomial': self.output_dist = Multinomial(self.input_size, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='linear', weight_norm=arch['weight_norm_dec']) elif self.output_distribution == 'gaussian': self.output_dist = DiagonalGaussian(self.input_size, None, None) self.mean_output = Dense(arch['n_units_dec'][0], self.input_size, non_linearity='sigmoid', weight_norm=arch['weight_norm_dec']) if self.constant_variances: if arch['single_output_variance']: self.trainable_log_var = Variable(torch.zeros(1), requires_grad=True) else: self.trainable_log_var = Variable(torch.normal(torch.zeros(self.input_size), 0.25), requires_grad=True) else: self.log_var_output = Dense(arch['n_units_dec'][0], self.input_size, weight_norm=arch['weight_norm_dec']) # make the state trainable if encoder_type is EM if arch['encoder_type'] in ['em', 'EM']: self.trainable_state() def encoder_input_size(self, level_num, arch): """ Calculates the size of the encoding input to a level. If we're encoding the gradient, then the encoding size is the size of the latent variables (x 2 if Gaussian variable). Otherwise, the encoding size depends on how many errors/variables we're encoding. :param level_num: the index of the level we're calculating the encoding size for :param arch: architecture dictionary :return: the size of this level's encoder's input """ def _encoding_size(_self, _level_num, _arch, lower_level=False): if _level_num == 0: latent_size = _self.input_size det_size = 0 else: latent_size = _arch['n_latent'][_level_num-1] det_size = _arch['n_det_enc'][_level_num-1] encoding_size = det_size if 'posterior' in _self.encoding_form: encoding_size += latent_size if 'mean' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_mean' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_mean' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'mean_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_mean_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_mean_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'log_var_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_log_var_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_log_var_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'log_var' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_log_var' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_log_var' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'var' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'bottom_error' in _self.encoding_form: encoding_size += latent_size if 'l2_norm_bottom_error' in _self.encoding_form: encoding_size += latent_size if 'layer_norm_bottom_error' in _self.encoding_form: encoding_size += latent_size if 'bottom_norm_error' in _self.encoding_form: encoding_size += latent_size if 'l2_norm_bottom_norm_error' in _self.encoding_form: encoding_size += latent_size if 'layer_norm_bottom_norm_error' in _self.encoding_form: encoding_size += latent_size if 'top_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_top_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_top_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'top_norm_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_top_norm_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'layer_norm_top_norm_error' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if 'gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if self.posterior_form == 'gaussian': encoding_size += _arch['n_latent'][_level_num] if 'l2_norm_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if self.posterior_form == 'gaussian': encoding_size += _arch['n_latent'][_level_num] if 'log_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if self.posterior_form == 'gaussian': encoding_size += _arch['n_latent'][_level_num] if 'scaled_log_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if self.posterior_form == 'gaussian': encoding_size += _arch['n_latent'][_level_num] if 'sign_gradient' in _self.encoding_form and not lower_level: encoding_size += _arch['n_latent'][_level_num] if self.posterior_form == 'gaussian': encoding_size += _arch['n_latent'][_level_num] return encoding_size encoder_size = _encoding_size(self, level_num, arch) if 'gradient' not in self.encoding_form: if self.concat_variables: for level in range(level_num): encoder_size += _encoding_size(self, level, arch, lower_level=True) return encoder_size def decoder_input_size(self, level_num, arch): """Calculates the size of the decoding input to a level.""" if level_num == len(arch['n_latent'])-1: return self.top_size decoder_size = arch['n_latent'][level_num+1] + arch['n_det_dec'][level_num+1] if self.concat_variables: for level in range(level_num+2, len(arch['n_latent'])): decoder_size += (arch['n_latent'][level] + arch['n_det_dec'][level]) return decoder_size def process_input(self, input): """ Whitens or scales the input. :param input: the input data """ if self.output_distribution == 'multinomial': return input return input / 255. def process_output(self, mean): """ Colors or un-scales the output. :param mean: the mean of the output distribution :return the unnormalized or unscaled mean """ if self.output_distribution == 'multinomial': return mean return 255. * mean def get_input_encoding(self, input): """ Encoding at the bottom level. :param input: the input data :return the encoding of the data """ if 'bottom_error' in self.encoding_form or 'bottom_norm_error' in self.encoding_form: assert self.output_dist is not None, 'Cannot encode error. Output distribution is None.' encoding = None if 'posterior' in self.encoding_form: encoding = input - 0.5 if 'bottom_error' in self.encoding_form: error = input - self.output_dist.mean.detach().mean(dim=1) encoding = error if encoding is None else torch.cat((encoding, error), dim=1) if 'norm_bottom_error' in self.encoding_form: error = input - self.output_dist.mean.detach().mean(dim=1) norm_error = error / torch.norm(error, 2, 1, True) encoding = norm_error if encoding is None else torch.cat((encoding, norm_error), dim=1) if 'log_bottom_error' in self.encoding_form: log_error = torch.log(torch.abs(input - self.output_dist.mean.detach().mean(dim=1))) encoding = log_error if encoding is None else torch.cat((encoding, log_error), dim=1) if 'sign_bottom_error' in self.encoding_form: sign_error = torch.sign(input - self.output_dist.mean.detach()) encoding = sign_error if encoding is None else torch.cat((encoding, sign_error), dim=1) if 'bottom_norm_error' in self.encoding_form: error = input - self.output_dist.mean.detach().mean(dim=1) norm_error = None if self.output_distribution == 'gaussian': norm_error = error / torch.exp(self.output_dist.log_var.detach().mean(dim=1)) elif self.output_distribution == 'bernoulli': mean = self.output_dist.mean.detach().mean(dim=1) norm_error = error * torch.exp(- torch.log(mean + 1e-5) - torch.log(1 - mean + 1e-5)) encoding = norm_error if encoding is None else torch.cat((encoding, norm_error), dim=1) if 'norm_bottom_norm_error' in self.encoding_form: error = input - self.output_dist.mean.detach().mean(dim=1) norm_error = None if self.output_distribution == 'gaussian': norm_error = error / torch.exp(self.output_dist.log_var.detach().mean(dim=1)) elif self.output_distribution == 'bernoulli': mean = self.output_dist.mean.detach().mean(dim=1) norm_error = error * torch.exp(- torch.log(mean + 1e-5) - torch.log(1 - mean + 1e-5)) norm_norm_error = norm_error / torch.norm(norm_error, 2, 1, True) encoding = norm_norm_error if encoding is None else torch.cat((encoding, norm_norm_error), dim=1) return encoding def encode(self, input): """ Encodes the input into an updated posterior estimate. :param input: the data input :return None """ if self.state_optimizer is None: if self._cuda_device is not None: input = input.cuda(self._cuda_device) input = self.process_input(input.view(-1, self.input_size)) h = self.get_input_encoding(input) for latent_level in self.levels: if self.concat_variables: h = torch.cat([h, latent_level.encode(h)], dim=1) else: h = latent_level.encode(h) def decode(self, n_samples=0, generate=False): """ Decodes the posterior (prior) estimate to get a reconstruction (sample). :param n_samples: number of samples to decode :param generate: flag to generate or reconstruct the data :return output distribution of reconstruction/sample """ if n_samples == 0: n_samples = self.n_training_samples h = Variable(torch.zeros(self.batch_size, n_samples, self.top_size)) if self._cuda_device is not None: h = h.cuda(self._cuda_device) concat = False for latent_level in self.levels[::-1]: if self.concat_variables and concat: h = torch.cat([h, latent_level.decode(h, n_samples, generate)], dim=2) else: h = latent_level.decode(h, n_samples, generate) concat = True h = h.view(-1, h.size()[2]) h = self.output_decoder(h) mean_out = self.mean_output(h) mean_out = mean_out.view(self.batch_size, n_samples, self.input_size) self.output_dist.mean = mean_out if self.output_distribution == 'gaussian': if self.constant_variances: if self.single_output_variance: self.output_dist.log_var = torch.clamp(self.trainable_log_var * Variable(torch.ones(self.batch_size, n_samples, self.input_size).cuda(self._cuda_device)), -7, 15) else: self.output_dist.log_var = torch.clamp(self.trainable_log_var.view(1, 1, -1).repeat(self.batch_size, n_samples, 1), -7., 15) else: log_var_out = self.log_var_output(h) log_var_out = log_var_out.view(self.batch_size, n_samples, self.input_size) self.output_dist.log_var = torch.clamp(log_var_out, -7., 15) self.reconstruction = self.output_dist.mean[:, 0, :] if self.output_distribution in ['gaussian', 'bernoulli']: self.reconstruction = self.reconstruction * 255. return self.output_dist def kl_divergences(self, averaged=False): """ Returns a list containing kl divergences at each level. :param averaged: whether to average across the batch dimension :return list of KL divergences at each level """ kl = [] for latent_level in range(len(self.levels)-1): # for latent_level in range(len(self.levels)): kl.append(torch.clamp(self.levels[latent_level].kl_divergence(), min=self.kl_min).sum(dim=2)) kl.append(self.levels[-1].latent.analytical_kl().sum(dim=2)) if averaged: return [level_kl.mean() for level_kl in kl] else: return kl def conditional_log_likelihoods(self, input, averaged=False): """ Returns the conditional likelihood. :param input: the input data :param averaged: whether to average across the batch dimension :return the conditional log likelihood """ if self._cuda_device is not None: input = input.cuda(self._cuda_device) input = input.view(-1, 1, self.input_size) / 255. # input = self.process_input(input.view(-1, self.input_size)) n_samples = self.output_dist.mean.data.shape[1] input = input.repeat(1, n_samples, 1) log_prob = self.output_dist.log_prob(sample=input) if self.output_distribution == 'gaussian': log_prob = log_prob - np.log(256.) log_prob = log_prob.sum(dim=2) if averaged: return log_prob.mean() else: return log_prob def elbo(self, input, averaged=False): """ Returns the ELBO. :param input: the input data :param averaged: whether to average across the batch dimension :return the ELBO """ cond_like = self.conditional_log_likelihoods(input) kl = sum(self.kl_divergences()) lower_bound = (cond_like - self.kl_weight * kl).mean(dim=1) # average across sample dimension if averaged: return lower_bound.mean() else: return lower_bound def losses(self, input, averaged=False): """ Returns all losses. :param input: the input data :param averaged: whether to average across the batch dimension """ cll = self.conditional_log_likelihoods(input) cond_log_like = cll.mean(dim=1) kld = self.kl_divergences() kl_div = [kl.mean(dim=1) for kl in kld] lower_bound = (cll - self.kl_weight * sum(kld)).mean(dim=1) if averaged: return lower_bound.mean(), cond_log_like.mean(), [level_kl.mean() for level_kl in kl_div] else: return lower_bound, cond_log_like, kl_div def state_gradients(self): """ Get the gradients for the approximate posterior parameters. :return: dictionary containing keys for each level with lists of gradients """ state_grads = {} for level_num, latent_level in enumerate(self.levels): state_grads[level_num] = latent_level.state_gradients() return state_grads def reset_state(self, mean=None, log_var=None, from_prior=True): """Resets the posterior estimate.""" for latent_level in self.levels: latent_level.reset(mean=mean, log_var=log_var, from_prior=from_prior) def trainable_state(self): """Makes the posterior estimate trainable.""" for latent_level in self.levels: latent_level.trainable_state() def not_trainable_state(self): """Makes the posterior estimate not trainable.""" for latent_level in self.levels: latent_level.not_trainable_state() def parameters(self): """Returns a list containing all parameters.""" return self.encoder_parameters() + self.decoder_parameters() + self.state_parameters() def encoder_parameters(self): """Returns a list containing all parameters in the encoder.""" params = [] for level in self.levels: params.extend(level.encoder_parameters()) return params def decoder_parameters(self): """Returns a list containing all parameters in the decoder.""" params = [] for level in self.levels: params.extend(level.decoder_parameters()) params.extend(list(self.output_decoder.parameters())) params.extend(list(self.mean_output.parameters())) if self.output_distribution == 'gaussian': if self.constant_variances: params.append(self.trainable_log_var) else: params.extend(list(self.log_var_output.parameters())) return params def state_parameters(self): """Returns a list containing the posterior estimate (state).""" states = [] for latent_level in self.levels: states.extend(list(latent_level.state_parameters())) return states def eval(self): """Puts the model into eval mode (affects batch_norm and dropout).""" for latent_level in self.levels: latent_level.eval() self.output_decoder.eval() self.mean_output.eval() if self.output_distribution == 'gaussian': if not self.constant_variances: self.log_var_output.eval() def train(self): """Puts the model into train mode (affects batch_norm and dropout).""" for latent_level in self.levels: latent_level.train() self.output_decoder.train() self.mean_output.train() if self.output_distribution == 'gaussian': if not self.constant_variances: self.log_var_output.train() def random_re_init(self, re_init_fraction=0.05): """Randomly re-initializes a fraction of all of the weights in the model.""" for level in self.levels: level.random_re_init(re_init_fraction) self.output_decoder.random_re_init(re_init_fraction) self.mean_output.random_re_init(re_init_fraction) if output_distribution == 'gaussian': if not self.constant_variances: self.log_var_output.random_re_init(re_init_fraction) def cuda(self, device_id=0): """Places the model on the GPU.""" self._cuda_device = device_id for latent_level in self.levels: latent_level.cuda(device_id) self.output_decoder = self.output_decoder.cuda(device_id) self.mean_output = self.mean_output.cuda(device_id) self.output_dist.cuda(device_id) if self.output_distribution == 'gaussian': if self.constant_variances: self.trainable_log_var = Variable(self.trainable_log_var.data.cuda(device_id), requires_grad=True) self.log_var_output = self.trainable_log_var.unsqueeze(0).repeat(self.batch_size, 1) else: self.log_var_output = self.log_var_output.cuda(device_id) def cpu(self): """Places the model on the CPU.""" self._cuda_device = None for latent_level in self.levels: latent_level.cpu() self.output_decoder = self.output_decoder.cpu() self.mean_output = self.mean_output.cpu() self.output_dist.cpu() if self.output_distribution == 'gaussian': if self.constant_variances: self.trainable_log_var = self.trainable_log_var.cpu() self.log_var_output = self.trainable_log_var.unsqueeze(0).repeat(self.batch_size, 1) else: self.log_var_output = self.log_var_output.cpu()