Ejemplo n.º 1
0
    def _fn(batch):
        """Build the loss."""
        shapes = [(8, 8), (16, 16), (32, 32)]

        def encoder_fn(net):
            """Encoder for VAE."""
            net = snt.nets.ConvNet2D(enc_units,
                                     kernel_shapes=[(3, 3)],
                                     strides=[2, 2, 2],
                                     paddings=[snt.SAME],
                                     activation=activation_fn,
                                     activate_final=True)(net)

            flat_dims = int(np.prod(net.shape.as_list()[1:]))
            net = tf.reshape(net, [-1, flat_dims])
            net = snt.Linear(2 * n_z)(net)
            return generative_utils.LogStddevNormal(net)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            """Decoder for VAE."""
            net = snt.Linear(4 * 4 * 32)(net)
            net = tf.reshape(net, [-1, 4, 4, 32])
            net = snt.nets.ConvNet2DTranspose(dec_units,
                                              shapes,
                                              kernel_shapes=[(3, 3)],
                                              strides=[2, 2, 2],
                                              paddings=[snt.SAME],
                                              activation=activation_fn,
                                              activate_final=True)(net)

            outchannel = batch["image"].shape.as_list()[3]
            net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net)
            net = tf.clip_by_value(net, -10, 10)

            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")

        zshape = tf.stack([tf.shape(batch["image"])[0], 2 * n_z])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        input_image = (batch["image"] - 0.5) * 2

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, input_image)

        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Ejemplo n.º 2
0
  def _build(batch):
    """Build the sonnet module."""
    net = snt.BatchFlatten()(batch["image"])
    # shift to be zero mean
    net = (net - 0.5) * 2

    n_inp = net.shape.as_list()[1]

    def encoder_fn(x):
      mlp_encoding = snt.nets.MLP(
          name="mlp_encoder",
          output_sizes=enc_units + [2 * n_z],
          activation=activation)
      return generative_utils.LogStddevNormal(mlp_encoding(x))

    encoder = snt.Module(encoder_fn, name="encoder")

    def decoder_fn(x):
      mlp_decoding = snt.nets.MLP(
          name="mlp_decoder",
          output_sizes=dec_units + [2 * n_inp],
          activation=activation)
      net = mlp_decoding(x)
      net = tf.clip_by_value(net, -10, 10)
      return generative_utils.QuantizedNormal(mu_log_sigma=net)

    decoder = snt.Module(decoder_fn, name="decoder")

    zshape = tf.stack([tf.shape(net)[0], 2 * n_z])

    prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

    log_p_x, kl_term = generative_utils.log_prob_elbo_components(
        encoder, decoder, prior, net)

    elbo = log_p_x - kl_term

    metrics = {
        "kl_term": tf.reduce_mean(kl_term),
        "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
        "log_p_x": tf.reduce_mean(log_p_x),
        "elbo": tf.reduce_mean(elbo),
        "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
    }

    return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Ejemplo n.º 3
0
    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)