def _fn(batch): """Build the loss.""" shapes = [(8, 8), (16, 16), (32, 32)] def encoder_fn(net): """Encoder for VAE.""" net = snt.nets.ConvNet2D(enc_units, kernel_shapes=[(3, 3)], strides=[2, 2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(net) flat_dims = int(np.prod(net.shape.as_list()[1:])) net = tf.reshape(net, [-1, flat_dims]) net = snt.Linear(2 * n_z)(net) return generative_utils.LogStddevNormal(net) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): """Decoder for VAE.""" net = snt.Linear(4 * 4 * 32)(net) net = tf.reshape(net, [-1, 4, 4, 32]) net = snt.nets.ConvNet2DTranspose(dec_units, shapes, kernel_shapes=[(3, 3)], strides=[2, 2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(net) outchannel = batch["image"].shape.as_list()[3] net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(batch["image"])[0], 2 * n_z]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) input_image = (batch["image"] - 0.5) * 2 log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, input_image) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
def _build(batch): """Build the sonnet module.""" net = snt.BatchFlatten()(batch["image"]) # shift to be zero mean net = (net - 0.5) * 2 n_inp = net.shape.as_list()[1] def encoder_fn(x): mlp_encoding = snt.nets.MLP( name="mlp_encoder", output_sizes=enc_units + [2 * n_z], activation=activation) return generative_utils.LogStddevNormal(mlp_encoding(x)) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(x): mlp_decoding = snt.nets.MLP( name="mlp_decoder", output_sizes=dec_units + [2 * n_inp], activation=activation) net = mlp_decoding(x) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(net)[0], 2 * n_z]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, net) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)