def _construct(self, network_config): """ Method to construct the network from the network_config dictionary parameters. """ self.layers = nn.ModuleList([]) self.gates = nn.ModuleList([]) if 'connection_type' in network_config: connection_types = [ 'sequential', 'residual', 'highway', 'concat_input', 'concat' ] assert network_config[ 'connection_type'] in connection_types, 'Connection type not found.' self.connection_type = network_config['connection_type'] else: self.connection_type = 'sequential' n_in = network_config['n_in'] n_in_orig = network_config['n_in'] n_units = network_config['n_units'] if self.connection_type in ['residual', 'highway']: # intial linear layer to embed to correct size self.initial_fc = FullyConnected({'n_in': n_in, 'n_out': n_units}) for _ in range(network_config['n_layers']): self.layers.append(LSTMLayer({'n_in': n_in, 'n_units': n_units})) if self.connection_type == 'highway': self.gates.append( FullyConnectedLayer({ 'n_in': n_in, 'n_out': n_units, 'non_linearity': 'sigmoid' })) if self.connection_type in ['sequential', 'residual', 'highway']: n_in = n_units elif self.connection_type == 'concat_input': n_in = n_units + n_in_orig elif self.connection_type == 'concat': n_in += n_units output_size = n_in
class FullyConnectedLatentVariable(LatentVariable): """ A fully-connected (Gaussian) latent variable. Args: latent_config (dict): dictionary containing variable configuration parameters: n_variables, n_inputs, inference_procedure """ def __init__(self, latent_config): super(FullyConnectedLatentVariable, self).__init__(latent_config) self._construct(latent_config) def _construct(self, latent_config): """ Method to construct the latent variable from the latent_config dictionary """ self.inference_procedure = latent_config['inference_procedure'] if self.inference_procedure in ['gradient', 'error']: self.update_type = latent_config['update_type'] n_variables = latent_config['n_variables'] n_inputs = latent_config['n_in'] self.normalize_samples = latent_config['normalize_samples'] if self.normalize_samples: self.normalizer = LayerNorm() if self.inference_procedure in ['direct', 'gradient', 'error']: # approximate posterior inputs self.inf_mean_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) self.inf_log_var_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) # self.approx_post_mean = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) # self.approx_post_log_var = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) if self.inference_procedure in ['gradient', 'error']: self.approx_post_mean_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) self.approx_post_log_var_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) # self.close_gates() if self.inference_procedure == 'sgd': self.learning_rate = latent_config['inf_lr'] # prior inputs self.prior_mean = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) self.prior_log_var = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) # distributions self.approx_post = Normal() self.prior = Normal() self.approx_post.re_init() self.prior.re_init() def infer(self, input): """ Method to perform inference. Args: input (Tensor): input to the inference procedure """ if self.inference_procedure in ['direct', 'gradient', 'error']: approx_post_mean = self.inf_mean_output(input) approx_post_log_var = self.inf_log_var_output(input) # approx_post_mean = self.approx_post_mean(input) # approx_post_log_var = self.approx_post_log_var(input) if self.inference_procedure == 'direct': self.approx_post.mean = approx_post_mean self.approx_post.log_var = torch.clamp(approx_post_log_var, -15, 5) elif self.inference_procedure in ['gradient', 'error']: if self.update_type == 'highway': # gated highway update approx_post_mean_gate = self.approx_post_mean_gate(input) self.approx_post.mean = approx_post_mean_gate * self.approx_post.mean.detach() \ + (1 - approx_post_mean_gate) * approx_post_mean approx_post_log_var_gate = self.approx_post_log_var_gate(input) self.approx_post.log_var = torch.clamp(approx_post_log_var_gate * self.approx_post.log_var.detach() \ + (1 - approx_post_log_var_gate) * approx_post_log_var, -15, 5) elif self.update_type == 'learned_sgd': # SGD style update with learned learning rate and offset mean_grad, log_var_grad = self.approx_posterior_gradients() mean_lr = self.approx_post_mean_gate(input) log_var_lr = self.approx_post_log_var_gate(input) self.approx_post.mean = self.approx_post.mean.detach( ) - mean_lr * mean_grad + approx_post_mean self.approx_post.log_var = torch.clamp( self.approx_post.log_var.detach() - log_var_lr * log_var_grad + approx_post_log_var, -15, 5) elif self.inference_procedure == 'sgd': self.approx_post.mean = self.approx_post.mean.detach( ) - self.learning_rate * input[0] self.approx_post.log_var = torch.clamp( self.approx_post.log_var.detach() - self.learning_rate * input[1], -15, 5) self.approx_post.mean.requires_grad = True self.approx_post.log_var.requires_grad = True else: raise NotImplementedError if self.normalize_samples: # apply layer normalization to the approximate posterior means self.approx_post.mean = self.normalizer(self.approx_post.mean) # retain the gradients (for inference) self.approx_post.mean.retain_grad() self.approx_post.log_var.retain_grad() def generate(self, input, gen, n_samples): """ Method to generate, i.e. run the model forward. Args: input (Tensor): input to the generative procedure gen (boolean): whether to sample from approximate poserior (False) or the prior (True) n_samples (int): number of samples to draw """ if input is not None: b, s, n = input.data.shape input = input.view(b * s, n) self.prior.mean = self.prior_mean(input).view(b, s, -1) self.prior.log_var = torch.clamp( self.prior_log_var(input).view(b, s, -1), -15, 5) dist = self.prior if gen else self.approx_post sample = dist.sample(n_samples, resample=True) sample = sample.detach() if self.detach else sample return sample def re_init(self): """ Method to reinitialize the approximate posterior and prior over the variable. """ # TODO: this is wrong. we shouldnt set the posterior to the prior then zero out the prior... self.re_init_approx_posterior() self.prior.re_init() def re_init_approx_posterior(self): """ Method to reinitialize the approximate posterior. """ mean = self.prior.mean.detach().mean(dim=1).data log_var = self.prior.log_var.detach().mean(dim=1).data self.approx_post.re_init(mean, log_var) def step(self): """ Method to step the latent variable forward in the sequence. """ pass def error(self, averaged=True): """ Calculates Gaussian error for encoding. Args: averaged (boolean): whether or not to average over samples """ sample = self.approx_post.sample() n_samples = sample.data.shape[1] prior_mean = self.prior.mean.detach() if len(prior_mean.data.shape) == 2: prior_mean = prior_mean.unsqueeze(1).repeat(1, n_samples, 1) prior_log_var = self.prior.log_var.detach() if len(prior_log_var.data.shape) == 2: prior_log_var = prior_log_var.unsqueeze(1).repeat(1, n_samples, 1) n_error = (sample - prior_mean) / torch.exp(prior_log_var + 1e-7) if averaged: n_error = n_error.mean(dim=1) return n_error def close_gates(self): nn.init.constant(self.approx_post_mean_gate.linear.bias, 5.) nn.init.constant(self.approx_post_log_var_gate.linear.bias, 5.) def inference_parameters(self): """ Method to obtain inference parameters. """ params = nn.ParameterList() params.extend(list(self.inf_mean_output.parameters())) params.extend(list(self.inf_log_var_output.parameters())) # params.extend(list(self.approx_post_mean.parameters())) # params.extend(list(self.approx_post_log_var.parameters())) if self.inference_procedure != 'direct': params.extend(list(self.approx_post_mean_gate.parameters())) params.extend(list(self.approx_post_log_var_gate.parameters())) return params def generative_parameters(self): """ Method to obtain generative parameters. """ params = nn.ParameterList() params.extend(list(self.prior_mean.parameters())) params.extend(list(self.prior_log_var.parameters())) return params def approx_posterior_parameters(self): return [ self.approx_post.mean.detach(), self.approx_post.log_var.detach() ] def approx_posterior_gradients(self): assert self.approx_post.mean.grad is not None, 'Approximate posterior gradients are None.' grads = [self.approx_post.mean.grad.detach()] grads += [self.approx_post.log_var.grad.detach()] for grad in grads: grad.volatile = False return grads
def _construct(self, latent_config): """ Method to construct the latent variable from the latent_config dictionary """ self.inference_procedure = latent_config['inference_procedure'] if self.inference_procedure in ['gradient', 'error']: self.update_type = latent_config['update_type'] n_variables = latent_config['n_variables'] n_inputs = latent_config['n_in'] self.normalize_samples = latent_config['normalize_samples'] if self.normalize_samples: self.normalizer = LayerNorm() if self.inference_procedure in ['direct', 'gradient', 'error']: # approximate posterior inputs self.inf_mean_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) self.inf_log_var_output = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables }) # self.approx_post_mean = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) # self.approx_post_log_var = FullyConnectedLayer({'n_in': n_inputs[0], # 'n_out': n_variables}) if self.inference_procedure in ['gradient', 'error']: self.approx_post_mean_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) self.approx_post_log_var_gate = FullyConnectedLayer({ 'n_in': n_inputs[0], 'n_out': n_variables, 'non_linearity': 'sigmoid' }) # self.close_gates() if self.inference_procedure == 'sgd': self.learning_rate = latent_config['inf_lr'] # prior inputs self.prior_mean = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) self.prior_log_var = FullyConnectedLayer({ 'n_in': n_inputs[1], 'n_out': n_variables }) # distributions self.approx_post = Normal() self.prior = Normal() self.approx_post.re_init() self.prior.re_init()
def _construct(self, network_config): """ Method to construct the network from the network_config dictionary parameters. """ self.layers = nn.ModuleList([]) self.gates = nn.ModuleList([]) if 'connection_type' in network_config: connection_types = [ 'sequential', 'residual', 'highway', 'concat_input', 'concat' ] assert network_config[ 'connection_type'] in connection_types, 'Connection type not found.' self.connection_type = network_config['connection_type'] else: self.connection_type = 'sequential' n_in = network_config['n_in'] n_in_orig = network_config['n_in'] n_units = network_config['n_units'] batch_norm = False if 'batch_norm' in network_config: if network_config['batch_norm']: batch_norm = True weight_norm = False if 'weight_norm' in network_config: if network_config['weight_norm']: weight_norm = True non_linearity = 'linear' if 'non_linearity' in network_config: non_linearity = network_config['non_linearity'] dropout = None if 'dropout' in network_config: dropout = network_config['dropout'] output_size = 0 if self.connection_type in ['residual', 'highway']: # intial linear layer to embed to correct size self.initial_fc = FullyConnectedLayer({ 'n_in': n_in, 'n_out': n_units, 'batch_norm': batch_norm, 'weight_norm': weight_norm }) for _ in range(network_config['n_layers']): layer = FullyConnectedLayer({ 'n_in': n_in, 'n_out': n_units, 'non_linearity': non_linearity, 'batch_norm': batch_norm, 'weight_norm': weight_norm, 'dropout': dropout }) self.layers.append(layer) if self.connection_type == 'highway': gate = FullyConnectedLayer({ 'n_in': n_in, 'n_out': n_units, 'non_linearity': 'sigmoid', 'batch_norm': batch_norm, 'weight_norm': weight_norm }) self.gates.append(gate) if self.connection_type in ['sequential', 'residual', 'highway']: n_in = n_units elif self.connection_type == 'concat_input': n_in = n_units + n_in_orig elif self.connection_type == 'concat': n_in += n_units output_size = n_in self.n_out = output_size
class SVG(LatentVariableModel): """ Stochastic video generation (SVG) model from "Stochastic Video Generation with a Learned Prior," Denton & Fergus, 2018. Args: model_config (dict): dictionary containing model configuration params """ def __init__(self, model_config): super(SVG, self).__init__(model_config) self._construct(model_config) def _construct(self, model_config): """ Method for constructing SVG model using the model configuration file. Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() level_config = {} latent_config = {} latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if not self.modified: level_config['inference_config'] = { 'n_layers': 1, 'n_units': 256, 'n_in': 128 } latent_config['n_in'] = (256, 256 ) # number of encoder, decoder units else: level_config['inference_config'] = None latent_config['n_in'] = [None, 256] # number of encoder, decoder units level_config['generative_config'] = None if model_type == 'sm_mnist': from lib.modules.networks.dcgan_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) self.decoder = decoder(128, self.n_input_channels) self.output_dist = Bernoulli() latent_config['n_variables'] = 10 if self.modified: if self.inference_procedure == 'direct': pass elif self.inference_procedure == 'gradient': pass elif self.inference_procedure == 'error': pass else: raise NotImplementedError elif model_type == 'kth_actions': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 512 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model n_units = 1024 inf_config = { 'n_layers': 1, 'n_units': n_units, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'elu', 'connection_type': 'highway' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = n_units latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError elif model_type == 'bair_robot_pushing': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 3 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 64 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError else: raise Exception('SVG model type must be one of 1) sm_mnist, 2) \ kth_action, or 3) bair_robot_pushing. Invalid model \ type: ' + model_type + '.') # construct a recurrent latent level level_config['latent_config'] = latent_config self.latent_levels = nn.ModuleList([LSTMLatentLevel(level_config)]) self.prior_lstm = LSTMNetwork({ 'n_layers': 1, 'n_units': 256, 'n_in': 128 }) self.decoder_lstm = LSTMNetwork({ 'n_layers': 2, 'n_units': 256, 'n_in': 128 + latent_config['n_variables'] }) self.decoder_lstm_output = FullyConnectedLayer({ 'n_in': 256, 'n_out': 128, 'non_linearity': 'tanh' }) self.output_interval = 1. / 256 def _get_encoding_form(self, observation): """ Gets the appropriate input form for the inference procedure. Args: observation (Variable, tensor): the input observation """ if self.inference_procedure == 'direct': return observation - 0.5 if self.inference_procedure == 'gradient': grads = self.latent_levels[0].latent.approx_posterior_gradients() # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, grad in enumerate(grads): mean = grad.mean(dim=norm_dim, keepdim=True) std = grad.std(dim=norm_dim, keepdim=True) grads[ind] = (grad - mean) / (std + 1e-7) grads = torch.cat(grads, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) grads_params = torch.cat([grads, params], dim=1) # concatenate with the observation if self.model_config['concat_observation']: grads_params = torch.cat([grads_params, observation - 0.5], dim=1) return grads_params elif self.inference_procedure == 'error': # TODO: figure out proper normalization for observation error errors = [ self._output_error(observation), self.latent_levels[0].latent.error() ] # normalize for ind, error in enumerate(errors): mean = error.mean(dim=0, keepdim=True) std = error.std(dim=0, keepdim=True) errors[ind] = (error - mean) / (std + 1e-5) # concatenate params = torch.cat( self.latent_levels[0].latent.approx_posterior_parameters(), dim=1) latent_error_params = torch.cat([errors[1], params], dim=1) if self.model_config['concat_observation']: latent_error_params = torch.cat( [latent_error_params, observation - 0.5], dim=1) return errors[0], latent_error_params else: raise NotImplementedError def _output_error(self, observation, averaged=True): """ Calculates Gaussian error for encoding. Args: observation (tensor): observation to use for error calculation """ # get the output mean and log variance output_mean = self.output_dist.mean.detach() output_log_var = self.output_dist.log_var.detach() # repeat the observation along the sample dimension n_samples = output_mean.data.shape[1] observation = observation.unsqueeze(1).repeat(1, n_samples, 1, 1, 1) # calculate the precision-weighted observation error n_error = (observation - output_mean) / (output_log_var.exp() + 1e-7) if averaged: # average along the sample dimension n_error = n_error.mean(dim=1) return n_error def infer(self, observation): """ Method for perfoming inference of the approximate posterior over the latent variables. Args: observation (tensor): observation to infer latent variables from """ if self.modified: if not self._obs_encoded: # encode the observation (to be used at the next time step) self._h, self._skip = self.encoder(observation - 0.5) self._obs_encoded = True enc = self._get_encoding_form(observation) if self.inference_procedure == 'direct': # separate encoder model enc_h, _ = self.inf_encoder(enc) enc_h = self.inf_model(enc_h) elif self.inference_procedure == 'gradient': # encode through the inference model enc_h = self.inf_model(enc) elif self.inference_procedure == 'error': # encode the error and flatten it enc_error = self.obs_error_enc(enc[0]) enc_error = enc_error.view(enc_error.data.shape[0], -1) # concatenate the error with the rest of the terms enc = torch.cat([enc_error, enc[1]], dim=1) # encode through the inference model enc_h = self.inf_model(enc) self.latent_levels[0].infer(enc_h) else: observation = self._get_encoding_form(observation) self._h, self._skip = self.encoder(observation) self.latent_levels[0].infer(self._h) def generate(self, gen=False, n_samples=1): """ Method for generating observations, i.e. running the generative model forward. Args: gen (boolean): whether to sample from prior or approximate posterior n_samples (int): number of samples to draw and evaluate """ batch_size = self._prev_h.data.shape[0] # get the previous h and skip prev_h = self._prev_h.unsqueeze(1) prev_skip = [ 0. * _prev_skip.repeat(n_samples, 1, 1, 1) for _prev_skip in self._prev_skip ] # detach prev_h and prev_skip if necessary if self._detach_h: prev_h = prev_h.detach() prev_skip = [_prev_skip.detach() for _prev_skip in prev_skip] # get the prior input, detach if necessary gen_input = self._gen_input gen_input = gen_input.detach() if self._detach_h else gen_input # sample the latent variables z = self.latent_levels[0].generate(gen_input.unsqueeze(1), gen=gen, n_samples=n_samples) # pass through the decoder decoder_input = torch.cat([z, prev_h], dim=2).view(batch_size * n_samples, -1) g = self.decoder_lstm(decoder_input, detach=self._detach_h) g = self.decoder_lstm_output(g) output = self.decoder([g, prev_skip]) b, _, h, w = output.data.shape # get the output mean and log variance if self.model_config['global_output_log_var']: # repeat along batch and sample dimensions output = output.view(b, -1, self.n_input_channels, h, w) log_var = self.output_log_var.unsqueeze(0).unsqueeze(0).repeat( batch_size, n_samples, 1, 1, 1) self.output_dist.log_var = torch.clamp(log_var, min=-10) else: output = output.view(b, -1, 2 * self.n_input_channels, h, w) self.output_dist.log_var = torch.clamp( output[:, :, self.n_input_channels:, :, :], min=-10) self.output_dist.mean = output[:, :, :self. n_input_channels, :, :].sigmoid() return torch.clamp(self.output_dist.sample(), 0., 1.) def step(self): """ Method for stepping the generative model forward one step in the sequence. """ # TODO: set n_samples in a smart way # step the lstms and latent level self.latent_levels[0].step() self.prior_lstm.step() self.decoder_lstm.step() # copy over the hidden and skip variables self._prev_h = self._h self._prev_skip = self._skip # clear the current hidden and skip variables, set the flag self._h = self._skip = None self._obs_encoded = False # use the prior lstm to get generative model input self._gen_input = self.prior_lstm(self._prev_h.unsqueeze(1), detach=False) # set the prior and approximate posterior self.latent_levels[0].generate(self._gen_input.detach().unsqueeze(1), gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def re_init(self, input): """ Method for reinitializing the state (distributions and hidden states). Args: input (Variable, Tensor): contains observation at t = -1 """ # TODO: set n_samples in a smart way # flag to encode the hidden state for later decoding self._obs_encoded = False # re-initialize the lstms and distributions self.latent_levels[0].re_init() self.prior_lstm.re_init(input) self.decoder_lstm.re_init(input) # clear the hidden state and skip self._h = self._skip = None # encode this input to set the previous h and skip self._prev_h, self._prev_skip = self.encoder(input - 0.5) # set the prior and approximate posterior self._gen_input = self.prior_lstm(self._prev_h, detach=False) self.latent_levels[0].generate(self._gen_input.unsqueeze(1), gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def inference_parameters(self): """ Method for obtaining the inference parameters. """ params = nn.ParameterList() if self.modified: params.extend(list(self.inf_model.parameters())) if self.inference_procedure == 'direct': params.extend(list(self.inf_encoder.parameters())) elif self.inference_procedure == 'gradient': pass # no other inference parameters elif self.inference_procedure == 'error': params.extend(list(self.obs_error_enc.parameters())) else: raise NotImplementedError else: params.extend(list(self.encoder.parameters())) params.extend(list(self.latent_levels[0].inference_parameters())) return params def generative_parameters(self): """ Method for obtaining the generative parameters. """ params = nn.ParameterList() if self.modified: params.extend(list(self.encoder.parameters())) params.extend(list(self.prior_lstm.parameters())) params.extend(list(self.decoder.parameters())) params.extend(list(self.latent_levels[0].generative_parameters())) params.extend(list(self.decoder_lstm.parameters())) params.extend(list(self.decoder_lstm_output.parameters())) if self.model_config['global_output_log_var']: params.append(self.output_log_var) return params def inference_mode(self): """ Method to set the model's current mode to inference. """ self.latent_levels[0].latent.detach = False self._detach_h = True def generative_mode(self): """ Method to set the model's current mode to generation. """ self.latent_levels[0].latent.detach = True self._detach_h = False
def _construct(self, model_config): """ Method for constructing SVG model using the model configuration file. Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() level_config = {} latent_config = {} latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if not self.modified: level_config['inference_config'] = { 'n_layers': 1, 'n_units': 256, 'n_in': 128 } latent_config['n_in'] = (256, 256 ) # number of encoder, decoder units else: level_config['inference_config'] = None latent_config['n_in'] = [None, 256] # number of encoder, decoder units level_config['generative_config'] = None if model_type == 'sm_mnist': from lib.modules.networks.dcgan_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) self.decoder = decoder(128, self.n_input_channels) self.output_dist = Bernoulli() latent_config['n_variables'] = 10 if self.modified: if self.inference_procedure == 'direct': pass elif self.inference_procedure == 'gradient': pass elif self.inference_procedure == 'error': pass else: raise NotImplementedError elif model_type == 'kth_actions': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 1 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 512 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model n_units = 1024 inf_config = { 'n_layers': 1, 'n_units': n_units, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'elu', 'connection_type': 'highway' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = n_units latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError elif model_type == 'bair_robot_pushing': from lib.modules.networks.vgg_64 import encoder, decoder self.n_input_channels = 3 self.encoder = encoder(128, self.n_input_channels) if model_config['global_output_log_var']: output_channels = self.n_input_channels self.output_log_var = nn.Parameter( torch.zeros(self.n_input_channels, 64, 64)) else: output_channels = 2 * self.n_input_channels self.decoder = decoder(128, output_channels) self.output_dist = Normal() latent_config['n_variables'] = 64 if self.modified: if self.inference_procedure == 'direct': # another convolutional encoder self.inf_encoder = encoder(128, self.n_input_channels) # fully-connected inference model inf_config = { 'n_layers': 2, 'n_units': 256, 'n_in': 128, 'non_linearity': 'relu' } self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 256 elif self.inference_procedure == 'gradient': # fully-connected encoder / latent inference model inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] elif self.inference_procedure == 'error': # convolutional observation error encoder obs_error_enc_config = { 'n_layers': 3, 'n_filters': 64, 'n_in': self.n_input_channels, 'filter_size': 3, 'non_linearity': 'relu' } if model_config['concat_observation']: obs_error_enc_config['n_in'] += self.n_input_channels self.obs_error_enc = ConvolutionalNetwork( obs_error_enc_config) # fully-connected error encoder (latent error + params + encoded observation errors) inf_config = { 'n_layers': 3, 'n_units': 1024, 'n_in': 4 * latent_config['n_variables'], 'non_linearity': 'relu' } if model_config['concat_observation']: inf_config['n_in'] += (self.n_input_channels * 64 * 64) self.inf_model = FullyConnectedNetwork(inf_config) latent_config['n_in'][0] = 1024 latent_config['update_type'] = model_config['update_type'] else: raise NotImplementedError else: raise Exception('SVG model type must be one of 1) sm_mnist, 2) \ kth_action, or 3) bair_robot_pushing. Invalid model \ type: ' + model_type + '.') # construct a recurrent latent level level_config['latent_config'] = latent_config self.latent_levels = nn.ModuleList([LSTMLatentLevel(level_config)]) self.prior_lstm = LSTMNetwork({ 'n_layers': 1, 'n_units': 256, 'n_in': 128 }) self.decoder_lstm = LSTMNetwork({ 'n_layers': 2, 'n_units': 256, 'n_in': 128 + latent_config['n_variables'] }) self.decoder_lstm_output = FullyConnectedLayer({ 'n_in': 256, 'n_out': 128, 'non_linearity': 'tanh' }) self.output_interval = 1. / 256
def _construct(self, model_config): """ Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() if not self.modified: assert self.inference_procedure == 'direct', 'The original model only supports direct inference.' self._detach_h = False latent_config = {} level_config = {} latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if model_type == 'timit': lstm_units = 2000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 self.output_interval = 0.0018190742 elif model_type == 'blizzard': lstm_units = 4000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 # TODO: check if this is correct self.output_interval = 0.0018190742 elif model_type == 'iam_ondb': lstm_units = 1200 encoder_units = 150 prior_units = 150 decoder_units = 250 x_units = 250 z_units = 150 hidden_layers = 1 x_dim = 3 z_dim = 50 elif model_type == 'bball': lstm_units = 1000 encoder_units = 200 prior_units = 200 decoder_units = 200 x_units = 200 z_units = 200 hidden_layers = 2 x_dim = 2 z_dim = 50 self.output_interval = Variable(torch.from_numpy( np.array([1e-5 / 94., 1e-5 / 50.]).astype('float32')), requires_grad=False).cuda() else: raise Exception('VRNN model type must be one of 1) timit, 2) \ blizzard, 3) iam_ondb, or 4) bball. Invalid model \ type: ' + model_type + '.') # LSTM lstm_config = { 'n_layers': 1, 'n_units': lstm_units, 'n_in': x_units + z_units } self.lstm = LSTMNetwork(lstm_config) # x model x_config = { 'n_in': x_dim, 'n_units': x_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.x_model = FullyConnectedNetwork(x_config) # inf model if self.modified: if self.inference_procedure in ['direct', 'gradient', 'error']: # set the input encoding size if self.inference_procedure == 'direct': input_dim = x_dim elif self.inference_procedure == 'gradient': latent_config['update_type'] = model_config['update_type'] input_dim = 4 * z_dim if model_config['concat_observation']: input_dim += x_dim elif self.inference_procedure == 'error': latent_config['update_type'] = model_config['update_type'] input_dim = x_dim + 3 * z_dim if model_config['concat_observation']: input_dim += x_dim else: raise NotImplementedError encoder_units = 1024 inf_config = { 'n_in': input_dim, 'n_units': encoder_units, 'n_layers': 2, 'non_linearity': 'elu' } inf_config['connection_type'] = 'highway' # self.inf_model = FullyConnectedNetwork(inf_config) else: inf_config = None latent_config['inf_lr'] = model_config['learning_rate'] else: inf_input_units = lstm_units + x_units inf_config = { 'n_in': inf_input_units, 'n_units': encoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } # latent level (encoder model and prior model) level_config['inference_config'] = inf_config gen_config = { 'n_in': lstm_units, 'n_units': prior_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } level_config['generative_config'] = gen_config latent_config['n_variables'] = z_dim latent_config['n_in'] = (encoder_units, prior_units) latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] # latent_config['n_in'] = (encoder_units+input_dim, prior_units) level_config['latent_config'] = latent_config latent = FullyConnectedLatentLevel(level_config) self.latent_levels = nn.ModuleList([latent]) # z model z_config = { 'n_in': z_dim, 'n_units': z_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.z_model = FullyConnectedNetwork(z_config) # decoder decoder_config = { 'n_in': lstm_units + z_units, 'n_units': decoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.decoder_model = FullyConnectedNetwork(decoder_config) self.output_dist = Normal() self.output_mean = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim }) if model_config['global_output_log_var']: self.output_log_var = nn.Parameter(torch.zeros(x_dim)) else: self.output_log_var = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim })
class VRNN(LatentVariableModel): """ Variational recurrent neural network (VRNN) from "A Recurrent Latent Variable Model for Sequential Data," Chung et al., 2015. Args: model_config (dict): dictionary containing model configuration params """ def __init__(self, model_config): super(VRNN, self).__init__(model_config) self._construct(model_config) def _construct(self, model_config): """ Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() if not self.modified: assert self.inference_procedure == 'direct', 'The original model only supports direct inference.' self._detach_h = False latent_config = {} level_config = {} latent_config['inference_procedure'] = self.inference_procedure # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if model_type == 'timit': lstm_units = 2000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 self.output_interval = 0.0018190742 elif model_type == 'blizzard': lstm_units = 4000 encoder_units = 500 prior_units = 500 decoder_units = 600 x_units = 600 z_units = 500 hidden_layers = 4 x_dim = 200 z_dim = 200 # TODO: check if this is correct self.output_interval = 0.0018190742 elif model_type == 'iam_ondb': lstm_units = 1200 encoder_units = 150 prior_units = 150 decoder_units = 250 x_units = 250 z_units = 150 hidden_layers = 1 x_dim = 3 z_dim = 50 elif model_type == 'bball': lstm_units = 1000 encoder_units = 200 prior_units = 200 decoder_units = 200 x_units = 200 z_units = 200 hidden_layers = 2 x_dim = 2 z_dim = 50 self.output_interval = Variable(torch.from_numpy( np.array([1e-5 / 94., 1e-5 / 50.]).astype('float32')), requires_grad=False).cuda() else: raise Exception('VRNN model type must be one of 1) timit, 2) \ blizzard, 3) iam_ondb, or 4) bball. Invalid model \ type: ' + model_type + '.') # LSTM lstm_config = { 'n_layers': 1, 'n_units': lstm_units, 'n_in': x_units + z_units } self.lstm = LSTMNetwork(lstm_config) # x model x_config = { 'n_in': x_dim, 'n_units': x_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.x_model = FullyConnectedNetwork(x_config) # inf model if self.modified: if self.inference_procedure in ['direct', 'gradient', 'error']: # set the input encoding size if self.inference_procedure == 'direct': input_dim = x_dim elif self.inference_procedure == 'gradient': latent_config['update_type'] = model_config['update_type'] input_dim = 4 * z_dim if model_config['concat_observation']: input_dim += x_dim elif self.inference_procedure == 'error': latent_config['update_type'] = model_config['update_type'] input_dim = x_dim + 3 * z_dim if model_config['concat_observation']: input_dim += x_dim else: raise NotImplementedError encoder_units = 1024 inf_config = { 'n_in': input_dim, 'n_units': encoder_units, 'n_layers': 2, 'non_linearity': 'elu' } inf_config['connection_type'] = 'highway' # self.inf_model = FullyConnectedNetwork(inf_config) else: inf_config = None latent_config['inf_lr'] = model_config['learning_rate'] else: inf_input_units = lstm_units + x_units inf_config = { 'n_in': inf_input_units, 'n_units': encoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } # latent level (encoder model and prior model) level_config['inference_config'] = inf_config gen_config = { 'n_in': lstm_units, 'n_units': prior_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } level_config['generative_config'] = gen_config latent_config['n_variables'] = z_dim latent_config['n_in'] = (encoder_units, prior_units) latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] # latent_config['n_in'] = (encoder_units+input_dim, prior_units) level_config['latent_config'] = latent_config latent = FullyConnectedLatentLevel(level_config) self.latent_levels = nn.ModuleList([latent]) # z model z_config = { 'n_in': z_dim, 'n_units': z_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.z_model = FullyConnectedNetwork(z_config) # decoder decoder_config = { 'n_in': lstm_units + z_units, 'n_units': decoder_units, 'n_layers': hidden_layers, 'non_linearity': 'relu' } self.decoder_model = FullyConnectedNetwork(decoder_config) self.output_dist = Normal() self.output_mean = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim }) if model_config['global_output_log_var']: self.output_log_var = nn.Parameter(torch.zeros(x_dim)) else: self.output_log_var = FullyConnectedLayer({ 'n_in': decoder_units, 'n_out': x_dim }) def _get_encoding_form(self, observation): """ Gets the appropriate input form for the inference procedure. Args: observation (Variable, tensor): the input observation """ if self.inference_procedure == 'direct': return observation elif self.inference_procedure == 'gradient': grads = self.latent_levels[0].latent.approx_posterior_gradients() # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, grad in enumerate(grads): mean = grad.mean(dim=norm_dim, keepdim=True) std = grad.std(dim=norm_dim, keepdim=True) grads[ind] = (grad - mean) / (std + 1e-7) grads = torch.cat(grads, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) grads_params = torch.cat([grads, params], dim=1) # concatenate with the observation if self.model_config['concat_observation']: grads_params = torch.cat([grads_params, observation], dim=1) return grads_params elif self.inference_procedure == 'error': errors = [ self._output_error(observation), self.latent_levels[0].latent.error() ] # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, error in enumerate(errors): mean = error.mean(dim=0, keepdim=True) std = error.std(dim=0, keepdim=True) errors[ind] = (error - mean) / (std + 1e-7) errors = torch.cat(errors, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) error_params = torch.cat([errors, params], dim=1) if self.model_config['concat_observation']: error_params = torch.cat([error_params, observation], dim=1) return error_params elif self.inference_procedure == 'sgd': grads = self.latent_levels[0].latent.approx_posterior_gradients() return grads else: raise NotImplementedError def _output_error(self, observation, averaged=True): """ Calculates Gaussian error for encoding. Args: observation (tensor): observation to use for error calculation """ output_mean = self.output_dist.mean.detach() output_log_var = self.output_dist.log_var.detach() n_samples = output_mean.data.shape[1] if len(observation.data.shape) == 2: observation = observation.unsqueeze(1).repeat(1, n_samples, 1) n_error = (observation - output_mean) / torch.exp(output_log_var + 1e-7) if averaged: n_error = n_error.mean(dim=1) return n_error def infer(self, observation): """ Method for perfoming inference of the approximate posterior over the latent variables. Args: observation (tensor): observation to infer latent variables from """ self._x_enc = self.x_model(observation) if self.modified: enc = self._get_encoding_form(observation) self.latent_levels[0].infer(enc) else: inf_input = self._x_enc prev_h = self._prev_h # prev_h = prev_h.detach() if self._detach_h else prev_h enc = torch.cat([inf_input, prev_h], dim=1) self.latent_levels[0].infer(enc) def generate(self, gen=False, n_samples=1): """ Method for generating observations, i.e. running the generative model forward. Args: gen (boolean): whether to sample from prior or approximate posterior n_samples (int): number of samples to draw and evaluate """ # TODO: handle sampling dimension # possibly detach the hidden state, preventing backprop prev_h = self._prev_h.unsqueeze(1) prev_h = prev_h.detach() if self._detach_h else prev_h # generate the prior z = self.latent_levels[0].generate(prev_h, gen=gen, n_samples=n_samples) # transform z through the z model b, s, _ = z.data.shape self._z_enc = self.z_model(z.view(b * s, -1)).view(b, s, -1) # pass encoded z and previous h through the decoder model dec = torch.cat([self._z_enc, prev_h.repeat(1, s, 1)], dim=2) b, s, _ = dec.data.shape output = self.decoder_model(dec.view(b * s, -1)).view(b, s, -1) # get the output mean and log variance self.output_dist.mean = self.output_mean(output) if self.model_config['global_output_log_var']: b, s = output.data.shape[0], output.data.shape[1] log_var = self.output_log_var.view(1, 1, -1).repeat(b, s, 1) self.output_dist.log_var = torch.clamp(log_var, min=-20., max=5) else: self.output_dist.log_var = torch.clamp(self.output_log_var(output), min=-20., max=5) return self.output_dist.sample() def step(self, n_samples=1): """ Method for stepping the generative model forward one step in the sequence. """ # TODO: handle sampling dimension self._prev_h = self.lstm( torch.cat([self._x_enc, self._z_enc[:, 0]], dim=1)) prev_h = self._prev_h.unsqueeze(1) self.lstm.step() # get the prior, use it to initialize the approximate posterior self.latent_levels[0].generate(prev_h, gen=True, n_samples=n_samples) self.latent_levels[0].latent.re_init_approx_posterior() def re_init(self, input): """ Method for reinitializing the state (approximate posterior and priors) of the dynamical latent variable model. """ # re-initialize the LSTM hidden and cell states self.lstm.re_init(input) # set the previous hidden state, add sample dimension self._prev_h = self.lstm.layers[0].hidden_state prev_h = self._prev_h.unsqueeze(1) # get the prior, use it to initialize the approximate posterior self.latent_levels[0].generate(prev_h, gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def inference_parameters(self): """ Method for obtaining the inference parameters. """ params = nn.ParameterList() if self.inference_procedure != 'sgd': params.extend(list(self.latent_levels[0].inference_parameters())) return params def generative_parameters(self): """ Method for obtaining the generative parameters. """ params = nn.ParameterList() params.extend(list(self.lstm.parameters())) params.extend(list(self.latent_levels[0].generative_parameters())) params.extend(list(self.x_model.parameters())) params.extend(list(self.z_model.parameters())) params.extend(list(self.decoder_model.parameters())) params.extend(list(self.output_mean.parameters())) if self.model_config['global_output_log_var']: params.append(self.output_log_var) else: params.extend(list(self.output_log_var.parameters())) return params def inference_mode(self): """ Method to set the model's current mode to inference. """ self.latent_levels[0].latent.detach = False self._detach_h = True def generative_mode(self): """ Method to set the model's current mode to generation. """ self.latent_levels[0].latent.detach = True self._detach_h = False
def _construct(self, model_config): """ Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() if not self.modified: assert self.inference_procedure == 'direct', 'The original model only supports direct inference.' self._detach_h = False latent_config = {} level_config = {} latent_config['inference_procedure'] = self.inference_procedure latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if model_type == 'timit': lstm_units = 1024 x_dim = 200 z_dim = 256 n_layers = 2 n_units = 512 # Gaussian output self.output_interval = 0.0018190742 self.output_dist = Normal() self.output_mean = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim }) if model_config['global_output_log_var']: self.output_log_var = nn.Parameter(torch.zeros(x_dim)) else: self.output_log_var = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim }) elif model_type == 'midi': lstm_units = 300 x_dim = 88 z_dim = 100 n_layers = 1 n_units = 500 # Bernoulli output self.output_dist = Bernoulli() self.output_mean = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim, 'non_linearity': 'sigmoid' }) self.output_log_var = None else: raise Exception('SRNN model type must be one of 1) timit, 2) \ or 4) midi. Invalid model type: ' + model_type + '.') # LSTM lstm_config = {'n_layers': 1, 'n_units': lstm_units, 'n_in': x_dim} self.lstm = LSTMNetwork(lstm_config) # non_linearity = 'sigmoid' # latent level gen_config = { 'n_in': lstm_units + z_dim, 'n_units': n_units, 'n_layers': n_layers, 'non_linearity': 'clipped_leaky_relu' } # gen_config = {'n_in': lstm_units + z_dim, 'n_units': n_units, # 'n_layers': n_layers, 'non_linearity': non_linearity} level_config['generative_config'] = gen_config level_config['inference_config'] = None latent_config['n_variables'] = z_dim if self.modified: inf_model_units = 1024 inf_model_layers = 2 inf_model_config = { 'n_in': 4 * z_dim, 'n_units': inf_model_units, 'n_layers': inf_model_layers, 'non_linearity': 'elu' } # inf_model_config = {'n_in': 4 * z_dim, 'n_units': inf_model_units, # 'n_layers': inf_model_layers, 'non_linearity': non_linearity} if model_config['concat_observation']: inf_model_config['n_in'] += x_dim inf_model_config['connection_type'] = 'highway' latent_config['update_type'] = model_config['update_type'] else: inf_model_units = n_units inf_model_config = { 'n_in': lstm_units + x_dim, 'n_units': n_units, 'n_layers': n_layers, 'non_linearity': 'clipped_leaky_relu' } # inf_model_config = {'n_in': lstm_units + x_dim, 'n_units': n_units, # 'n_layers': n_layers, 'non_linearity': non_linearity} self.inference_model = FullyConnectedNetwork(inf_model_config) latent_config['n_in'] = (inf_model_units, n_units) level_config['latent_config'] = latent_config latent = FullyConnectedLatentLevel(level_config) self.latent_levels = nn.ModuleList([latent]) self._initial_z = nn.Parameter(torch.zeros(z_dim)) # decoder decoder_config = { 'n_in': lstm_units + z_dim, 'n_units': n_units, 'n_layers': 2, 'non_linearity': 'clipped_leaky_relu' } # decoder_config = {'n_in': lstm_units + z_dim, 'n_units': n_units, # 'n_layers': 2, 'non_linearity': non_linearity} self.decoder_model = FullyConnectedNetwork(decoder_config)
class SRNN(LatentVariableModel): """ Stochastic recurrent neural network (SRNN) from Fraccaro et al., 2016. Args: model_config (dict): dictionary containing model configuration params """ def __init__(self, model_config): super(SRNN, self).__init__(model_config) self._construct(model_config) def _construct(self, model_config): """ Args: model_config (dict): dictionary containing model configuration params """ model_type = model_config['model_type'].lower() self.modified = model_config['modified'] self.inference_procedure = model_config['inference_procedure'].lower() if not self.modified: assert self.inference_procedure == 'direct', 'The original model only supports direct inference.' self._detach_h = False latent_config = {} level_config = {} latent_config['inference_procedure'] = self.inference_procedure latent_config['normalize_samples'] = model_config[ 'normalize_latent_samples'] # hard coded because we handle inference here in the model level_config['inference_procedure'] = 'direct' if model_type == 'timit': lstm_units = 1024 x_dim = 200 z_dim = 256 n_layers = 2 n_units = 512 # Gaussian output self.output_interval = 0.0018190742 self.output_dist = Normal() self.output_mean = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim }) if model_config['global_output_log_var']: self.output_log_var = nn.Parameter(torch.zeros(x_dim)) else: self.output_log_var = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim }) elif model_type == 'midi': lstm_units = 300 x_dim = 88 z_dim = 100 n_layers = 1 n_units = 500 # Bernoulli output self.output_dist = Bernoulli() self.output_mean = FullyConnectedLayer({ 'n_in': n_units, 'n_out': x_dim, 'non_linearity': 'sigmoid' }) self.output_log_var = None else: raise Exception('SRNN model type must be one of 1) timit, 2) \ or 4) midi. Invalid model type: ' + model_type + '.') # LSTM lstm_config = {'n_layers': 1, 'n_units': lstm_units, 'n_in': x_dim} self.lstm = LSTMNetwork(lstm_config) # non_linearity = 'sigmoid' # latent level gen_config = { 'n_in': lstm_units + z_dim, 'n_units': n_units, 'n_layers': n_layers, 'non_linearity': 'clipped_leaky_relu' } # gen_config = {'n_in': lstm_units + z_dim, 'n_units': n_units, # 'n_layers': n_layers, 'non_linearity': non_linearity} level_config['generative_config'] = gen_config level_config['inference_config'] = None latent_config['n_variables'] = z_dim if self.modified: inf_model_units = 1024 inf_model_layers = 2 inf_model_config = { 'n_in': 4 * z_dim, 'n_units': inf_model_units, 'n_layers': inf_model_layers, 'non_linearity': 'elu' } # inf_model_config = {'n_in': 4 * z_dim, 'n_units': inf_model_units, # 'n_layers': inf_model_layers, 'non_linearity': non_linearity} if model_config['concat_observation']: inf_model_config['n_in'] += x_dim inf_model_config['connection_type'] = 'highway' latent_config['update_type'] = model_config['update_type'] else: inf_model_units = n_units inf_model_config = { 'n_in': lstm_units + x_dim, 'n_units': n_units, 'n_layers': n_layers, 'non_linearity': 'clipped_leaky_relu' } # inf_model_config = {'n_in': lstm_units + x_dim, 'n_units': n_units, # 'n_layers': n_layers, 'non_linearity': non_linearity} self.inference_model = FullyConnectedNetwork(inf_model_config) latent_config['n_in'] = (inf_model_units, n_units) level_config['latent_config'] = latent_config latent = FullyConnectedLatentLevel(level_config) self.latent_levels = nn.ModuleList([latent]) self._initial_z = nn.Parameter(torch.zeros(z_dim)) # decoder decoder_config = { 'n_in': lstm_units + z_dim, 'n_units': n_units, 'n_layers': 2, 'non_linearity': 'clipped_leaky_relu' } # decoder_config = {'n_in': lstm_units + z_dim, 'n_units': n_units, # 'n_layers': 2, 'non_linearity': non_linearity} self.decoder_model = FullyConnectedNetwork(decoder_config) def _get_encoding_form(self, observation): """ Gets the appropriate input form for the inference procedure. Args: observation (Variable, tensor): the input observation """ if self.inference_procedure == 'direct': return observation elif self.inference_procedure == 'gradient': grads = self.latent_levels[0].latent.approx_posterior_gradients() # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, grad in enumerate(grads): mean = grad.mean(dim=norm_dim, keepdim=True) std = grad.std(dim=norm_dim, keepdim=True) grads[ind] = (grad - mean) / (std + 1e-7) grads = torch.cat(grads, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) grads_params = torch.cat([grads, params], dim=1) # concatenate with the observation if self.model_config['concat_observation']: grads_params = torch.cat([grads_params, observation], dim=1) return grads_params elif self.inference_procedure == 'error': errors = [ self._output_error(observation), self.latent_levels[0].latent.error() ] # normalization if self.model_config['input_normalization'] in ['layer', 'batch']: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, error in enumerate(errors): mean = error.mean(dim=norm_dim, keepdim=True) std = error.std(dim=norm_dim, keepdim=True) errors[ind] = (error - mean) / (std + 1e-7) errors = torch.cat(errors, dim=1) # concatenate with the parameters params = self.latent_levels[0].latent.approx_posterior_parameters() if self.model_config['norm_parameters']: if self.model_config['input_normalization'] in [ 'layer', 'batch' ]: norm_dim = 0 if self.model_config[ 'input_normalization'] == 'batch' else 1 for ind, param in enumerate(params): mean = param.mean(dim=norm_dim, keepdim=True) std = param.std(dim=norm_dim, keepdim=True) params[ind] = (param - mean) / (std + 1e-7) params = torch.cat(params, dim=1) error_params = torch.cat([errors, params], dim=1) if self.model_config['concat_observation']: error_params = torch.cat([error_params, observation], dim=1) return error_params elif self.inference_procedure == 'sgd': grads = self.latent_levels[0].latent.approx_posterior_gradients() return grads else: raise NotImplementedError def _output_error(self, observation, averaged=True): """ Calculates Gaussian error for encoding. Args: observation (tensor): observation to use for error calculation """ output_mean = self.output_dist.mean.detach() output_log_var = self.output_dist.log_var.detach() n_samples = output_mean.data.shape[1] if len(observation.data.shape) == 2: observation = observation.unsqueeze(1).repeat(1, n_samples, 1) n_error = (observation - output_mean) / torch.exp(output_log_var + 1e-7) if averaged: n_error = n_error.mean(dim=1) return n_error def infer(self, observation): """ Method for perfoming inference of the approximate posterior over the latent variables. Args: observation (tensor): observation to infer latent variables from """ if self._x is None: # store the observation self._x = observation enc = self._get_encoding_form(observation) if self.modified: pass else: h = self._h enc = torch.cat([enc, h], dim=1) enc = self.inference_model(enc) self.latent_levels[0].infer(enc) def generate(self, gen=False, n_samples=1): """ Method for generating observations, i.e. running the generative model forward. Args: gen (boolean): whether to sample from prior or approximate posterior n_samples (int): number of samples to draw and evaluate """ h = self._h.detach() if self._detach_h else self._h h = h.unsqueeze(1) prev_z = self._prev_z.detach() if self._detach_h else self._prev_z if prev_z.data.shape[1] != n_samples: prev_z = prev_z.repeat(1, n_samples, 1) gen_input = torch.cat([h.repeat(1, n_samples, 1), prev_z], dim=2) self._z = self.latent_levels[0].generate(gen_input, gen=gen, n_samples=n_samples) dec_input = torch.cat([h.repeat(1, n_samples, 1), self._z], dim=2) b, s, _ = dec_input.data.shape dec = self.decoder_model(dec_input.view(b * s, -1)).view(b, s, -1) output_mean = self.output_mean(dec) if self.output_log_var: # Gaussian output if self.model_config['global_output_log_var']: b, s = dec.data.shape[0], dec.data.shape[1] log_var = self.output_log_var.view(1, 1, -1).repeat(b, s, 1) self.output_dist.log_var = torch.clamp(log_var, min=-10) else: output_log_var = torch.clamp(self.output_log_var(dec), min=-10) self.output_dist = Normal(output_mean, output_log_var) else: # Bernoulli output self.output_dist = Bernoulli(output_mean) return self.output_dist.sample() def step(self, n_samples=1): """ Method for stepping the generative model forward one step in the sequence. """ # set the previous z self._prev_z = self._z s = self._prev_z.data.shape[1] # step the LSTM (using the previous observation) self._h = self.lstm(self._x) self.lstm.step() self._x = None # get the prior, use it to initialize the approximate posterior gen_input = torch.cat( [self._h.unsqueeze(1).repeat(1, s, 1), self._prev_z], dim=2) self.latent_levels[0].generate(gen_input, gen=True, n_samples=n_samples) self.latent_levels[0].latent.re_init_approx_posterior() def re_init(self, input): """ Method for reinitializing the state (approximate posterior and priors) of the dynamical latent variable model. """ # re-initialize the LSTM hidden and cell states self.lstm.re_init(input) # set the previous hidden state, add sample dimension self._h = self.lstm(input) self.lstm.step() self._x = None self._z = None self._prev_z = self._initial_z.view(1, 1, -1).repeat(input.data.shape[0], 1, 1) # get the prior, use it to initialize the approximate posterior gen_input = torch.cat([self._h.unsqueeze(1), self._prev_z], dim=2) self.latent_levels[0].generate(gen_input, gen=True, n_samples=1) self.latent_levels[0].latent.re_init_approx_posterior() def inference_parameters(self): """ Method for obtaining the inference parameters. """ params = nn.ParameterList() if self.inference_procedure != 'sgd': params.extend(list(self.inference_model.parameters())) params.extend(list(self.latent_levels[0].inference_parameters())) params.append(self._initial_z) return params def generative_parameters(self): """ Method for obtaining the generative parameters. """ params = nn.ParameterList() params.extend(list(self.lstm.parameters())) params.extend(list(self.latent_levels[0].generative_parameters())) params.extend(list(self.decoder_model.parameters())) params.extend(list(self.output_mean.parameters())) if self.output_log_var is not None: if self.model_config['global_output_log_var']: params.append(self.output_log_var) else: params.extend(list(self.output_log_var.parameters())) return params def inference_mode(self): """ Method to set the model's current mode to inference. """ self.latent_levels[0].latent.detach = False self._detach_h = True def generative_mode(self): """ Method to set the model's current mode to generation. """ self.latent_levels[0].latent.detach = True self._detach_h = False