class FullyConnectedLatentVariable(LatentVariable): """ A fully-connected (Gaussian) latent variable. Args: latent_config (dict): dictionary containing variable configuration parameters: n_variables, n_inputs, inference_procedure """ def __init__(self, latent_config): super(FullyConnectedLatentVariable, self).__init__(latent_config) self._construct(latent_config) def _construct(self, latent_config): """ Method to construct the latent variable from the latent_config dictionary """ self.inference_procedure = latent_config['inference_procedure'] if self.inference_procedure in ['gradient', 'error']: self.update_type = latent_config['update_type'] n_variables = latent_config['n_variables'] n_inputs = latent_config['n_in'] self.normalize_samples = latent_config['normalize_samples'] if self.normalize_samples: self.normalizer = LayerNorm() if self.inference_procedure in ['direct', 'gradient', 'error']: # approximate posterior inputs self.inf_mean_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) self.inf_log_var_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) # self.approx_post_mean = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) # self.approx_post_log_var = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) if self.inference_procedure in ['gradient', 'error']: self.approx_post_mean_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) self.approx_post_log_var_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) # self.close_gates() if self.inference_procedure == 'sgd': self.learning_rate = latent_config['inf_lr'] # prior inputs self.prior_mean = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) self.prior_log_var = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) # distributions self.approx_post = Normal() self.prior = Normal() self.approx_post.re_init() self.prior.re_init() def infer(self, input): """ Method to perform inference. Args: input (Tensor): input to the inference procedure """ if self.inference_procedure in ['direct', 'gradient', 'error']: approx_post_mean = self.inf_mean_output(input) approx_post_log_var = self.inf_log_var_output(input) # approx_post_mean = self.approx_post_mean(input) # approx_post_log_var = self.approx_post_log_var(input) if self.inference_procedure == 'direct': self.approx_post.mean = approx_post_mean self.approx_post.log_var = torch.clamp(approx_post_log_var, -15, 5) elif self.inference_procedure in ['gradient', 'error']: if self.update_type == 'highway': # gated highway update approx_post_mean_gate = self.approx_post_mean_gate(input) self.approx_post.mean = approx_post_mean_gate * self.approx_post.mean.detach() \ + (1 - approx_post_mean_gate) * approx_post_mean approx_post_log_var_gate = self.approx_post_log_var_gate(input) self.approx_post.log_var = torch.clamp(approx_post_log_var_gate * self.approx_post.log_var.detach() \ + (1 - approx_post_log_var_gate) * approx_post_log_var, -15, 5) elif self.update_type == 'learned_sgd': # SGD style update with learned learning rate and offset mean_grad, log_var_grad = self.approx_posterior_gradients() mean_lr = self.approx_post_mean_gate(input) log_var_lr = self.approx_post_log_var_gate(input) self.approx_post.mean = self.approx_post.mean.detach( ) - mean_lr * mean_grad + approx_post_mean self.approx_post.log_var = torch.clamp( self.approx_post.log_var.detach() - log_var_lr * log_var_grad + approx_post_log_var, -15, 5) elif self.inference_procedure == 'sgd': self.approx_post.mean = self.approx_post.mean.detach( ) - self.learning_rate * input[0] self.approx_post.log_var = torch.clamp( self.approx_post.log_var.detach() - self.learning_rate * input[1], -15, 5) self.approx_post.mean.requires_grad = True self.approx_post.log_var.requires_grad = True else: raise NotImplementedError if self.normalize_samples: # apply layer normalization to the approximate posterior means self.approx_post.mean = self.normalizer(self.approx_post.mean) # retain the gradients (for inference) self.approx_post.mean.retain_grad() self.approx_post.log_var.retain_grad() def generate(self, input, gen, n_samples): """ Method to generate, i.e. run the model forward. Args: input (Tensor): input to the generative procedure gen (boolean): whether to sample from approximate poserior (False) or the prior (True) n_samples (int): number of samples to draw """ if input is not None: b, s, n = input.data.shape input = input.view(b * s, n) self.prior.mean = self.prior_mean(input).view(b, s, -1) self.prior.log_var = torch.clamp( self.prior_log_var(input).view(b, s, -1), -15, 5) dist = self.prior if gen else self.approx_post sample = dist.sample(n_samples, resample=True) sample = sample.detach() if self.detach else sample return sample def re_init(self): """ Method to reinitialize the approximate posterior and prior over the variable. """ # TODO: this is wrong. we shouldnt set the posterior to the prior then zero out the prior... self.re_init_approx_posterior() self.prior.re_init() def re_init_approx_posterior(self): """ Method to reinitialize the approximate posterior. """ mean = self.prior.mean.detach().mean(dim=1).data log_var = self.prior.log_var.detach().mean(dim=1).data self.approx_post.re_init(mean, log_var) def step(self): """ Method to step the latent variable forward in the sequence. """ pass def error(self, averaged=True): """ Calculates Gaussian error for encoding. Args: averaged (boolean): whether or not to average over samples """ sample = self.approx_post.sample() n_samples = sample.data.shape[1] prior_mean = self.prior.mean.detach() if len(prior_mean.data.shape) == 2: prior_mean = prior_mean.unsqueeze(1).repeat(1, n_samples, 1) prior_log_var = self.prior.log_var.detach() if len(prior_log_var.data.shape) == 2: prior_log_var = prior_log_var.unsqueeze(1).repeat(1, n_samples, 1) n_error = (sample - prior_mean) / torch.exp(prior_log_var + 1e-7) if averaged: n_error = n_error.mean(dim=1) return n_error def close_gates(self): nn.init.constant(self.approx_post_mean_gate.linear.bias, 5.) nn.init.constant(self.approx_post_log_var_gate.linear.bias, 5.) def inference_parameters(self): """ Method to obtain inference parameters. """ params = nn.ParameterList() params.extend(list(self.inf_mean_output.parameters())) params.extend(list(self.inf_log_var_output.parameters())) # params.extend(list(self.approx_post_mean.parameters())) # params.extend(list(self.approx_post_log_var.parameters())) if self.inference_procedure != 'direct': params.extend(list(self.approx_post_mean_gate.parameters())) params.extend(list(self.approx_post_log_var_gate.parameters())) return params def generative_parameters(self): """ Method to obtain generative parameters. """ params = nn.ParameterList() params.extend(list(self.prior_mean.parameters())) params.extend(list(self.prior_log_var.parameters())) return params def approx_posterior_parameters(self): return [ self.approx_post.mean.detach(), self.approx_post.log_var.detach() ] def approx_posterior_gradients(self): assert self.approx_post.mean.grad is not None, 'Approximate posterior gradients are None.' grads = [self.approx_post.mean.grad.detach()] grads += [self.approx_post.log_var.grad.detach()] for grad in grads: grad.volatile = False return grads
class SVG(LatentVariableModel): """ Stochastic video generation (SVG) model from "Stochastic Video Generation with a Learned Prior," Denton & Fergus, 2018. Args: model_config (dict): dictionary containing model configuration params """ def __init__(self, model_config): super(SVG, self).__init__(model_config) self._construct(model_config) def _construct(self, model_config): """ Method for constructing SVG model using the model configuration file. Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() level_config = {} latent_config = {} latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if not self.modified: level_config['inference_config'] = { 'n_layers': 1, 'n_units': 256, 'n_in': 128 } latent_config['n_in'] = (256, 256 ) # number of encoder, decoder units else: level_config['inference_config'] = None latent_config['n_in'] = [None, 256] # number of encoder, decoder units level_config['generative_config'] = None if model_type == 'sm_mnist': from lib.modules.networks.dcgan_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) self.decoder = decoder(128, self.n_input_channels) self.output_dist = Bernoulli() latent_config['n_variables'] = 10 if self.modified: if self.inference_procedure == 'direct': pass elif self.inference_procedure == 'gradient': pass elif self.inference_procedure == 'error': pass else: raise NotImplementedError elif model_type == 'kth_actions': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 512 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model n_units = 1024 inf_config = { 'n_layers': 1, 'n_units': n_units, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'elu', 'connection_type': 'highway' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = n_units latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError elif model_type == 'bair_robot_pushing': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 3 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 64 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError else: raise Exception('SVG model type must be one of 1) sm_mnist, 2) \ kth_action, or 3) bair_robot_pushing. Invalid model \ type: ' + model_type + '.') # construct a recurrent latent level level_config['latent_config'] = latent_config self.latent_levels = nn.ModuleList([LSTMLatentLevel(level_config)]) self.prior_lstm = LSTMNetwork({ 'n_layers': 1, 'n_units': 256, 'n_in': 128 }) self.decoder_lstm = LSTMNetwork({ 'n_layers': 2, 'n_units': 256, 'n_in': 128 + latent_config['n_variables'] }) self.decoder_lstm_output = FullyConnectedLayer({ 'n_in': 256, 'n_out': 128, 'non_linearity': 'tanh' }) self.output_interval = 1. / 256 def _get_encoding_form(self, observation): """ Gets the appropriate input form for the inference procedure. Args: observation (Variable, tensor): the input observation """ if self.inference_procedure == 'direct': return observation - 0.5 if self.inference_procedure == 'gradient': grads = self.latent_levels[0].latent.approx_posterior_gradients() # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, grad in enumerate(grads): mean = grad.mean(dim=norm_dim, keepdim=True) std = grad.std(dim=norm_dim, keepdim=True) grads[ind] = (grad - mean) / (std + 1e-7) grads = torch.cat(grads, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) grads_params = torch.cat([grads, params], dim=1) # concatenate with the observation if self.model_config['concat_observation']: grads_params = torch.cat([grads_params, observation - 0.5], dim=1) return grads_params elif self.inference_procedure == 'error': # TODO: figure out proper normalization for observation error errors = [ self._output_error(observation), self.latent_levels[0].latent.error() ] # normalize for ind, error in enumerate(errors): mean = error.mean(dim=0, keepdim=True) std = error.std(dim=0, keepdim=True) errors[ind] = (error - mean) / (std + 1e-5) # concatenate params = torch.cat( self.latent_levels[0].latent.approx_posterior_parameters(), dim=1) latent_error_params = torch.cat([errors[1], params], dim=1) if self.model_config['concat_observation']: latent_error_params = torch.cat( [latent_error_params, observation - 0.5], dim=1) return errors[0], latent_error_params else: raise NotImplementedError def _output_error(self, observation, averaged=True): """ Calculates Gaussian error for encoding. Args: observation (tensor): observation to use for error calculation """ # get the output mean and log variance output_mean = self.output_dist.mean.detach() output_log_var = self.output_dist.log_var.detach() # repeat the observation along the sample dimension n_samples = output_mean.data.shape[1] observation = observation.unsqueeze(1).repeat(1, n_samples, 1, 1, 1) # calculate the precision-weighted observation error n_error = (observation - output_mean) / (output_log_var.exp() + 1e-7) if averaged: # average along the sample dimension n_error = n_error.mean(dim=1) return n_error def infer(self, observation): """ Method for perfoming inference of the approximate posterior over the latent variables. Args: observation (tensor): observation to infer latent variables from """ if self.modified: if not self._obs_encoded: # encode the observation (to be used at the next time step) self._h, self._skip = self.encoder(observation - 0.5) self._obs_encoded = True enc = self._get_encoding_form(observation) if self.inference_procedure == 'direct': # separate encoder model enc_h, _ = self.inf_encoder(enc) enc_h = self.inf_model(enc_h) elif self.inference_procedure == 'gradient': # encode through the inference model enc_h = self.inf_model(enc) elif self.inference_procedure == 'error': # encode the error and flatten it enc_error = self.obs_error_enc(enc[0]) enc_error = enc_error.view(enc_error.data.shape[0], -1) # concatenate the error with the rest of the terms enc = torch.cat([enc_error, enc[1]], dim=1) # encode through the inference model enc_h = self.inf_model(enc) self.latent_levels[0].infer(enc_h) else: observation = self._get_encoding_form(observation) self._h, self._skip = self.encoder(observation) self.latent_levels[0].infer(self._h) def generate(self, gen=False, n_samples=1): """ Method for generating observations, i.e. running the generative model forward. Args: gen (boolean): whether to sample from prior or approximate posterior n_samples (int): number of samples to draw and evaluate """ batch_size = self._prev_h.data.shape[0] # get the previous h and skip prev_h = self._prev_h.unsqueeze(1) prev_skip = [ 0. * _prev_skip.repeat(n_samples, 1, 1, 1) for _prev_skip in self._prev_skip ] # detach prev_h and prev_skip if necessary if self._detach_h: prev_h = prev_h.detach() prev_skip = [_prev_skip.detach() for _prev_skip in prev_skip] # get the prior input, detach if necessary gen_input = self._gen_input gen_input = gen_input.detach() if self._detach_h else gen_input # sample the latent variables z = self.latent_levels[0].generate(gen_input.unsqueeze(1), gen=gen, n_samples=n_samples) # pass through the decoder decoder_input = torch.cat([z, prev_h], dim=2).view(batch_size * n_samples, -1) g = self.decoder_lstm(decoder_input, detach=self._detach_h) g = self.decoder_lstm_output(g) output = self.decoder([g, prev_skip]) b, _, h, w = output.data.shape # get the output mean and log variance if self.model_config['global_output_log_var']: # repeat along batch and sample dimensions output = output.view(b, -1, self.n_input_channels, h, w) log_var = self.output_log_var.unsqueeze(0).unsqueeze(0).repeat( batch_size, n_samples, 1, 1, 1) self.output_dist.log_var = torch.clamp(log_var, min=-10) else: output = output.view(b, -1, 2 * self.n_input_channels, h, w) self.output_dist.log_var = torch.clamp( output[:, :, self.n_input_channels:, :, :], min=-10) self.output_dist.mean = output[:, :, :self. n_input_channels, :, :].sigmoid() return torch.clamp(self.output_dist.sample(), 0., 1.) def step(self): """ Method for stepping the generative model forward one step in the sequence. """ # TODO: set n_samples in a smart way # step the lstms and latent level self.latent_levels[0].step() self.prior_lstm.step() self.decoder_lstm.step() # copy over the hidden and skip variables self._prev_h = self._h self._prev_skip = self._skip # clear the current hidden and skip variables, set the flag self._h = self._skip = None self._obs_encoded = False # use the prior lstm to get generative model input self._gen_input = self.prior_lstm(self._prev_h.unsqueeze(1), detach=False) # set the prior and approximate posterior self.latent_levels[0].generate(self._gen_input.detach().unsqueeze(1), gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def re_init(self, input): """ Method for reinitializing the state (distributions and hidden states). Args: input (Variable, Tensor): contains observation at t = -1 """ # TODO: set n_samples in a smart way # flag to encode the hidden state for later decoding self._obs_encoded = False # re-initialize the lstms and distributions self.latent_levels[0].re_init() self.prior_lstm.re_init(input) self.decoder_lstm.re_init(input) # clear the hidden state and skip self._h = self._skip = None # encode this input to set the previous h and skip self._prev_h, self._prev_skip = self.encoder(input - 0.5) # set the prior and approximate posterior self._gen_input = self.prior_lstm(self._prev_h, detach=False) self.latent_levels[0].generate(self._gen_input.unsqueeze(1), gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def inference_parameters(self): """ Method for obtaining the inference parameters. """ params = nn.ParameterList() if self.modified: params.extend(list(self.inf_model.parameters())) if self.inference_procedure == 'direct': params.extend(list(self.inf_encoder.parameters())) elif self.inference_procedure == 'gradient': pass # no other inference parameters elif self.inference_procedure == 'error': params.extend(list(self.obs_error_enc.parameters())) else: raise NotImplementedError else: params.extend(list(self.encoder.parameters())) params.extend(list(self.latent_levels[0].inference_parameters())) return params def generative_parameters(self): """ Method for obtaining the generative parameters. """ params = nn.ParameterList() if self.modified: params.extend(list(self.encoder.parameters())) params.extend(list(self.prior_lstm.parameters())) params.extend(list(self.decoder.parameters())) params.extend(list(self.latent_levels[0].generative_parameters())) params.extend(list(self.decoder_lstm.parameters())) params.extend(list(self.decoder_lstm_output.parameters())) if self.model_config['global_output_log_var']: params.append(self.output_log_var) return params def inference_mode(self): """ Method to set the model's current mode to inference. """ self.latent_levels[0].latent.detach = False self._detach_h = True def generative_mode(self): """ Method to set the model's current mode to generation. """ self.latent_levels[0].latent.detach = True self._detach_h = False
class ConvolutionalLatentVariable(LatentVariable): """ A convolutional latent variable. Args: variable_config (dict): dictionary containing variable config parameters """ def __init__(self, variable_config): super(ConvLatentVariable, self).__init__() self.approx_posterior = self.prior = None self._construct(variable_config) def _construct(self, variable_config): """ Constructs the latent variable according to the variable_config dict. Currently hard-coded to Gaussian distributions for both approximate posterior and prior. Args: variable_config (dict): dictionary containing variable config params """ self.n_channels = variable_config['n_channels'] self.filter_size = variable_config['filter_size'] mean = Variable(torch.zeros(1, self.n_channels, 1, 1)) std = Variable(torch.ones(1, self.n_channels, 1, 1)) self.approx_posterior = Normal(mean, std) self.prior = Normal(mean, std) def infer(self, input): """ Method to perform inference. Args: input (Tensor): input to the inference procedure """ self.n_variable_channels = n_variable_channels self.filter_size = filter_size self.posterior_mean = Convolutional(n_input[0], self.n_variable_channels, self.filter_size) self.posterior_mean_gate = Convolutional(n_input[0], self.n_variable_channels, self.filter_size, 'sigmoid') self.posterior_log_var = Convolutional(n_input[0], self.n_variable_channels, self.filter_size) self.posterior_log_var_gate = Convolutional(n_input[0], self.n_variable_channels, self.filter_size, 'sigmoid') self.prior_mean = Convolutional(n_input[1], self.n_variable_channels, self.filter_size) # self.prior_mean_gate = Convolutional(n_input[1], self.n_variable_channels, self.filter_size, 'sigmoid', gate=True) self.prior_log_var = None if not const_prior_var: self.prior_log_var = Convolutional(n_input[1], self.n_variable_channels, self.filter_size) # self.prior_log_var_gate = Convolutional(n_input[1], self.n_variable_channels, self.filter_size, 'sigmoid', gate=True) self.previous_posterior = Normal(self.n_variable_channels) self.posterior = Normal(self.n_variable_channels) self.prior = Normal(self.n_variable_channels) if const_prior_var: self.prior.log_var_trainable() def infer(self, input, n_samples=1): # infer the approximate posterior mean_gate = self.posterior_mean_gate(input) mean_update = self.posterior_mean(input) * mean_gate # self.posterior.mean = self.posterior.mean.detach() + mean_update self.posterior.mean = mean_update log_var_gate = self.posterior_log_var_gate(input) log_var_update = self.posterior_log_var(input) * log_var_gate # self.posterior.log_var = (1. - log_var_gate) * self.posterior.log_var.detach() + log_var_update self.posterior.log_var = log_var_update return self.posterior.sample(n_samples, resample=True) def generate(self, input, gen, n_samples): b, s, c, h, w = input.data.shape input = input.view(-1, c, h, w) # mean_gate = self.prior_mean_gate(input).view(b, s, -1, h, w) mean_update = self.prior_mean(input).view(b, s, -1, h, w) # * mean_gate # self.prior.mean = (1. - mean_gate) * self.posterior.mean.detach() + mean_update self.prior.mean = mean_update # log_var_gate = self.prior_log_var_gate(input).view(b, s, -1, h, w) log_var_update = self.prior_log_var(input).view(b, s, -1, h, w) # * log_var_gate # self.prior.log_var = (1. - log_var_gate) * self.posterior.log_var.detach() + log_var_update self.prior.log_var = log_var_update if gen: return self.prior.sample(n_samples, resample=True) return self.posterior.sample(n_samples, resample=True) def step(self): # set the previous posterior with the current posterior self.previous_posterior.mean = self.posterior.mean.detach() self.previous_posterior.log_var = self.posterior.log_var.detach() def error(self, averaged=True, weighted=False): sample = self.posterior.sample() n_samples = sample.data.shape[1] prior_mean = self.prior.mean.detach() err = sample - prior_mean[:n_samples] if weighted: prior_log_var = self.prior.log_var.detach() err /= prior_log_var if averaged: err = err.mean(dim=1) return err def reset_approx_posterior(self): mean = self.prior.mean.data.clone().mean(dim=1) log_var = self.prior.log_var.data.clone().mean(dim=1) self.posterior.reset(mean, log_var) def reset_prior(self): self.prior.reset() if self.prior_log_var is None: self.prior.log_var_trainable() def reinitialize_variable(self, output_dims): b, _, h, w = output_dims # reinitialize the previous approximate posterior and prior self.previous_posterior.reset() self.previous_posterior.mean = self.previous_posterior.mean.view( 1, 1, 1, 1, 1).repeat(b, 1, self.n_variable_channels, h, w) self.previous_posterior.log_var = self.previous_posterior.log_var.view( 1, 1, 1, 1, 1).repeat(b, 1, self.n_variable_channels, h, w) self.prior.reset() self.prior.mean = self.prior.mean.view(1, 1, 1, 1, 1).repeat( b, 1, self.n_variable_channels, h, w) self.prior.log_var = self.prior.log_var.view(1, 1, 1, 1, 1).repeat( b, 1, self.n_variable_channels, h, w) def inference_model_parameters(self): inference_params = [] inference_params.extend(list(self.posterior_mean.parameters())) inference_params.extend(list(self.posterior_mean_gate.parameters())) inference_params.extend(list(self.posterior_log_var.parameters())) inference_params.extend(list(self.posterior_log_var_gate.parameters())) return inference_params def generative_model_parameters(self): generative_params = [] generative_params.extend(list(self.prior_mean.parameters())) if self.prior_log_var is not None: generative_params.extend(list(self.prior_log_var.parameters())) else: generative_params.append(self.prior.log_var) return generative_params def approx_posterior_parameters(self): return [self.posterior.mean.detach(), self.posterior.log_var.detach()] def approx_posterior_gradients(self): assert self.posterior.mean.grad is not None, 'Approximate posterior gradients are None.' grads = [self.posterior.mean.grad.detach()] grads += [self.posterior.log_var.grad.detach()] for grad in grads: grad.volatile = False return grads
class VRNN(LatentVariableModel): """ Variational recurrent neural network (VRNN) from "A Recurrent Latent Variable Model for Sequential Data," Chung et al., 2015. Args: model_config (dict): dictionary containing model configuration params """ def __init__(self, model_config): super(VRNN, self).__init__(model_config) self._construct(model_config) def _construct(self, model_config): """ Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() if not self.modified: assert self.inference_procedure == 'direct', 'The original model only supports direct inference.' self._detach_h = False latent_config = {} level_config = {} latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if model_type == 'timit': lstm_units = 2000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 self.output_interval = 0.0018190742 elif model_type == 'blizzard': lstm_units = 4000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 # TODO: check if this is correct self.output_interval = 0.0018190742 elif model_type == 'iam_ondb': lstm_units = 1200 encoder_units = 150 prior_units = 150 decoder_units = 250 x_units = 250 z_units = 150 hidden_layers = 1 x_dim = 3 z_dim = 50 elif model_type == 'bball': lstm_units = 1000 encoder_units = 200 prior_units = 200 decoder_units = 200 x_units = 200 z_units = 200 hidden_layers = 2 x_dim = 2 z_dim = 50 self.output_interval = Variable(torch.from_numpy( np.array([1e-5 / 94., 1e-5 / 50.]).astype('float32')), requires_grad=False).cuda() else: raise Exception('VRNN model type must be one of 1) timit, 2) \ blizzard, 3) iam_ondb, or 4) bball. Invalid model \ type: ' + model_type + '.') # LSTM lstm_config = { 'n_layers': 1, 'n_units': lstm_units, 'n_in': x_units + z_units } self.lstm = LSTMNetwork(lstm_config) # x model x_config = { 'n_in': x_dim, 'n_units': x_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.x_model = FullyConnectedNetwork(x_config) # inf model if self.modified: if self.inference_procedure in ['direct', 'gradient', 'error']: # set the input encoding size if self.inference_procedure == 'direct': input_dim = x_dim elif self.inference_procedure == 'gradient': latent_config['update_type'] = model_config['update_type'] input_dim = 4 * z_dim if model_config['concat_observation']: input_dim += x_dim elif self.inference_procedure == 'error': latent_config['update_type'] = model_config['update_type'] input_dim = x_dim + 3 * z_dim if model_config['concat_observation']: input_dim += x_dim else: raise NotImplementedError encoder_units = 1024 inf_config = { 'n_in': input_dim, 'n_units': encoder_units, 'n_layers': 2, 'non_linearity': 'elu' } inf_config['connection_type'] = 'highway' # self.inf_model = FullyConnectedNetwork(inf_config) else: inf_config = None latent_config['inf_lr'] = model_config['learning_rate'] else: inf_input_units = lstm_units + x_units inf_config = { 'n_in': inf_input_units, 'n_units': encoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } # latent level (encoder model and prior model) level_config['inference_config'] = inf_config gen_config = { 'n_in': lstm_units, 'n_units': prior_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } level_config['generative_config'] = gen_config latent_config['n_variables'] = z_dim latent_config['n_in'] = (encoder_units, prior_units) latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] # latent_config['n_in'] = (encoder_units+input_dim, prior_units) level_config['latent_config'] = latent_config latent = FullyConnectedLatentLevel(level_config) self.latent_levels = nn.ModuleList([latent]) # z model z_config = { 'n_in': z_dim, 'n_units': z_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.z_model = FullyConnectedNetwork(z_config) # decoder decoder_config = { 'n_in': lstm_units + z_units, 'n_units': decoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.decoder_model = FullyConnectedNetwork(decoder_config) self.output_dist = Normal() self.output_mean = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim }) if model_config['global_output_log_var']: self.output_log_var = nn.Parameter(torch.zeros(x_dim)) else: self.output_log_var = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim }) def _get_encoding_form(self, observation): """ Gets the appropriate input form for the inference procedure. Args: observation (Variable, tensor): the input observation """ if self.inference_procedure == 'direct': return observation elif self.inference_procedure == 'gradient': grads = self.latent_levels[0].latent.approx_posterior_gradients() # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, grad in enumerate(grads): mean = grad.mean(dim=norm_dim, keepdim=True) std = grad.std(dim=norm_dim, keepdim=True) grads[ind] = (grad - mean) / (std + 1e-7) grads = torch.cat(grads, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) grads_params = torch.cat([grads, params], dim=1) # concatenate with the observation if self.model_config['concat_observation']: grads_params = torch.cat([grads_params, observation], dim=1) return grads_params elif self.inference_procedure == 'error': errors = [ self._output_error(observation), self.latent_levels[0].latent.error() ] # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, error in enumerate(errors): mean = error.mean(dim=0, keepdim=True) std = error.std(dim=0, keepdim=True) errors[ind] = (error - mean) / (std + 1e-7) errors = torch.cat(errors, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) error_params = torch.cat([errors, params], dim=1) if self.model_config['concat_observation']: error_params = torch.cat([error_params, observation], dim=1) return error_params elif self.inference_procedure == 'sgd': grads = self.latent_levels[0].latent.approx_posterior_gradients() return grads else: raise NotImplementedError def _output_error(self, observation, averaged=True): """ Calculates Gaussian error for encoding. Args: observation (tensor): observation to use for error calculation """ output_mean = self.output_dist.mean.detach() output_log_var = self.output_dist.log_var.detach() n_samples = output_mean.data.shape[1] if len(observation.data.shape) == 2: observation = observation.unsqueeze(1).repeat(1, n_samples, 1) n_error = (observation - output_mean) / torch.exp(output_log_var + 1e-7) if averaged: n_error = n_error.mean(dim=1) return n_error def infer(self, observation): """ Method for perfoming inference of the approximate posterior over the latent variables. Args: observation (tensor): observation to infer latent variables from """ self._x_enc = self.x_model(observation) if self.modified: enc = self._get_encoding_form(observation) self.latent_levels[0].infer(enc) else: inf_input = self._x_enc prev_h = self._prev_h # prev_h = prev_h.detach() if self._detach_h else prev_h enc = torch.cat([inf_input, prev_h], dim=1) self.latent_levels[0].infer(enc) def generate(self, gen=False, n_samples=1): """ Method for generating observations, i.e. running the generative model forward. Args: gen (boolean): whether to sample from prior or approximate posterior n_samples (int): number of samples to draw and evaluate """ # TODO: handle sampling dimension # possibly detach the hidden state, preventing backprop prev_h = self._prev_h.unsqueeze(1) prev_h = prev_h.detach() if self._detach_h else prev_h # generate the prior z = self.latent_levels[0].generate(prev_h, gen=gen, n_samples=n_samples) # transform z through the z model b, s, _ = z.data.shape self._z_enc = self.z_model(z.view(b * s, -1)).view(b, s, -1) # pass encoded z and previous h through the decoder model dec = torch.cat([self._z_enc, prev_h.repeat(1, s, 1)], dim=2) b, s, _ = dec.data.shape output = self.decoder_model(dec.view(b * s, -1)).view(b, s, -1) # get the output mean and log variance self.output_dist.mean = self.output_mean(output) if self.model_config['global_output_log_var']: b, s = output.data.shape[0], output.data.shape[1] log_var = self.output_log_var.view(1, 1, -1).repeat(b, s, 1) self.output_dist.log_var = torch.clamp(log_var, min=-20., max=5) else: self.output_dist.log_var = torch.clamp(self.output_log_var(output), min=-20., max=5) return self.output_dist.sample() def step(self, n_samples=1): """ Method for stepping the generative model forward one step in the sequence. """ # TODO: handle sampling dimension self._prev_h = self.lstm( torch.cat([self._x_enc, self._z_enc[:, 0]], dim=1)) prev_h = self._prev_h.unsqueeze(1) self.lstm.step() # get the prior, use it to initialize the approximate posterior self.latent_levels[0].generate(prev_h, gen=True, n_samples=n_samples) self.latent_levels[0].latent.re_init_approx_posterior() def re_init(self, input): """ Method for reinitializing the state (approximate posterior and priors) of the dynamical latent variable model. """ # re-initialize the LSTM hidden and cell states self.lstm.re_init(input) # set the previous hidden state, add sample dimension self._prev_h = self.lstm.layers[0].hidden_state prev_h = self._prev_h.unsqueeze(1) # get the prior, use it to initialize the approximate posterior self.latent_levels[0].generate(prev_h, gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def inference_parameters(self): """ Method for obtaining the inference parameters. """ params = nn.ParameterList() if self.inference_procedure != 'sgd': params.extend(list(self.latent_levels[0].inference_parameters())) return params def generative_parameters(self): """ Method for obtaining the generative parameters. """ params = nn.ParameterList() params.extend(list(self.lstm.parameters())) params.extend(list(self.latent_levels[0].generative_parameters())) params.extend(list(self.x_model.parameters())) params.extend(list(self.z_model.parameters())) params.extend(list(self.decoder_model.parameters())) params.extend(list(self.output_mean.parameters())) if self.model_config['global_output_log_var']: params.append(self.output_log_var) else: params.extend(list(self.output_log_var.parameters())) return params def inference_mode(self): """ Method to set the model's current mode to inference. """ self.latent_levels[0].latent.detach = False self._detach_h = True def generative_mode(self): """ Method to set the model's current mode to generation. """ self.latent_levels[0].latent.detach = True self._detach_h = False