Exemple #1
0
    def decode(self,
               zs,
               sigma_scale,
               latent_id=None,
               reuse=False,
               is_training=True):
        # --- Decoding Loop
        output_dim = datashapes[self.opts['dataset']][:-1] + [
            datashapes[self.opts['dataset']][-1],
        ]
        z = zs[0]
        zshape = z.get_shape().as_list()[1:]
        if len(zshape) > 1:
            # reshape the codes to [-1,output_dim]
            z = tf.squeeze(tf.concat(tf.split(z, zshape[0], 1), axis=0), [1])
        mean, Sigma = Decoder(self.opts,
                              input=z,
                              archi=self.opts['archi'][0],
                              nlayers=self.opts['nlayers'][0],
                              nfilters=self.opts['nfilters'][0],
                              filters_size=self.opts['filters_size'][0],
                              output_dim=output_dim,
                              upsample=self.opts['upsample'],
                              output_layer=self.opts['output_layer'][0],
                              scope='decoder/layer_0',
                              reuse=reuse,
                              is_training=is_training)
        # reshaping to [-1,nresamples,output_dim] if needed
        if len(zshape) > 1:
            mean = tf.stack(tf.split(mean, zshape[0]), axis=1)
            Sigma = tf.stack(tf.split(Sigma, zshape[0]), axis=1)
        # - resampling reconstruced
        if self.opts['decoder'][0] == 'det':
            # - deterministic decoder
            x = mean
            if self.opts['use_sigmoid']:
                x = tf.compat.v1.sigmoid(x)
        elif self.opts['decoder'][0] == 'gauss':
            # - gaussian decoder
            p_params = tf.concat((mean, sigma_scale * Sigma), axis=-1)
            x = sample_gaussian(self.opts, p_params, 'tensorflow')
        else:
            assert False, 'Unknown encoder %s' % self.opts['decoder'][idx]

        return [
            x,
        ], [
            mean,
        ], [
            Sigma,
        ]
Exemple #2
0
    def layerwise_agg_kl(self, inputs, sigma_scale):
        # --- compute layer-wise KL(q(z_i|z_i-1,p(z_i|z_i+1))
        zs, enc_means, enc_Sigmas, xs, dec_means, dec_Sigmas = self.forward_pass(
            inputs, sigma_scale, False, reuse=True, is_training=False)
        kls = []
        # latent layer up to N-1
        for n in range(len(enc_means) - 1):
            logp_prob = log_normal(tf.expand_dims(xs[n + 1], 1),
                                   tf.expand_dims(dec_means[n + 1], 0),
                                   tf.expand_dims(dec_Sigmas[n + 1], 0))
            logp = tf.reduce_logsumexp(tf.reduce_sum(logp_prob,
                                                     axis=2,
                                                     keepdims=False),
                                       axis=1,
                                       keepdims=False)
            logq_prob = log_normal(tf.expand_dims(zs[n], 1),
                                   tf.expand_dims(enc_means[n], 0),
                                   tf.expand_dims(enc_Sigmas[n], 0))
            logq = tf.reduce_logsumexp(tf.reduce_sum(logq_prob,
                                                     axis=2,
                                                     keepdims=False),
                                       axis=1,
                                       keepdims=False)
            kl = tf.reduce_mean(logp - logq)
            kls.append(kl)
        # deepest layer
        pz_mean, pz_Sigma = tf.split(self.pz_params, 2, axis=-1)
        pz_samples = sample_gaussian(self.opts, self.pz_params, 'numpy',
                                     self.opts['batch_size'])
        logp_prob = log_normal(
            tf.expand_dims(pz_samples, 1),
            tf.expand_dims(tf.expand_dims(pz_mean, axis=0), 0),
            tf.expand_dims(tf.expand_dims(pz_Sigma, axis=0), 0))
        logp = tf.reduce_logsumexp(tf.reduce_sum(logp_prob,
                                                 axis=2,
                                                 keepdims=False),
                                   axis=1,
                                   keepdims=False)
        logq_prob = log_normal(tf.expand_dims(zs[-1], 1),
                               tf.expand_dims(enc_means[-1], 0),
                               tf.expand_dims(enc_Sigmas[-1], 0))
        logq = tf.reduce_logsumexp(tf.reduce_sum(logq_prob,
                                                 axis=2,
                                                 keepdims=False),
                                   axis=1,
                                   keepdims=False)
        kl = tf.reduce_mean(logp - logq)
        kls.append(kl)

        return kls
Exemple #3
0
    def encode(self,
               inputs,
               sigma_scale,
               resample,
               nresamples=1,
               reuse=False,
               is_training=True):
        # --- Encoding Loop
        zs, means, Sigmas = [], [], []
        for n in range(self.opts['nlatents']):
            if n == 0:
                input = inputs
            else:
                if resample:
                    # when resampling for vizu, we just pass the mean
                    input = means[-1]
                else:
                    input = zs[-1]
            mean, Sigma = Encoder(self.opts,
                                  input=input,
                                  archi=self.opts['archi'][n],
                                  nlayers=self.opts['nlayers'][n],
                                  nfilters=self.opts['nfilters'][n],
                                  filters_size=self.opts['filters_size'][n],
                                  output_dim=self.opts['zdim'][n],
                                  downsample=self.opts['upsample'],
                                  output_layer=self.opts['output_layer'][n],
                                  scope='encoder/layer_%d' % (n + 1),
                                  reuse=reuse,
                                  is_training=is_training)
            if self.opts['encoder'][n] == 'det':
                # - deterministic encoder
                z = mean
            elif self.opts['encoder'][n] == 'gauss':
                # - gaussian encoder
                if resample:
                    q_params = tf.concat((mean, sigma_scale * Sigma), axis=-1)
                    q_params = tf.stack([q_params for i in range(nresamples)],
                                        axis=1)
                else:
                    q_params = tf.concat((mean, Sigma), axis=-1)
                z = sample_gaussian(self.opts, q_params, 'tensorflow')
            else:
                assert False, 'Unknown encoder %s' % self.opts['encoder']
            zs.append(z)
            means.append(mean)
            Sigmas.append(Sigma)

        return zs, means, Sigmas
Exemple #4
0
 def losses(self, inputs, sigma_scale, reuse=False, is_training=True):
     # --- compute the losses of the stackedWAE
     zs, enc_means, enc_Sigmas, xs, dec_means, dec_Sigmas = self.forward_pass(
         inputs, sigma_scale, False, reuse=reuse, is_training=is_training)
     obs_cost = obs_reconstruction_loss(self.opts, inputs, xs[0])
     latent_cost = self.latent_cost(xs[1:], dec_means[1:], dec_Sigmas[1:],
                                    zs[:-1], enc_means[:-1],
                                    enc_Sigmas[:-1])
     pz_samples = sample_gaussian(self.opts, self.pz_params, 'numpy',
                                  self.opts['batch_size'])
     # if len(qz_samples.get_Shape().as_list()[1:])>1:
     #     qz_samples = zs[-1][:,0]
     # else:
     #     qz_samples = zs[-1]
     matching_penalty = latent_penalty(self.opts, zs[-1], pz_samples)
     enc_Sigma_penalty = self.Sigma_penalty(enc_Sigmas)
     dec_Sigma_penalty = self.Sigma_penalty(dec_Sigmas[1:])
     return obs_cost, latent_cost, matching_penalty, enc_Sigma_penalty, dec_Sigma_penalty
Exemple #5
0
    def encode(self,
               inputs,
               sigma_scale,
               resample,
               nresamples=1,
               reuse=False,
               is_training=True):
        # --- Encoding Loop
        mean, Sigma = Encoder(self.opts,
                              input=inputs,
                              archi=self.opts['archi'][0],
                              nlayers=self.opts['nlayers'][0],
                              nfilters=self.opts['nfilters'][0],
                              filters_size=self.opts['filters_size'][0],
                              output_dim=self.opts['zdim'][0],
                              downsample=self.opts['upsample'],
                              output_layer=self.opts['output_layer'][0],
                              scope='encoder/layer_1',
                              reuse=reuse,
                              is_training=is_training)
        if self.opts['encoder'][0] == 'det':
            # - deterministic encoder
            z = mean
        elif self.opts['encoder'][0] == 'gauss':
            # - gaussian encoder
            if resample:
                q_params = tf.concat((mean, sigma_scale * Sigma), axis=-1)
                q_params = tf.stack([q_params for i in range(nresamples)],
                                    axis=1)
            else:
                q_params = tf.concat((mean, Sigma), axis=-1)
            z = sample_gaussian(self.opts, q_params, 'tensorflow')
        else:
            assert False, 'Unknown encoder %s' % self.opts['encoder']

        return [
            z,
        ], [
            mean,
        ], [
            Sigma,
        ]
Exemple #6
0
 def losses(self,
            inputs,
            sigma_scale,
            resample,
            nresamples=1,
            reuse=False,
            is_training=True):
     # --- compute the losses of the stackedWAE
     zs, _, enc_Sigmas, xs, _, _ = self.forward_pass(
         inputs, sigma_scale, resample, nresamples, reuse, is_training)
     obs_cost = obs_reconstruction_loss(self.opts, inputs, xs[0])
     latent_cost = []
     pz_samples = tf.convert_to_tensor(
         sample_gaussian(self.opts, self.pz_params, 'numpy',
                         self.opts['batch_size']))
     pz_samples = self.decode_implicit_prior(pz_samples, reuse, is_training)
     if resample:
         qz_samples = zs[-1][:, 0]
     else:
         qz_samples = zs[-1]
     matching_penalty = latent_penalty(self.opts, qz_samples, pz_samples)
     Sigma_penalty = self.Sigma_penalty(enc_Sigmas)
     return obs_cost, latent_cost, matching_penalty, Sigma_penalty
Exemple #7
0
def encoder(opts,
            input,
            output_dim,
            scope=None,
            reuse=False,
            is_training=False):
    with tf.variable_scope(scope, reuse=reuse):
        if opts['network']['e_arch'] == 'mlp':
            # Encoder uses only fully connected layers with ReLus
            outputs = mlp_encoder(opts, input, output_dim, reuse, is_training)
        elif opts['network']['e_arch'] == 'conv_locatello':
            # Fully convolutional architecture similar to Locatello & al.
            outputs = locatello_encoder(opts, input, output_dim, reuse,
                                        is_training)
        elif opts['network']['e_arch'] == 'conv_rae':
            # Fully convolutional architecture similar to Locatello & al.
            outputs = rae_encoder(opts, input, output_dim, reuse, is_training)
        else:
            raise ValueError('%s : Unknown encoder architecture' %
                             opts['network']['e_arch'])

    mean, logSigma = tf.split(outputs, 2, axis=-1)
    logSigma = tf.clip_by_value(logSigma, -20, 500)
    Sigma = tf.nn.softplus(logSigma)
    mean = tf.layers.flatten(mean)
    Sigma = tf.layers.flatten(Sigma)

    if opts['encoder'] == 'det':
        z = mean
    elif opts['encoder'] == 'gauss':
        qz_params = tf.concat((mean, Sigma), axis=-1)
        z = sample_gaussian(qz_params, 'tensorflow')
    else:
        assert False, 'Unknown encoder %s' % opts['encoder']

    return z, mean, Sigma
Exemple #8
0
    def decode(self,
               zs,
               sigma_scale,
               latent_id=None,
               reuse=False,
               is_training=True):
        # --- Decoding Loop
        xs, means, Sigmas = [], [], []
        for n in range(len(zs)):
            if latent_id is not None:
                idx = latent_id
            else:
                idx = n
            if idx == 0:
                output_dim = datashapes[self.opts['dataset']][:-1] + [
                    datashapes[self.opts['dataset']][-1],
                ]
            else:
                output_dim = [
                    self.opts['zdim'][idx - 1],
                ]
            z = zs[n]
            zshape = z.get_shape().as_list()[1:]
            if len(zshape) > 1:
                # reshape the codes to [-1,output_dim]
                z = tf.squeeze(tf.concat(tf.split(z, zshape[0], 1), axis=0),
                               [1])
            mean, Sigma = Decoder(
                self.opts,
                input=z,
                archi=self.opts['archi'][idx],
                nlayers=self.opts['nlayers'][idx],
                nfilters=self.opts['nfilters'][idx],
                filters_size=self.opts['filters_size'][idx],
                output_dim=output_dim,
                # features_dim=features_dim,
                upsample=self.opts['upsample'],
                output_layer=self.opts['output_layer'][idx],
                scope='decoder/layer_%d' % idx,
                reuse=reuse,
                is_training=is_training)
            # reshaping to [-1,nresamples,output_dim] if needed
            if len(zshape) > 1:
                mean = tf.stack(tf.split(mean, zshape[0]), axis=1)
                Sigma = tf.stack(tf.split(Sigma, zshape[0]), axis=1)
            # - resampling reconstruced
            if self.opts['decoder'][idx] == 'det':
                # - deterministic decoder
                x = mean
                if self.opts['use_sigmoid']:
                    x = tf.compat.v1.sigmoid(x)
            elif self.opts['decoder'][idx] == 'gauss':
                # - gaussian decoder
                p_params = tf.concat((mean, sigma_scale * Sigma), axis=-1)
                x = sample_gaussian(self.opts, p_params, 'tensorflow')
            elif self.opts['decoder'][idx] == 'bernoulli':
                x = sample_bernoulli(mean)
            else:
                assert False, 'Unknown encoder %s' % self.opts['decoder'][idx]
            xs.append(x)
            means.append(mean)
            Sigmas.append(Sigma)

        return xs, means, Sigmas