Esempio n. 1
0
    def _fn(batch):
        """Build the loss."""
        shapes = [(8, 8), (16, 16), (32, 32)]

        def encoder_fn(net):
            """Encoder for VAE."""
            net = snt.nets.ConvNet2D(enc_units,
                                     kernel_shapes=[(3, 3)],
                                     strides=[2, 2, 2],
                                     paddings=[snt.SAME],
                                     activation=activation_fn,
                                     activate_final=True)(net)

            flat_dims = int(np.prod(net.shape.as_list()[1:]))
            net = tf.reshape(net, [-1, flat_dims])
            net = snt.Linear(2 * n_z)(net)
            return generative_utils.LogStddevNormal(net)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            """Decoder for VAE."""
            net = snt.Linear(4 * 4 * 32)(net)
            net = tf.reshape(net, [-1, 4, 4, 32])
            net = snt.nets.ConvNet2DTranspose(dec_units,
                                              shapes,
                                              kernel_shapes=[(3, 3)],
                                              strides=[2, 2, 2],
                                              paddings=[snt.SAME],
                                              activation=activation_fn,
                                              activate_final=True)(net)

            outchannel = batch["image"].shape.as_list()[3]
            net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net)
            net = tf.clip_by_value(net, -10, 10)

            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")

        zshape = tf.stack([tf.shape(batch["image"])[0], 2 * n_z])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        input_image = (batch["image"] - 0.5) * 2

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, input_image)

        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Esempio n. 2
0
  def _build(batch):
    """Build the sonnet module."""
    net = snt.BatchFlatten()(batch["image"])
    # shift to be zero mean
    net = (net - 0.5) * 2

    n_inp = net.shape.as_list()[1]

    def encoder_fn(x):
      mlp_encoding = snt.nets.MLP(
          name="mlp_encoder",
          output_sizes=enc_units + [2 * n_z],
          activation=activation)
      return generative_utils.LogStddevNormal(mlp_encoding(x))

    encoder = snt.Module(encoder_fn, name="encoder")

    def decoder_fn(x):
      mlp_decoding = snt.nets.MLP(
          name="mlp_decoder",
          output_sizes=dec_units + [2 * n_inp],
          activation=activation)
      net = mlp_decoding(x)
      net = tf.clip_by_value(net, -10, 10)
      return generative_utils.QuantizedNormal(mu_log_sigma=net)

    decoder = snt.Module(decoder_fn, name="decoder")

    zshape = tf.stack([tf.shape(net)[0], 2 * n_z])

    prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

    log_p_x, kl_term = generative_utils.log_prob_elbo_components(
        encoder, decoder, prior, net)

    elbo = log_p_x - kl_term

    metrics = {
        "kl_term": tf.reduce_mean(kl_term),
        "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
        "log_p_x": tf.reduce_mean(log_p_x),
        "elbo": tf.reduce_mean(elbo),
        "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
    }

    return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Esempio n. 3
0
    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Esempio n. 4
0
  def _build(batch):
    """Builds the sonnet module.

    Args:
      batch: Dict with "image", "label", and "label_onehot" keys. This is the
        input batch used to compute the loss over.

    Returns:
      The loss and a metrics dict.
    """
    net = snt.BatchFlatten()(batch["image"])
    logits = snt.nets.MLP(hidden_units, activation=activation)(net)
    if losstype == "ce":
      loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
          labels=batch["label_onehot"], logits=logits)
    elif losstype == "mse":
      loss_vec = tf.reduce_mean(
          tf.square(batch["label_onehot"] - tf.nn.sigmoid(logits)), [1])
    else:
      raise ValueError("Loss type [%s] not supported." % losstype)

    aux = {"accuracy": utils.accuracy(label=batch["label"], logits=logits)}
    return base.LossAndAux(loss=tf.reduce_mean(loss_vec), aux=aux)