def decode(self, zs, sigma_scale, latent_id=None, reuse=False, is_training=True): # --- Decoding Loop output_dim = datashapes[self.opts['dataset']][:-1] + [ datashapes[self.opts['dataset']][-1], ] z = zs[0] zshape = z.get_shape().as_list()[1:] if len(zshape) > 1: # reshape the codes to [-1,output_dim] z = tf.squeeze(tf.concat(tf.split(z, zshape[0], 1), axis=0), [1]) mean, Sigma = Decoder(self.opts, input=z, archi=self.opts['archi'][0], nlayers=self.opts['nlayers'][0], nfilters=self.opts['nfilters'][0], filters_size=self.opts['filters_size'][0], output_dim=output_dim, upsample=self.opts['upsample'], output_layer=self.opts['output_layer'][0], scope='decoder/layer_0', reuse=reuse, is_training=is_training) # reshaping to [-1,nresamples,output_dim] if needed if len(zshape) > 1: mean = tf.stack(tf.split(mean, zshape[0]), axis=1) Sigma = tf.stack(tf.split(Sigma, zshape[0]), axis=1) # - resampling reconstruced if self.opts['decoder'][0] == 'det': # - deterministic decoder x = mean if self.opts['use_sigmoid']: x = tf.compat.v1.sigmoid(x) elif self.opts['decoder'][0] == 'gauss': # - gaussian decoder p_params = tf.concat((mean, sigma_scale * Sigma), axis=-1) x = sample_gaussian(self.opts, p_params, 'tensorflow') else: assert False, 'Unknown encoder %s' % self.opts['decoder'][idx] return [ x, ], [ mean, ], [ Sigma, ]
def layerwise_agg_kl(self, inputs, sigma_scale): # --- compute layer-wise KL(q(z_i|z_i-1,p(z_i|z_i+1)) zs, enc_means, enc_Sigmas, xs, dec_means, dec_Sigmas = self.forward_pass( inputs, sigma_scale, False, reuse=True, is_training=False) kls = [] # latent layer up to N-1 for n in range(len(enc_means) - 1): logp_prob = log_normal(tf.expand_dims(xs[n + 1], 1), tf.expand_dims(dec_means[n + 1], 0), tf.expand_dims(dec_Sigmas[n + 1], 0)) logp = tf.reduce_logsumexp(tf.reduce_sum(logp_prob, axis=2, keepdims=False), axis=1, keepdims=False) logq_prob = log_normal(tf.expand_dims(zs[n], 1), tf.expand_dims(enc_means[n], 0), tf.expand_dims(enc_Sigmas[n], 0)) logq = tf.reduce_logsumexp(tf.reduce_sum(logq_prob, axis=2, keepdims=False), axis=1, keepdims=False) kl = tf.reduce_mean(logp - logq) kls.append(kl) # deepest layer pz_mean, pz_Sigma = tf.split(self.pz_params, 2, axis=-1) pz_samples = sample_gaussian(self.opts, self.pz_params, 'numpy', self.opts['batch_size']) logp_prob = log_normal( tf.expand_dims(pz_samples, 1), tf.expand_dims(tf.expand_dims(pz_mean, axis=0), 0), tf.expand_dims(tf.expand_dims(pz_Sigma, axis=0), 0)) logp = tf.reduce_logsumexp(tf.reduce_sum(logp_prob, axis=2, keepdims=False), axis=1, keepdims=False) logq_prob = log_normal(tf.expand_dims(zs[-1], 1), tf.expand_dims(enc_means[-1], 0), tf.expand_dims(enc_Sigmas[-1], 0)) logq = tf.reduce_logsumexp(tf.reduce_sum(logq_prob, axis=2, keepdims=False), axis=1, keepdims=False) kl = tf.reduce_mean(logp - logq) kls.append(kl) return kls
def encode(self, inputs, sigma_scale, resample, nresamples=1, reuse=False, is_training=True): # --- Encoding Loop zs, means, Sigmas = [], [], [] for n in range(self.opts['nlatents']): if n == 0: input = inputs else: if resample: # when resampling for vizu, we just pass the mean input = means[-1] else: input = zs[-1] mean, Sigma = Encoder(self.opts, input=input, archi=self.opts['archi'][n], nlayers=self.opts['nlayers'][n], nfilters=self.opts['nfilters'][n], filters_size=self.opts['filters_size'][n], output_dim=self.opts['zdim'][n], downsample=self.opts['upsample'], output_layer=self.opts['output_layer'][n], scope='encoder/layer_%d' % (n + 1), reuse=reuse, is_training=is_training) if self.opts['encoder'][n] == 'det': # - deterministic encoder z = mean elif self.opts['encoder'][n] == 'gauss': # - gaussian encoder if resample: q_params = tf.concat((mean, sigma_scale * Sigma), axis=-1) q_params = tf.stack([q_params for i in range(nresamples)], axis=1) else: q_params = tf.concat((mean, Sigma), axis=-1) z = sample_gaussian(self.opts, q_params, 'tensorflow') else: assert False, 'Unknown encoder %s' % self.opts['encoder'] zs.append(z) means.append(mean) Sigmas.append(Sigma) return zs, means, Sigmas
def losses(self, inputs, sigma_scale, reuse=False, is_training=True): # --- compute the losses of the stackedWAE zs, enc_means, enc_Sigmas, xs, dec_means, dec_Sigmas = self.forward_pass( inputs, sigma_scale, False, reuse=reuse, is_training=is_training) obs_cost = obs_reconstruction_loss(self.opts, inputs, xs[0]) latent_cost = self.latent_cost(xs[1:], dec_means[1:], dec_Sigmas[1:], zs[:-1], enc_means[:-1], enc_Sigmas[:-1]) pz_samples = sample_gaussian(self.opts, self.pz_params, 'numpy', self.opts['batch_size']) # if len(qz_samples.get_Shape().as_list()[1:])>1: # qz_samples = zs[-1][:,0] # else: # qz_samples = zs[-1] matching_penalty = latent_penalty(self.opts, zs[-1], pz_samples) enc_Sigma_penalty = self.Sigma_penalty(enc_Sigmas) dec_Sigma_penalty = self.Sigma_penalty(dec_Sigmas[1:]) return obs_cost, latent_cost, matching_penalty, enc_Sigma_penalty, dec_Sigma_penalty
def encode(self, inputs, sigma_scale, resample, nresamples=1, reuse=False, is_training=True): # --- Encoding Loop mean, Sigma = Encoder(self.opts, input=inputs, archi=self.opts['archi'][0], nlayers=self.opts['nlayers'][0], nfilters=self.opts['nfilters'][0], filters_size=self.opts['filters_size'][0], output_dim=self.opts['zdim'][0], downsample=self.opts['upsample'], output_layer=self.opts['output_layer'][0], scope='encoder/layer_1', reuse=reuse, is_training=is_training) if self.opts['encoder'][0] == 'det': # - deterministic encoder z = mean elif self.opts['encoder'][0] == 'gauss': # - gaussian encoder if resample: q_params = tf.concat((mean, sigma_scale * Sigma), axis=-1) q_params = tf.stack([q_params for i in range(nresamples)], axis=1) else: q_params = tf.concat((mean, Sigma), axis=-1) z = sample_gaussian(self.opts, q_params, 'tensorflow') else: assert False, 'Unknown encoder %s' % self.opts['encoder'] return [ z, ], [ mean, ], [ Sigma, ]
def losses(self, inputs, sigma_scale, resample, nresamples=1, reuse=False, is_training=True): # --- compute the losses of the stackedWAE zs, _, enc_Sigmas, xs, _, _ = self.forward_pass( inputs, sigma_scale, resample, nresamples, reuse, is_training) obs_cost = obs_reconstruction_loss(self.opts, inputs, xs[0]) latent_cost = [] pz_samples = tf.convert_to_tensor( sample_gaussian(self.opts, self.pz_params, 'numpy', self.opts['batch_size'])) pz_samples = self.decode_implicit_prior(pz_samples, reuse, is_training) if resample: qz_samples = zs[-1][:, 0] else: qz_samples = zs[-1] matching_penalty = latent_penalty(self.opts, qz_samples, pz_samples) Sigma_penalty = self.Sigma_penalty(enc_Sigmas) return obs_cost, latent_cost, matching_penalty, Sigma_penalty
def encoder(opts, input, output_dim, scope=None, reuse=False, is_training=False): with tf.variable_scope(scope, reuse=reuse): if opts['network']['e_arch'] == 'mlp': # Encoder uses only fully connected layers with ReLus outputs = mlp_encoder(opts, input, output_dim, reuse, is_training) elif opts['network']['e_arch'] == 'conv_locatello': # Fully convolutional architecture similar to Locatello & al. outputs = locatello_encoder(opts, input, output_dim, reuse, is_training) elif opts['network']['e_arch'] == 'conv_rae': # Fully convolutional architecture similar to Locatello & al. outputs = rae_encoder(opts, input, output_dim, reuse, is_training) else: raise ValueError('%s : Unknown encoder architecture' % opts['network']['e_arch']) mean, logSigma = tf.split(outputs, 2, axis=-1) logSigma = tf.clip_by_value(logSigma, -20, 500) Sigma = tf.nn.softplus(logSigma) mean = tf.layers.flatten(mean) Sigma = tf.layers.flatten(Sigma) if opts['encoder'] == 'det': z = mean elif opts['encoder'] == 'gauss': qz_params = tf.concat((mean, Sigma), axis=-1) z = sample_gaussian(qz_params, 'tensorflow') else: assert False, 'Unknown encoder %s' % opts['encoder'] return z, mean, Sigma
def decode(self, zs, sigma_scale, latent_id=None, reuse=False, is_training=True): # --- Decoding Loop xs, means, Sigmas = [], [], [] for n in range(len(zs)): if latent_id is not None: idx = latent_id else: idx = n if idx == 0: output_dim = datashapes[self.opts['dataset']][:-1] + [ datashapes[self.opts['dataset']][-1], ] else: output_dim = [ self.opts['zdim'][idx - 1], ] z = zs[n] zshape = z.get_shape().as_list()[1:] if len(zshape) > 1: # reshape the codes to [-1,output_dim] z = tf.squeeze(tf.concat(tf.split(z, zshape[0], 1), axis=0), [1]) mean, Sigma = Decoder( self.opts, input=z, archi=self.opts['archi'][idx], nlayers=self.opts['nlayers'][idx], nfilters=self.opts['nfilters'][idx], filters_size=self.opts['filters_size'][idx], output_dim=output_dim, # features_dim=features_dim, upsample=self.opts['upsample'], output_layer=self.opts['output_layer'][idx], scope='decoder/layer_%d' % idx, reuse=reuse, is_training=is_training) # reshaping to [-1,nresamples,output_dim] if needed if len(zshape) > 1: mean = tf.stack(tf.split(mean, zshape[0]), axis=1) Sigma = tf.stack(tf.split(Sigma, zshape[0]), axis=1) # - resampling reconstruced if self.opts['decoder'][idx] == 'det': # - deterministic decoder x = mean if self.opts['use_sigmoid']: x = tf.compat.v1.sigmoid(x) elif self.opts['decoder'][idx] == 'gauss': # - gaussian decoder p_params = tf.concat((mean, sigma_scale * Sigma), axis=-1) x = sample_gaussian(self.opts, p_params, 'tensorflow') elif self.opts['decoder'][idx] == 'bernoulli': x = sample_bernoulli(mean) else: assert False, 'Unknown encoder %s' % self.opts['decoder'][idx] xs.append(x) means.append(mean) Sigmas.append(Sigma) return xs, means, Sigmas