class FullyConnectedLatentVariable(LatentVariable):
    """
    A fully-connected (Gaussian) latent variable.

    Args:
        latent_config (dict): dictionary containing variable configuration
                              parameters: n_variables, n_inputs, inference_procedure
    """
    def __init__(self, latent_config):
        super(FullyConnectedLatentVariable, self).__init__(latent_config)
        self._construct(latent_config)

    def _construct(self, latent_config):
        """
        Method to construct the latent variable from the latent_config dictionary
        """
        self.inference_procedure = latent_config['inference_procedure']
        if self.inference_procedure in ['gradient', 'error']:
            self.update_type = latent_config['update_type']
        n_variables = latent_config['n_variables']
        n_inputs = latent_config['n_in']
        self.normalize_samples = latent_config['normalize_samples']
        if self.normalize_samples:
            self.normalizer = LayerNorm()

        if self.inference_procedure in ['direct', 'gradient', 'error']:
            # approximate posterior inputs
            self.inf_mean_output = FullyConnectedLayer({
                'n_in': n_inputs[0],
                'n_out': n_variables
            })
            self.inf_log_var_output = FullyConnectedLayer({
                'n_in': n_inputs[0],
                'n_out': n_variables
            })
            # self.approx_post_mean = FullyConnectedLayer({'n_in': n_inputs[0],
            #                                             'n_out': n_variables})
            # self.approx_post_log_var = FullyConnectedLayer({'n_in': n_inputs[0],
            #                                                'n_out': n_variables})
        if self.inference_procedure in ['gradient', 'error']:
            self.approx_post_mean_gate = FullyConnectedLayer({
                'n_in':
                n_inputs[0],
                'n_out':
                n_variables,
                'non_linearity':
                'sigmoid'
            })
            self.approx_post_log_var_gate = FullyConnectedLayer({
                'n_in':
                n_inputs[0],
                'n_out':
                n_variables,
                'non_linearity':
                'sigmoid'
            })
            # self.close_gates()
        if self.inference_procedure == 'sgd':
            self.learning_rate = latent_config['inf_lr']

        # prior inputs
        self.prior_mean = FullyConnectedLayer({
            'n_in': n_inputs[1],
            'n_out': n_variables
        })
        self.prior_log_var = FullyConnectedLayer({
            'n_in': n_inputs[1],
            'n_out': n_variables
        })
        # distributions
        self.approx_post = Normal()
        self.prior = Normal()
        self.approx_post.re_init()
        self.prior.re_init()

    def infer(self, input):
        """
        Method to perform inference.

        Args:
            input (Tensor): input to the inference procedure
        """
        if self.inference_procedure in ['direct', 'gradient', 'error']:
            approx_post_mean = self.inf_mean_output(input)
            approx_post_log_var = self.inf_log_var_output(input)
            # approx_post_mean = self.approx_post_mean(input)
            # approx_post_log_var = self.approx_post_log_var(input)
        if self.inference_procedure == 'direct':
            self.approx_post.mean = approx_post_mean
            self.approx_post.log_var = torch.clamp(approx_post_log_var, -15, 5)
        elif self.inference_procedure in ['gradient', 'error']:
            if self.update_type == 'highway':
                # gated highway update
                approx_post_mean_gate = self.approx_post_mean_gate(input)
                self.approx_post.mean = approx_post_mean_gate * self.approx_post.mean.detach() \
                                        + (1 - approx_post_mean_gate) * approx_post_mean
                approx_post_log_var_gate = self.approx_post_log_var_gate(input)
                self.approx_post.log_var = torch.clamp(approx_post_log_var_gate * self.approx_post.log_var.detach() \
                                           + (1 - approx_post_log_var_gate) * approx_post_log_var, -15, 5)
            elif self.update_type == 'learned_sgd':
                # SGD style update with learned learning rate and offset
                mean_grad, log_var_grad = self.approx_posterior_gradients()
                mean_lr = self.approx_post_mean_gate(input)
                log_var_lr = self.approx_post_log_var_gate(input)
                self.approx_post.mean = self.approx_post.mean.detach(
                ) - mean_lr * mean_grad + approx_post_mean
                self.approx_post.log_var = torch.clamp(
                    self.approx_post.log_var.detach() -
                    log_var_lr * log_var_grad + approx_post_log_var, -15, 5)
        elif self.inference_procedure == 'sgd':
            self.approx_post.mean = self.approx_post.mean.detach(
            ) - self.learning_rate * input[0]
            self.approx_post.log_var = torch.clamp(
                self.approx_post.log_var.detach() -
                self.learning_rate * input[1], -15, 5)
            self.approx_post.mean.requires_grad = True
            self.approx_post.log_var.requires_grad = True
        else:
            raise NotImplementedError

        if self.normalize_samples:
            # apply layer normalization to the approximate posterior means
            self.approx_post.mean = self.normalizer(self.approx_post.mean)

        # retain the gradients (for inference)
        self.approx_post.mean.retain_grad()
        self.approx_post.log_var.retain_grad()

    def generate(self, input, gen, n_samples):
        """
        Method to generate, i.e. run the model forward.

        Args:
            input (Tensor): input to the generative procedure
            gen (boolean): whether to sample from approximate poserior (False) or
                            the prior (True)
            n_samples (int): number of samples to draw
        """
        if input is not None:
            b, s, n = input.data.shape
            input = input.view(b * s, n)
            self.prior.mean = self.prior_mean(input).view(b, s, -1)
            self.prior.log_var = torch.clamp(
                self.prior_log_var(input).view(b, s, -1), -15, 5)
        dist = self.prior if gen else self.approx_post
        sample = dist.sample(n_samples, resample=True)
        sample = sample.detach() if self.detach else sample
        return sample

    def re_init(self):
        """
        Method to reinitialize the approximate posterior and prior over the variable.
        """
        # TODO: this is wrong. we shouldnt set the posterior to the prior then zero out the prior...
        self.re_init_approx_posterior()
        self.prior.re_init()

    def re_init_approx_posterior(self):
        """
        Method to reinitialize the approximate posterior.
        """
        mean = self.prior.mean.detach().mean(dim=1).data
        log_var = self.prior.log_var.detach().mean(dim=1).data
        self.approx_post.re_init(mean, log_var)

    def step(self):
        """
        Method to step the latent variable forward in the sequence.
        """
        pass

    def error(self, averaged=True):
        """
        Calculates Gaussian error for encoding.

        Args:
            averaged (boolean): whether or not to average over samples
        """
        sample = self.approx_post.sample()
        n_samples = sample.data.shape[1]
        prior_mean = self.prior.mean.detach()
        if len(prior_mean.data.shape) == 2:
            prior_mean = prior_mean.unsqueeze(1).repeat(1, n_samples, 1)
        prior_log_var = self.prior.log_var.detach()
        if len(prior_log_var.data.shape) == 2:
            prior_log_var = prior_log_var.unsqueeze(1).repeat(1, n_samples, 1)
        n_error = (sample - prior_mean) / torch.exp(prior_log_var + 1e-7)
        if averaged:
            n_error = n_error.mean(dim=1)
        return n_error

    def close_gates(self):
        nn.init.constant(self.approx_post_mean_gate.linear.bias, 5.)
        nn.init.constant(self.approx_post_log_var_gate.linear.bias, 5.)

    def inference_parameters(self):
        """
        Method to obtain inference parameters.
        """
        params = nn.ParameterList()
        params.extend(list(self.inf_mean_output.parameters()))
        params.extend(list(self.inf_log_var_output.parameters()))
        # params.extend(list(self.approx_post_mean.parameters()))
        # params.extend(list(self.approx_post_log_var.parameters()))
        if self.inference_procedure != 'direct':
            params.extend(list(self.approx_post_mean_gate.parameters()))
            params.extend(list(self.approx_post_log_var_gate.parameters()))
        return params

    def generative_parameters(self):
        """
        Method to obtain generative parameters.
        """
        params = nn.ParameterList()
        params.extend(list(self.prior_mean.parameters()))
        params.extend(list(self.prior_log_var.parameters()))
        return params

    def approx_posterior_parameters(self):
        return [
            self.approx_post.mean.detach(),
            self.approx_post.log_var.detach()
        ]

    def approx_posterior_gradients(self):
        assert self.approx_post.mean.grad is not None, 'Approximate posterior gradients are None.'
        grads = [self.approx_post.mean.grad.detach()]
        grads += [self.approx_post.log_var.grad.detach()]
        for grad in grads:
            grad.volatile = False
        return grads
Пример #2
0
class SVG(LatentVariableModel):
    """
    Stochastic video generation (SVG) model from "Stochastic Video Generation
    with a Learned Prior," Denton & Fergus, 2018.

    Args:
        model_config (dict): dictionary containing model configuration params
    """
    def __init__(self, model_config):
        super(SVG, self).__init__(model_config)
        self._construct(model_config)

    def _construct(self, model_config):
        """
        Method for constructing SVG model using the model configuration file.

        Args:
            model_config (dict): dictionary containing model configuration params
        """
        model_type = model_config['model_type'].lower()
        self.modified = model_config['modified']
        self.inference_procedure = model_config['inference_procedure'].lower()

        level_config = {}
        latent_config = {}
        latent_config['normalize_samples'] = model_config[
            'normalize_latent_samples']
        latent_config['inference_procedure'] = self.inference_procedure
        # hard coded because we handle inference here in the model
        level_config['inference_procedure'] = 'direct'

        if not self.modified:
            level_config['inference_config'] = {
                'n_layers': 1,
                'n_units': 256,
                'n_in': 128
            }
            latent_config['n_in'] = (256, 256
                                     )  # number of encoder, decoder units
        else:
            level_config['inference_config'] = None
            latent_config['n_in'] = [None,
                                     256]  # number of encoder, decoder units
        level_config['generative_config'] = None

        if model_type == 'sm_mnist':
            from lib.modules.networks.dcgan_64 import encoder, decoder
            self.n_input_channels = 1
            self.encoder = encoder(128, self.n_input_channels)
            self.decoder = decoder(128, self.n_input_channels)
            self.output_dist = Bernoulli()
            latent_config['n_variables'] = 10
            if self.modified:
                if self.inference_procedure == 'direct':
                    pass
                elif self.inference_procedure == 'gradient':
                    pass
                elif self.inference_procedure == 'error':
                    pass
                else:
                    raise NotImplementedError

        elif model_type == 'kth_actions':
            from lib.modules.networks.vgg_64 import encoder, decoder
            self.n_input_channels = 1
            self.encoder = encoder(128, self.n_input_channels)
            if model_config['global_output_log_var']:
                output_channels = self.n_input_channels
                self.output_log_var = nn.Parameter(
                    torch.zeros(self.n_input_channels, 64, 64))
            else:
                output_channels = 2 * self.n_input_channels
            self.decoder = decoder(128, output_channels)
            self.output_dist = Normal()
            latent_config['n_variables'] = 512
            if self.modified:
                if self.inference_procedure == 'direct':
                    # another convolutional encoder
                    self.inf_encoder = encoder(128, self.n_input_channels)
                    # fully-connected inference model
                    inf_config = {
                        'n_layers': 2,
                        'n_units': 256,
                        'n_in': 128,
                        'non_linearity': 'relu'
                    }
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = 256
                elif self.inference_procedure == 'gradient':
                    # fully-connected encoder / latent inference model
                    n_units = 1024
                    inf_config = {
                        'n_layers': 1,
                        'n_units': n_units,
                        'n_in': 4 * latent_config['n_variables'],
                        'non_linearity': 'elu',
                        'connection_type': 'highway'
                    }
                    if model_config['concat_observation']:
                        inf_config['n_in'] += (self.n_input_channels * 64 * 64)
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = n_units
                    latent_config['update_type'] = model_config['update_type']
                elif self.inference_procedure == 'error':
                    # convolutional observation error encoder
                    obs_error_enc_config = {
                        'n_layers': 3,
                        'n_filters': 64,
                        'n_in': self.n_input_channels,
                        'filter_size': 3,
                        'non_linearity': 'relu'
                    }
                    if model_config['concat_observation']:
                        obs_error_enc_config['n_in'] += self.n_input_channels
                    self.obs_error_enc = ConvolutionalNetwork(
                        obs_error_enc_config)
                    # fully-connected error encoder (latent error + params + encoded observation errors)
                    inf_config = {
                        'n_layers': 3,
                        'n_units': 1024,
                        'n_in': 4 * latent_config['n_variables'],
                        'non_linearity': 'relu'
                    }
                    if model_config['concat_observation']:
                        inf_config['n_in'] += (self.n_input_channels * 64 * 64)
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = 1024
                    latent_config['update_type'] = model_config['update_type']
                else:
                    raise NotImplementedError

        elif model_type == 'bair_robot_pushing':
            from lib.modules.networks.vgg_64 import encoder, decoder
            self.n_input_channels = 3
            self.encoder = encoder(128, self.n_input_channels)
            if model_config['global_output_log_var']:
                output_channels = self.n_input_channels
                self.output_log_var = nn.Parameter(
                    torch.zeros(self.n_input_channels, 64, 64))
            else:
                output_channels = 2 * self.n_input_channels
            self.decoder = decoder(128, output_channels)
            self.output_dist = Normal()
            latent_config['n_variables'] = 64
            if self.modified:
                if self.inference_procedure == 'direct':
                    # another convolutional encoder
                    self.inf_encoder = encoder(128, self.n_input_channels)
                    # fully-connected inference model
                    inf_config = {
                        'n_layers': 2,
                        'n_units': 256,
                        'n_in': 128,
                        'non_linearity': 'relu'
                    }
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = 256
                elif self.inference_procedure == 'gradient':
                    # fully-connected encoder / latent inference model
                    inf_config = {
                        'n_layers': 3,
                        'n_units': 1024,
                        'n_in': 4 * latent_config['n_variables'],
                        'non_linearity': 'relu'
                    }
                    if model_config['concat_observation']:
                        inf_config['n_in'] += (self.n_input_channels * 64 * 64)
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = 1024
                    latent_config['update_type'] = model_config['update_type']
                elif self.inference_procedure == 'error':
                    # convolutional observation error encoder
                    obs_error_enc_config = {
                        'n_layers': 3,
                        'n_filters': 64,
                        'n_in': self.n_input_channels,
                        'filter_size': 3,
                        'non_linearity': 'relu'
                    }
                    if model_config['concat_observation']:
                        obs_error_enc_config['n_in'] += self.n_input_channels
                    self.obs_error_enc = ConvolutionalNetwork(
                        obs_error_enc_config)
                    # fully-connected error encoder (latent error + params + encoded observation errors)
                    inf_config = {
                        'n_layers': 3,
                        'n_units': 1024,
                        'n_in': 4 * latent_config['n_variables'],
                        'non_linearity': 'relu'
                    }
                    if model_config['concat_observation']:
                        inf_config['n_in'] += (self.n_input_channels * 64 * 64)
                    self.inf_model = FullyConnectedNetwork(inf_config)
                    latent_config['n_in'][0] = 1024
                    latent_config['update_type'] = model_config['update_type']
                else:
                    raise NotImplementedError
        else:
            raise Exception('SVG model type must be one of 1) sm_mnist, 2) \
                            kth_action, or 3) bair_robot_pushing. Invalid model \
                            type: ' + model_type + '.')

        # construct a recurrent latent level
        level_config['latent_config'] = latent_config
        self.latent_levels = nn.ModuleList([LSTMLatentLevel(level_config)])

        self.prior_lstm = LSTMNetwork({
            'n_layers': 1,
            'n_units': 256,
            'n_in': 128
        })

        self.decoder_lstm = LSTMNetwork({
            'n_layers':
            2,
            'n_units':
            256,
            'n_in':
            128 + latent_config['n_variables']
        })
        self.decoder_lstm_output = FullyConnectedLayer({
            'n_in': 256,
            'n_out': 128,
            'non_linearity': 'tanh'
        })
        self.output_interval = 1. / 256

    def _get_encoding_form(self, observation):
        """
        Gets the appropriate input form for the inference procedure.

        Args:
            observation (Variable, tensor): the input observation
        """
        if self.inference_procedure == 'direct':
            return observation - 0.5

        if self.inference_procedure == 'gradient':
            grads = self.latent_levels[0].latent.approx_posterior_gradients()

            # normalization
            if self.model_config['input_normalization'] in ['layer', 'batch']:
                norm_dim = 0 if self.model_config[
                    'input_normalization'] == 'batch' else 1
                for ind, grad in enumerate(grads):
                    mean = grad.mean(dim=norm_dim, keepdim=True)
                    std = grad.std(dim=norm_dim, keepdim=True)
                    grads[ind] = (grad - mean) / (std + 1e-7)
            grads = torch.cat(grads, dim=1)

            # concatenate with the parameters
            params = self.latent_levels[0].latent.approx_posterior_parameters()
            if self.model_config['norm_parameters']:
                if self.model_config['input_normalization'] in [
                        'layer', 'batch'
                ]:
                    norm_dim = 0 if self.model_config[
                        'input_normalization'] == 'batch' else 1
                    for ind, param in enumerate(params):
                        mean = param.mean(dim=norm_dim, keepdim=True)
                        std = param.std(dim=norm_dim, keepdim=True)
                        params[ind] = (param - mean) / (std + 1e-7)
            params = torch.cat(params, dim=1)

            grads_params = torch.cat([grads, params], dim=1)

            # concatenate with the observation
            if self.model_config['concat_observation']:
                grads_params = torch.cat([grads_params, observation - 0.5],
                                         dim=1)

            return grads_params

        elif self.inference_procedure == 'error':
            # TODO: figure out proper normalization for observation error
            errors = [
                self._output_error(observation),
                self.latent_levels[0].latent.error()
            ]
            # normalize
            for ind, error in enumerate(errors):
                mean = error.mean(dim=0, keepdim=True)
                std = error.std(dim=0, keepdim=True)
                errors[ind] = (error - mean) / (std + 1e-5)
            # concatenate
            params = torch.cat(
                self.latent_levels[0].latent.approx_posterior_parameters(),
                dim=1)
            latent_error_params = torch.cat([errors[1], params], dim=1)
            if self.model_config['concat_observation']:
                latent_error_params = torch.cat(
                    [latent_error_params, observation - 0.5], dim=1)
            return errors[0], latent_error_params
        else:
            raise NotImplementedError

    def _output_error(self, observation, averaged=True):
        """
        Calculates Gaussian error for encoding.

        Args:
            observation (tensor): observation to use for error calculation
        """
        # get the output mean and log variance
        output_mean = self.output_dist.mean.detach()
        output_log_var = self.output_dist.log_var.detach()
        # repeat the observation along the sample dimension
        n_samples = output_mean.data.shape[1]
        observation = observation.unsqueeze(1).repeat(1, n_samples, 1, 1, 1)
        # calculate the precision-weighted observation error
        n_error = (observation - output_mean) / (output_log_var.exp() + 1e-7)
        if averaged:
            # average along the sample dimension
            n_error = n_error.mean(dim=1)
        return n_error

    def infer(self, observation):
        """
        Method for perfoming inference of the approximate posterior over the
        latent variables.

        Args:
            observation (tensor): observation to infer latent variables from
        """
        if self.modified:
            if not self._obs_encoded:
                # encode the observation (to be used at the next time step)
                self._h, self._skip = self.encoder(observation - 0.5)
                self._obs_encoded = True

            enc = self._get_encoding_form(observation)

            if self.inference_procedure == 'direct':
                # separate encoder model
                enc_h, _ = self.inf_encoder(enc)
                enc_h = self.inf_model(enc_h)

            elif self.inference_procedure == 'gradient':
                # encode through the inference model
                enc_h = self.inf_model(enc)

            elif self.inference_procedure == 'error':
                # encode the error and flatten it
                enc_error = self.obs_error_enc(enc[0])
                enc_error = enc_error.view(enc_error.data.shape[0], -1)
                # concatenate the error with the rest of the terms
                enc = torch.cat([enc_error, enc[1]], dim=1)
                # encode through the inference model
                enc_h = self.inf_model(enc)

            self.latent_levels[0].infer(enc_h)

        else:
            observation = self._get_encoding_form(observation)
            self._h, self._skip = self.encoder(observation)
            self.latent_levels[0].infer(self._h)

    def generate(self, gen=False, n_samples=1):
        """
        Method for generating observations, i.e. running the generative model
        forward.

        Args:
            gen (boolean): whether to sample from prior or approximate posterior
            n_samples (int): number of samples to draw and evaluate
        """
        batch_size = self._prev_h.data.shape[0]

        # get the previous h and skip
        prev_h = self._prev_h.unsqueeze(1)
        prev_skip = [
            0. * _prev_skip.repeat(n_samples, 1, 1, 1)
            for _prev_skip in self._prev_skip
        ]

        # detach prev_h and prev_skip if necessary
        if self._detach_h:
            prev_h = prev_h.detach()
            prev_skip = [_prev_skip.detach() for _prev_skip in prev_skip]

        # get the prior input, detach if necessary
        gen_input = self._gen_input
        gen_input = gen_input.detach() if self._detach_h else gen_input

        # sample the latent variables
        z = self.latent_levels[0].generate(gen_input.unsqueeze(1),
                                           gen=gen,
                                           n_samples=n_samples)

        # pass through the decoder
        decoder_input = torch.cat([z, prev_h],
                                  dim=2).view(batch_size * n_samples, -1)
        g = self.decoder_lstm(decoder_input, detach=self._detach_h)
        g = self.decoder_lstm_output(g)
        output = self.decoder([g, prev_skip])
        b, _, h, w = output.data.shape

        # get the output mean and log variance
        if self.model_config['global_output_log_var']:
            # repeat along batch and sample dimensions
            output = output.view(b, -1, self.n_input_channels, h, w)
            log_var = self.output_log_var.unsqueeze(0).unsqueeze(0).repeat(
                batch_size, n_samples, 1, 1, 1)
            self.output_dist.log_var = torch.clamp(log_var, min=-10)
        else:
            output = output.view(b, -1, 2 * self.n_input_channels, h, w)
            self.output_dist.log_var = torch.clamp(
                output[:, :, self.n_input_channels:, :, :], min=-10)

        self.output_dist.mean = output[:, :, :self.
                                       n_input_channels, :, :].sigmoid()

        return torch.clamp(self.output_dist.sample(), 0., 1.)

    def step(self):
        """
        Method for stepping the generative model forward one step in the sequence.
        """
        # TODO: set n_samples in a smart way
        # step the lstms and latent level
        self.latent_levels[0].step()
        self.prior_lstm.step()
        self.decoder_lstm.step()
        # copy over the hidden and skip variables
        self._prev_h = self._h
        self._prev_skip = self._skip
        # clear the current hidden and skip variables, set the flag
        self._h = self._skip = None
        self._obs_encoded = False
        # use the prior lstm to get generative model input
        self._gen_input = self.prior_lstm(self._prev_h.unsqueeze(1),
                                          detach=False)
        # set the prior and approximate posterior
        self.latent_levels[0].generate(self._gen_input.detach().unsqueeze(1),
                                       gen=True,
                                       n_samples=1)
        self.latent_levels[0].latent.re_init_approx_posterior()

    def re_init(self, input):
        """
        Method for reinitializing the state (distributions and hidden states).

        Args:
            input (Variable, Tensor): contains observation at t = -1
        """
        # TODO: set n_samples in a smart way
        # flag to encode the hidden state for later decoding
        self._obs_encoded = False
        # re-initialize the lstms and distributions
        self.latent_levels[0].re_init()
        self.prior_lstm.re_init(input)
        self.decoder_lstm.re_init(input)
        # clear the hidden state and skip
        self._h = self._skip = None
        # encode this input to set the previous h and skip
        self._prev_h, self._prev_skip = self.encoder(input - 0.5)
        # set the prior and approximate posterior
        self._gen_input = self.prior_lstm(self._prev_h, detach=False)
        self.latent_levels[0].generate(self._gen_input.unsqueeze(1),
                                       gen=True,
                                       n_samples=1)
        self.latent_levels[0].latent.re_init_approx_posterior()

    def inference_parameters(self):
        """
        Method for obtaining the inference parameters.
        """
        params = nn.ParameterList()
        if self.modified:
            params.extend(list(self.inf_model.parameters()))
            if self.inference_procedure == 'direct':
                params.extend(list(self.inf_encoder.parameters()))
            elif self.inference_procedure == 'gradient':
                pass  # no other inference parameters
            elif self.inference_procedure == 'error':
                params.extend(list(self.obs_error_enc.parameters()))
            else:
                raise NotImplementedError
        else:
            params.extend(list(self.encoder.parameters()))
        params.extend(list(self.latent_levels[0].inference_parameters()))
        return params

    def generative_parameters(self):
        """
        Method for obtaining the generative parameters.
        """
        params = nn.ParameterList()
        if self.modified:
            params.extend(list(self.encoder.parameters()))
        params.extend(list(self.prior_lstm.parameters()))
        params.extend(list(self.decoder.parameters()))
        params.extend(list(self.latent_levels[0].generative_parameters()))
        params.extend(list(self.decoder_lstm.parameters()))
        params.extend(list(self.decoder_lstm_output.parameters()))
        if self.model_config['global_output_log_var']:
            params.append(self.output_log_var)
        return params

    def inference_mode(self):
        """
        Method to set the model's current mode to inference.
        """
        self.latent_levels[0].latent.detach = False
        self._detach_h = True

    def generative_mode(self):
        """
        Method to set the model's current mode to generation.
        """
        self.latent_levels[0].latent.detach = True
        self._detach_h = False
class ConvolutionalLatentVariable(LatentVariable):
    """
    A convolutional latent variable.

    Args:
        variable_config (dict): dictionary containing variable config parameters
    """
    def __init__(self, variable_config):
        super(ConvLatentVariable, self).__init__()
        self.approx_posterior = self.prior = None
        self._construct(variable_config)

    def _construct(self, variable_config):
        """
        Constructs the latent variable according to the variable_config dict.
        Currently hard-coded to Gaussian distributions for both approximate
        posterior and prior.

        Args:
            variable_config (dict): dictionary containing variable config params
        """
        self.n_channels = variable_config['n_channels']
        self.filter_size = variable_config['filter_size']

        mean = Variable(torch.zeros(1, self.n_channels, 1, 1))
        std = Variable(torch.ones(1, self.n_channels, 1, 1))
        self.approx_posterior = Normal(mean, std)
        self.prior = Normal(mean, std)

    def infer(self, input):
        """
        Method to perform inference.

        Args:
            input (Tensor): input to the inference procedure
        """

        self.n_variable_channels = n_variable_channels
        self.filter_size = filter_size

        self.posterior_mean = Convolutional(n_input[0],
                                            self.n_variable_channels,
                                            self.filter_size)
        self.posterior_mean_gate = Convolutional(n_input[0],
                                                 self.n_variable_channels,
                                                 self.filter_size, 'sigmoid')
        self.posterior_log_var = Convolutional(n_input[0],
                                               self.n_variable_channels,
                                               self.filter_size)
        self.posterior_log_var_gate = Convolutional(n_input[0],
                                                    self.n_variable_channels,
                                                    self.filter_size,
                                                    'sigmoid')

        self.prior_mean = Convolutional(n_input[1], self.n_variable_channels,
                                        self.filter_size)
        # self.prior_mean_gate = Convolutional(n_input[1], self.n_variable_channels, self.filter_size, 'sigmoid', gate=True)
        self.prior_log_var = None
        if not const_prior_var:
            self.prior_log_var = Convolutional(n_input[1],
                                               self.n_variable_channels,
                                               self.filter_size)
            # self.prior_log_var_gate = Convolutional(n_input[1], self.n_variable_channels, self.filter_size, 'sigmoid', gate=True)

        self.previous_posterior = Normal(self.n_variable_channels)
        self.posterior = Normal(self.n_variable_channels)
        self.prior = Normal(self.n_variable_channels)
        if const_prior_var:
            self.prior.log_var_trainable()

    def infer(self, input, n_samples=1):
        # infer the approximate posterior
        mean_gate = self.posterior_mean_gate(input)
        mean_update = self.posterior_mean(input) * mean_gate
        # self.posterior.mean = self.posterior.mean.detach() + mean_update
        self.posterior.mean = mean_update
        log_var_gate = self.posterior_log_var_gate(input)
        log_var_update = self.posterior_log_var(input) * log_var_gate
        # self.posterior.log_var = (1. - log_var_gate) * self.posterior.log_var.detach() + log_var_update
        self.posterior.log_var = log_var_update
        return self.posterior.sample(n_samples, resample=True)

    def generate(self, input, gen, n_samples):
        b, s, c, h, w = input.data.shape
        input = input.view(-1, c, h, w)
        # mean_gate = self.prior_mean_gate(input).view(b, s, -1, h, w)
        mean_update = self.prior_mean(input).view(b, s, -1, h,
                                                  w)  # * mean_gate
        # self.prior.mean = (1. - mean_gate) * self.posterior.mean.detach() + mean_update
        self.prior.mean = mean_update
        # log_var_gate = self.prior_log_var_gate(input).view(b, s, -1, h, w)
        log_var_update = self.prior_log_var(input).view(b, s, -1, h,
                                                        w)  # * log_var_gate
        # self.prior.log_var = (1. - log_var_gate) * self.posterior.log_var.detach() + log_var_update
        self.prior.log_var = log_var_update
        if gen:
            return self.prior.sample(n_samples, resample=True)
        return self.posterior.sample(n_samples, resample=True)

    def step(self):
        # set the previous posterior with the current posterior
        self.previous_posterior.mean = self.posterior.mean.detach()
        self.previous_posterior.log_var = self.posterior.log_var.detach()

    def error(self, averaged=True, weighted=False):
        sample = self.posterior.sample()
        n_samples = sample.data.shape[1]
        prior_mean = self.prior.mean.detach()
        err = sample - prior_mean[:n_samples]
        if weighted:
            prior_log_var = self.prior.log_var.detach()
            err /= prior_log_var
        if averaged:
            err = err.mean(dim=1)
        return err

    def reset_approx_posterior(self):
        mean = self.prior.mean.data.clone().mean(dim=1)
        log_var = self.prior.log_var.data.clone().mean(dim=1)
        self.posterior.reset(mean, log_var)

    def reset_prior(self):
        self.prior.reset()
        if self.prior_log_var is None:
            self.prior.log_var_trainable()

    def reinitialize_variable(self, output_dims):
        b, _, h, w = output_dims
        # reinitialize the previous approximate posterior and prior
        self.previous_posterior.reset()
        self.previous_posterior.mean = self.previous_posterior.mean.view(
            1, 1, 1, 1, 1).repeat(b, 1, self.n_variable_channels, h, w)
        self.previous_posterior.log_var = self.previous_posterior.log_var.view(
            1, 1, 1, 1, 1).repeat(b, 1, self.n_variable_channels, h, w)
        self.prior.reset()
        self.prior.mean = self.prior.mean.view(1, 1, 1, 1, 1).repeat(
            b, 1, self.n_variable_channels, h, w)
        self.prior.log_var = self.prior.log_var.view(1, 1, 1, 1, 1).repeat(
            b, 1, self.n_variable_channels, h, w)

    def inference_model_parameters(self):
        inference_params = []
        inference_params.extend(list(self.posterior_mean.parameters()))
        inference_params.extend(list(self.posterior_mean_gate.parameters()))
        inference_params.extend(list(self.posterior_log_var.parameters()))
        inference_params.extend(list(self.posterior_log_var_gate.parameters()))
        return inference_params

    def generative_model_parameters(self):
        generative_params = []
        generative_params.extend(list(self.prior_mean.parameters()))
        if self.prior_log_var is not None:
            generative_params.extend(list(self.prior_log_var.parameters()))
        else:
            generative_params.append(self.prior.log_var)
        return generative_params

    def approx_posterior_parameters(self):
        return [self.posterior.mean.detach(), self.posterior.log_var.detach()]

    def approx_posterior_gradients(self):
        assert self.posterior.mean.grad is not None, 'Approximate posterior gradients are None.'
        grads = [self.posterior.mean.grad.detach()]
        grads += [self.posterior.log_var.grad.detach()]
        for grad in grads:
            grad.volatile = False
        return grads
class VRNN(LatentVariableModel):
    """
    Variational recurrent neural network (VRNN) from "A Recurrent Latent
    Variable Model for Sequential Data," Chung et al., 2015.

    Args:
        model_config (dict): dictionary containing model configuration params
    """
    def __init__(self, model_config):
        super(VRNN, self).__init__(model_config)
        self._construct(model_config)

    def _construct(self, model_config):
        """
        Args:
            model_config (dict): dictionary containing model configuration params
        """
        model_type = model_config['model_type'].lower()
        self.modified = model_config['modified']
        self.inference_procedure = model_config['inference_procedure'].lower()
        if not self.modified:
            assert self.inference_procedure == 'direct', 'The original model only supports direct inference.'
        self._detach_h = False
        latent_config = {}
        level_config = {}
        latent_config['inference_procedure'] = self.inference_procedure
        # hard coded because we handle inference here in the model
        level_config['inference_procedure'] = 'direct'

        if model_type == 'timit':
            lstm_units = 2000
            encoder_units = 500
            prior_units = 500
            decoder_units = 600
            x_units = 600
            z_units = 500
            hidden_layers = 4
            x_dim = 200
            z_dim = 200
            self.output_interval = 0.0018190742
        elif model_type == 'blizzard':
            lstm_units = 4000
            encoder_units = 500
            prior_units = 500
            decoder_units = 600
            x_units = 600
            z_units = 500
            hidden_layers = 4
            x_dim = 200
            z_dim = 200
            # TODO: check if this is correct
            self.output_interval = 0.0018190742
        elif model_type == 'iam_ondb':
            lstm_units = 1200
            encoder_units = 150
            prior_units = 150
            decoder_units = 250
            x_units = 250
            z_units = 150
            hidden_layers = 1
            x_dim = 3
            z_dim = 50
        elif model_type == 'bball':
            lstm_units = 1000
            encoder_units = 200
            prior_units = 200
            decoder_units = 200
            x_units = 200
            z_units = 200
            hidden_layers = 2
            x_dim = 2
            z_dim = 50
            self.output_interval = Variable(torch.from_numpy(
                np.array([1e-5 / 94., 1e-5 / 50.]).astype('float32')),
                                            requires_grad=False).cuda()
        else:
            raise Exception('VRNN model type must be one of 1) timit, 2) \
                            blizzard, 3) iam_ondb, or 4) bball. Invalid model \
                            type: ' + model_type + '.')

        # LSTM
        lstm_config = {
            'n_layers': 1,
            'n_units': lstm_units,
            'n_in': x_units + z_units
        }
        self.lstm = LSTMNetwork(lstm_config)

        # x model
        x_config = {
            'n_in': x_dim,
            'n_units': x_units,
            'n_layers': hidden_layers,
            'non_linearity': 'relu'
        }
        self.x_model = FullyConnectedNetwork(x_config)

        # inf model
        if self.modified:
            if self.inference_procedure in ['direct', 'gradient', 'error']:
                # set the input encoding size
                if self.inference_procedure == 'direct':
                    input_dim = x_dim
                elif self.inference_procedure == 'gradient':
                    latent_config['update_type'] = model_config['update_type']
                    input_dim = 4 * z_dim
                    if model_config['concat_observation']:
                        input_dim += x_dim
                elif self.inference_procedure == 'error':
                    latent_config['update_type'] = model_config['update_type']
                    input_dim = x_dim + 3 * z_dim
                    if model_config['concat_observation']:
                        input_dim += x_dim
                else:
                    raise NotImplementedError

                encoder_units = 1024
                inf_config = {
                    'n_in': input_dim,
                    'n_units': encoder_units,
                    'n_layers': 2,
                    'non_linearity': 'elu'
                }
                inf_config['connection_type'] = 'highway'
                # self.inf_model = FullyConnectedNetwork(inf_config)
            else:
                inf_config = None
                latent_config['inf_lr'] = model_config['learning_rate']
        else:
            inf_input_units = lstm_units + x_units
            inf_config = {
                'n_in': inf_input_units,
                'n_units': encoder_units,
                'n_layers': hidden_layers,
                'non_linearity': 'relu'
            }

        # latent level (encoder model and prior model)
        level_config['inference_config'] = inf_config
        gen_config = {
            'n_in': lstm_units,
            'n_units': prior_units,
            'n_layers': hidden_layers,
            'non_linearity': 'relu'
        }
        level_config['generative_config'] = gen_config
        latent_config['n_variables'] = z_dim
        latent_config['n_in'] = (encoder_units, prior_units)
        latent_config['normalize_samples'] = model_config[
            'normalize_latent_samples']
        # latent_config['n_in'] = (encoder_units+input_dim, prior_units)
        level_config['latent_config'] = latent_config
        latent = FullyConnectedLatentLevel(level_config)
        self.latent_levels = nn.ModuleList([latent])

        # z model
        z_config = {
            'n_in': z_dim,
            'n_units': z_units,
            'n_layers': hidden_layers,
            'non_linearity': 'relu'
        }
        self.z_model = FullyConnectedNetwork(z_config)

        # decoder
        decoder_config = {
            'n_in': lstm_units + z_units,
            'n_units': decoder_units,
            'n_layers': hidden_layers,
            'non_linearity': 'relu'
        }
        self.decoder_model = FullyConnectedNetwork(decoder_config)

        self.output_dist = Normal()
        self.output_mean = FullyConnectedLayer({
            'n_in': decoder_units,
            'n_out': x_dim
        })
        if model_config['global_output_log_var']:
            self.output_log_var = nn.Parameter(torch.zeros(x_dim))
        else:
            self.output_log_var = FullyConnectedLayer({
                'n_in': decoder_units,
                'n_out': x_dim
            })

    def _get_encoding_form(self, observation):
        """
        Gets the appropriate input form for the inference procedure.

        Args:
            observation (Variable, tensor): the input observation
        """
        if self.inference_procedure == 'direct':
            return observation

        elif self.inference_procedure == 'gradient':
            grads = self.latent_levels[0].latent.approx_posterior_gradients()

            # normalization
            if self.model_config['input_normalization'] in ['layer', 'batch']:
                norm_dim = 0 if self.model_config[
                    'input_normalization'] == 'batch' else 1
                for ind, grad in enumerate(grads):
                    mean = grad.mean(dim=norm_dim, keepdim=True)
                    std = grad.std(dim=norm_dim, keepdim=True)
                    grads[ind] = (grad - mean) / (std + 1e-7)
            grads = torch.cat(grads, dim=1)

            # concatenate with the parameters
            params = self.latent_levels[0].latent.approx_posterior_parameters()
            if self.model_config['norm_parameters']:
                if self.model_config['input_normalization'] in [
                        'layer', 'batch'
                ]:
                    norm_dim = 0 if self.model_config[
                        'input_normalization'] == 'batch' else 1
                    for ind, param in enumerate(params):
                        mean = param.mean(dim=norm_dim, keepdim=True)
                        std = param.std(dim=norm_dim, keepdim=True)
                        params[ind] = (param - mean) / (std + 1e-7)
            params = torch.cat(params, dim=1)

            grads_params = torch.cat([grads, params], dim=1)

            # concatenate with the observation
            if self.model_config['concat_observation']:
                grads_params = torch.cat([grads_params, observation], dim=1)

            return grads_params

        elif self.inference_procedure == 'error':
            errors = [
                self._output_error(observation),
                self.latent_levels[0].latent.error()
            ]

            # normalization
            if self.model_config['input_normalization'] in ['layer', 'batch']:
                norm_dim = 0 if self.model_config[
                    'input_normalization'] == 'batch' else 1
                for ind, error in enumerate(errors):
                    mean = error.mean(dim=0, keepdim=True)
                    std = error.std(dim=0, keepdim=True)
                    errors[ind] = (error - mean) / (std + 1e-7)
            errors = torch.cat(errors, dim=1)

            # concatenate with the parameters
            params = self.latent_levels[0].latent.approx_posterior_parameters()
            if self.model_config['norm_parameters']:
                if self.model_config['input_normalization'] in [
                        'layer', 'batch'
                ]:
                    norm_dim = 0 if self.model_config[
                        'input_normalization'] == 'batch' else 1
                    for ind, param in enumerate(params):
                        mean = param.mean(dim=norm_dim, keepdim=True)
                        std = param.std(dim=norm_dim, keepdim=True)
                        params[ind] = (param - mean) / (std + 1e-7)
            params = torch.cat(params, dim=1)

            error_params = torch.cat([errors, params], dim=1)

            if self.model_config['concat_observation']:
                error_params = torch.cat([error_params, observation], dim=1)

            return error_params

        elif self.inference_procedure == 'sgd':
            grads = self.latent_levels[0].latent.approx_posterior_gradients()
            return grads

        else:
            raise NotImplementedError

    def _output_error(self, observation, averaged=True):
        """
        Calculates Gaussian error for encoding.

        Args:
            observation (tensor): observation to use for error calculation
        """
        output_mean = self.output_dist.mean.detach()
        output_log_var = self.output_dist.log_var.detach()
        n_samples = output_mean.data.shape[1]
        if len(observation.data.shape) == 2:
            observation = observation.unsqueeze(1).repeat(1, n_samples, 1)
        n_error = (observation - output_mean) / torch.exp(output_log_var +
                                                          1e-7)
        if averaged:
            n_error = n_error.mean(dim=1)
        return n_error

    def infer(self, observation):
        """
        Method for perfoming inference of the approximate posterior over the
        latent variables.

        Args:
            observation (tensor): observation to infer latent variables from
        """
        self._x_enc = self.x_model(observation)
        if self.modified:
            enc = self._get_encoding_form(observation)
            self.latent_levels[0].infer(enc)
        else:
            inf_input = self._x_enc
            prev_h = self._prev_h
            # prev_h = prev_h.detach() if self._detach_h else prev_h
            enc = torch.cat([inf_input, prev_h], dim=1)
            self.latent_levels[0].infer(enc)

    def generate(self, gen=False, n_samples=1):
        """
        Method for generating observations, i.e. running the generative model
        forward.

        Args:
            gen (boolean): whether to sample from prior or approximate posterior
            n_samples (int): number of samples to draw and evaluate
        """
        # TODO: handle sampling dimension
        # possibly detach the hidden state, preventing backprop
        prev_h = self._prev_h.unsqueeze(1)
        prev_h = prev_h.detach() if self._detach_h else prev_h

        # generate the prior
        z = self.latent_levels[0].generate(prev_h,
                                           gen=gen,
                                           n_samples=n_samples)

        # transform z through the z model
        b, s, _ = z.data.shape
        self._z_enc = self.z_model(z.view(b * s, -1)).view(b, s, -1)

        # pass encoded z and previous h through the decoder model
        dec = torch.cat([self._z_enc, prev_h.repeat(1, s, 1)], dim=2)
        b, s, _ = dec.data.shape
        output = self.decoder_model(dec.view(b * s, -1)).view(b, s, -1)

        # get the output mean and log variance
        self.output_dist.mean = self.output_mean(output)
        if self.model_config['global_output_log_var']:
            b, s = output.data.shape[0], output.data.shape[1]
            log_var = self.output_log_var.view(1, 1, -1).repeat(b, s, 1)
            self.output_dist.log_var = torch.clamp(log_var, min=-20., max=5)
        else:
            self.output_dist.log_var = torch.clamp(self.output_log_var(output),
                                                   min=-20.,
                                                   max=5)

        return self.output_dist.sample()

    def step(self, n_samples=1):
        """
        Method for stepping the generative model forward one step in the sequence.
        """
        # TODO: handle sampling dimension
        self._prev_h = self.lstm(
            torch.cat([self._x_enc, self._z_enc[:, 0]], dim=1))
        prev_h = self._prev_h.unsqueeze(1)
        self.lstm.step()
        # get the prior, use it to initialize the approximate posterior
        self.latent_levels[0].generate(prev_h, gen=True, n_samples=n_samples)
        self.latent_levels[0].latent.re_init_approx_posterior()

    def re_init(self, input):
        """
        Method for reinitializing the state (approximate posterior and priors)
        of the dynamical latent variable model.
        """
        # re-initialize the LSTM hidden and cell states
        self.lstm.re_init(input)
        # set the previous hidden state, add sample dimension
        self._prev_h = self.lstm.layers[0].hidden_state
        prev_h = self._prev_h.unsqueeze(1)
        # get the prior, use it to initialize the approximate posterior
        self.latent_levels[0].generate(prev_h, gen=True, n_samples=1)
        self.latent_levels[0].latent.re_init_approx_posterior()

    def inference_parameters(self):
        """
        Method for obtaining the inference parameters.
        """
        params = nn.ParameterList()
        if self.inference_procedure != 'sgd':
            params.extend(list(self.latent_levels[0].inference_parameters()))
        return params

    def generative_parameters(self):
        """
        Method for obtaining the generative parameters.
        """
        params = nn.ParameterList()
        params.extend(list(self.lstm.parameters()))
        params.extend(list(self.latent_levels[0].generative_parameters()))
        params.extend(list(self.x_model.parameters()))
        params.extend(list(self.z_model.parameters()))
        params.extend(list(self.decoder_model.parameters()))
        params.extend(list(self.output_mean.parameters()))
        if self.model_config['global_output_log_var']:
            params.append(self.output_log_var)
        else:
            params.extend(list(self.output_log_var.parameters()))
        return params

    def inference_mode(self):
        """
        Method to set the model's current mode to inference.
        """
        self.latent_levels[0].latent.detach = False
        self._detach_h = True

    def generative_mode(self):
        """
        Method to set the model's current mode to generation.
        """
        self.latent_levels[0].latent.detach = True
        self._detach_h = False