Beispiel #1
0
    def simple_discrete_latent_tower(self, input_image, target_image):
        hparams = self.hparams

        if self.is_predicting:
            batch_size = common_layers.shape_list(input_image)[0]
            rand = tf.random_uniform([batch_size, hparams.bottleneck_bits])
            bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return bits

        conv_size = self.tinyify([64, 32, 32, 1])
        pair = tf.concat([input_image, target_image], axis=-1)
        posterior_enc = self.basic_conv_net(pair, conv_size, "posterior_enc")
        posterior_enc = tfl.flatten(posterior_enc)
        bits, _ = discretization.tanh_discrete_bottleneck(
            posterior_enc, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        return bits
Beispiel #2
0
  def learned_discrete_tower(self, input_image, target_image):
    hparams = self.hparams

    # Encode the input frames into a prior encoding.
    conv_size = [64, 32, 32, 1]
    prior_enc = self.basic_conv_net(input_image, conv_size, "prior_enc")
    tower_output_shape = common_layers.shape_list(prior_enc)
    batch_size = tower_output_shape[0]
    prior_enc = tfl.flatten(prior_enc)

    def decode_bits(b):
      return common_video.encode_to_shape(b, tower_output_shape, "bits_dec")

    if self.is_predicting:
      if hparams.full_latent_tower:
        rand = tf.random_uniform([batch_size, hparams.bottleneck_bits])
        bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      else:
        # Generate bit using the learned prior at inference time.
        bits, _ = discretization.predict_bits_with_lstm(
            prior_enc,
            hparams.latent_predictor_state_size,
            hparams.bottleneck_bits,
            temperature=hparams.latent_predictor_temperature)
      return decode_bits(bits), 0.0

    # Encode the input and target frames into posterior.
    x = tf.concat([input_image, target_image], axis=-1)
    x = self.basic_conv_net(x, conv_size, "posterior_enc")
    x = tfl.flatten(x)
    bits, bits_clean = discretization.tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits,
        hparams.bottleneck_noise,
        hparams.discretize_warmup_steps,
        hparams.mode)

    pred_loss = 0.0
    if not hparams.full_latent_tower:
      # Learn the prior by matching the posterior.
      _, pred_loss = discretization.predict_bits_with_lstm(
          prior_enc,
          hparams.latent_predictor_state_size,
          hparams.bottleneck_bits,
          target_bits=bits_clean)

    return decode_bits(bits), pred_loss
Beispiel #3
0
  def simple_discrete_latent_tower(self, input_image, target_image):
    hparams = self.hparams

    if self.is_predicting:
      batch_size = common_layers.shape_list(input_image)[0]
      rand = tf.random_uniform([batch_size, hparams.bottleneck_bits])
      bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      return bits

    conv_size = self.tinyify([64, 32, 32, 1])
    pair = tf.concat([input_image, target_image], axis=-1)
    posterior_enc = self.basic_conv_net(pair, conv_size, "posterior_enc")
    posterior_enc = tfl.flatten(posterior_enc)
    bits, _ = discretization.tanh_discrete_bottleneck(
        posterior_enc,
        hparams.bottleneck_bits,
        hparams.bottleneck_noise,
        hparams.discretize_warmup_steps,
        hparams.mode)
    return bits
Beispiel #4
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
    def inject_latent(self, layer, inputs, target):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=common_layers.belu,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)

        return add_bits(layer, bits), pred_loss
  def inject_latent(self, layer, inputs, target, action):
    """Inject a deterministic latent based on the target frame."""
    hparams = self.hparams
    final_filters = common_layers.shape_list(layer)[-1]
    filters = hparams.hidden_size
    kernel = (4, 4)
    layer_shape = common_layers.shape_list(layer)

    def add_bits(layer, bits):
      z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
      if not hparams.complex_addn:
        return layer + z_mul
      layer *= tf.nn.sigmoid(z_mul)
      z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
      layer += z_add
      return layer

    if not self.is_training:
      if hparams.full_latent_tower:
        rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits])
        bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      else:
        bits, _ = discretization.predict_bits_with_lstm(
            layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
            temperature=hparams.latent_predictor_temperature)
        bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
      return add_bits(layer, bits), 0.0

    # Embed.
    frames = tf.concat(inputs + [target], axis=-1)
    x = tfl.dense(
        frames, filters, name="latent_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Add embedded action if present.
    if action is not None:
      x = common_video.inject_additional_input(
          x, action, "action_enc_latent", hparams.action_injection)

    if hparams.full_latent_tower:
      for i in range(hparams.num_compress_steps):
        with tf.variable_scope("latent_downstride%d" % i):
          x = common_layers.make_even_size(x)
          if i < hparams.filter_double_steps:
            filters *= 2
          x = common_attention.add_timing_signal_nd(x)
          x = tfl.conv2d(x, filters, kernel,
                         activation=common_layers.belu,
                         strides=(2, 2), padding="SAME")
          x = common_layers.layer_norm(x)
    else:
      x = common_layers.double_discriminator(x)
      x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

    bits, bits_clean = discretization.tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise,
        hparams.discretize_warmup_steps, hparams.mode)
    if not hparams.full_latent_tower:
      _, pred_loss = discretization.predict_bits_with_lstm(
          layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
          target_bits=bits_clean)
      # Mix bits from latent with predicted bits on forward pass as a noise.
      if hparams.latent_rnn_max_sampling > 0.0:
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          bits_pred, _ = discretization.predict_bits_with_lstm(
              layer, hparams.latent_predictor_state_size,
              hparams.bottleneck_bits,
              temperature=hparams.latent_predictor_temperature)
          bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2)
        # Be bits_pred on the forward pass but bits on the backward one.
        bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean)
        # Select which bits to take from pred sampling with bit_p probability.
        which_bit = tf.random_uniform(common_layers.shape_list(bits))
        bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps)
        bit_p *= hparams.latent_rnn_max_sampling
        bits = tf.where(which_bit < bit_p, bits_pred, bits)

    res = add_bits(layer, bits)
    # During training, sometimes skip the latent to help action-conditioning.
    res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2)
    res_p *= hparams.latent_use_max_probability
    res_rand = tf.random_uniform([layer_shape[0]])
    res = tf.where(res_rand < res_p, res, layer)
    return res, pred_loss