def decoder_fn(x):
     mlp_decoding = snt.nets.MLP(name="mlp_decoder",
                                 output_sizes=dec_units + [2 * n_inp],
                                 activation=activation)
     net = mlp_decoding(x)
     net = tf.clip_by_value(net, -10, 10)
     return generative_utils.QuantizedNormal(mu_log_sigma=net)
    def decoder_fn(net):
      """Decoder for VAE."""
      if batch["image"].shape.as_list()[1] == 28:
        net = snt.Linear(7 * 7 * 32)(net)
        net = tf.reshape(net, [-1, 7, 7, 32])
      elif batch["image"].shape.as_list()[1] == 32:
        net = snt.Linear(8 * 8 * 32)(net)
        net = tf.reshape(net, [-1, 8, 8, 32])
      else:
        raise ValueError("Only 32x32 or 28x28 supported!")

      net = snt.nets.ConvNet2DTranspose(
          dec_units,
          shapes,
          kernel_shapes=[(3, 3)],
          strides=[2, 2],
          paddings=[snt.SAME],
          activation=activation_fn,
          activate_final=True)(
              net)

      outchannel = batch["image"].shape.as_list()[3]
      net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net)
      net = tf.clip_by_value(net, -10, 10)

      return generative_utils.QuantizedNormal(mu_log_sigma=net)
Esempio n. 3
0
 def decoder_fn(net):
     hidden_units = cfg["dec_hidden_units"] + [
         flat_img.shape.as_list()[1] * 2
     ]
     mod = snt.nets.MLP(hidden_units,
                        activation=act_fn,
                        initializers=init)
     net = mod(net)
     net = tf.clip_by_value(net, -10, 10)
     return generative_utils.QuantizedNormal(mu_log_sigma=net)
Esempio n. 4
0
        def decoder_fn(net):
            """Decoder for VAE."""
            net = snt.Linear(4 * 4 * 32)(net)
            net = tf.reshape(net, [-1, 4, 4, 32])
            net = snt.nets.ConvNet2DTranspose(dec_units,
                                              shapes,
                                              kernel_shapes=[(3, 3)],
                                              strides=[2, 2, 2],
                                              paddings=[snt.SAME],
                                              activation=activation_fn,
                                              activate_final=True)(net)

            outchannel = batch["image"].shape.as_list()[3]
            net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net)
            net = tf.clip_by_value(net, -10, 10)

            return generative_utils.QuantizedNormal(mu_log_sigma=net)