예제 #1
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)

        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            layer_shape = common_layers.shape_list(layer)
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                rand = tf.random_uniform(layer_shape[:-3] +
                                         [1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            z = tf.layers.dense(d, final_filters, name="unbottleneck")
            return layer + z, 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.tanh(
            tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck"))
        d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            d *= noise

        z = tf.layers.dense(d, final_filters, name="unbottleneck")
        return layer + z, 0.0
예제 #2
0
 def discriminate(x):
   """Run a dioscriminator depending on the hparams."""
   if hparams.discriminator == "default":
     return common_layers.deep_discriminator(
         x, hparams.discriminator_batchnorm, is_training)
   elif hparams.discriminator == "patched":
     return common_layers.patch_discriminator(x)
   elif hparams.discriminator == "single":
     return common_layers.single_discriminator(
         x,
         hparams.discriminator_size,
         hparams.discriminator_kernel_size,
         hparams.discriminator_strides,
         pure_mean=hparams.discriminator_pure_mean)
   elif hparams.discriminator == "double":
     return common_layers.double_discriminator(
         x,
         hparams.discriminator_size,
         hparams.discriminator_kernel_size,
         hparams.discriminator_strides,
         pure_mean=hparams.discriminator_pure_mean)
   else:
     raise Exception("Unknown discriminator %s" % hparams.discriminator)
예제 #3
0
 def discriminate(x):
   """Run a dioscriminator depending on the hparams."""
   if hparams.discriminator == "default":
     return common_layers.deep_discriminator(
         x, hparams.discriminator_batchnorm, is_training)
   elif hparams.discriminator == "patched":
     return common_layers.patch_discriminator(x)
   elif hparams.discriminator == "single":
     return common_layers.single_discriminator(
         x,
         hparams.discriminator_size,
         hparams.discriminator_kernel_size,
         hparams.discriminator_strides,
         pure_mean=hparams.discriminator_pure_mean)
   elif hparams.discriminator == "double":
     return common_layers.double_discriminator(
         x,
         hparams.discriminator_size,
         hparams.discriminator_kernel_size,
         hparams.discriminator_strides,
         pure_mean=hparams.discriminator_pure_mean)
   else:
     raise Exception("Unknown discriminator %s" % hparams.discriminator)
예제 #4
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
예제 #5
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        batch_size = layer_shape[0]
        state_size = hparams.latent_predictor_state_size
        lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
        discrete_predict = tf.layers.Dense(256, name="discrete_predict")
        discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")

        def add_d(layer, d):
            z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if self.is_predicting:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                layer_pred = tf.reshape(
                    layer, [batch_size, prod(layer_shape[1:])])
                prediction = tf.layers.dense(layer_pred,
                                             state_size,
                                             name="istate")
                c_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="cstate")
                m_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="mstate")
                state = (c_state, m_state)
                outputs = []
                for i in range(hparams.bottleneck_bits // 8):
                    output, state = lstm_cell(prediction, state)
                    discrete_logits = discrete_predict(output)
                    discrete_samples = common_layers.sample_with_temperature(
                        discrete_logits, hparams.latent_predictor_temperature)
                    outputs.append(tf.expand_dims(discrete_samples, axis=1))
                    prediction = discrete_embed(
                        tf.one_hot(discrete_samples, 256))
                outputs = tf.concat(outputs, axis=1)
                outputs = discretization.int_to_bit(outputs, 8)
                rand = tf.reshape(outputs,
                                  [batch_size, 1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return add_d(layer, d), 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")
        x0 = tf.tanh(x)
        d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 -
                                  x0)
        pred_loss = 0.0
        if not hparams.full_latent_tower:
            d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0),
                                [batch_size, hparams.bottleneck_bits // 8, 8])
            d_int = discretization.bit_to_int(d_pred, 8)
            tf.summary.histogram("d_int", tf.reshape(d_int, [-1]))
            d_hot = tf.one_hot(d_int, 256, axis=-1)
            d_pred = discrete_embed(d_hot)
            layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])])
            prediction0 = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="istate")
            c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
            m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
            pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred],
                             axis=1)
            state = (c_state, m_state)
            outputs = []
            for i in range(hparams.bottleneck_bits // 8):
                output, state = lstm_cell(pred[:, i, :], state)
                outputs.append(tf.expand_dims(output, axis=1))
            outputs = tf.concat(outputs, axis=1)
            d_int_pred = discrete_predict(outputs)
            pred_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=d_int_pred, labels=d_int)
            pred_loss = tf.reduce_mean(pred_loss)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            x += tf.truncated_normal(common_layers.shape_list(x),
                                     mean=0.0,
                                     stddev=0.2)
            x = tf.tanh(x)
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            x *= noise
            d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 -
                                     x)
            p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
            d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)
        return add_d(layer, d), pred_loss
    def inject_latent(self, layer, inputs, target):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=common_layers.belu,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)

        return add_bits(layer, bits), pred_loss
예제 #7
0
  def inject_latent(self, layer, inputs, target, action):
    """Inject a deterministic latent based on the target frame."""
    hparams = self.hparams
    final_filters = common_layers.shape_list(layer)[-1]
    filters = hparams.hidden_size
    kernel = (4, 4)
    layer_shape = common_layers.shape_list(layer)

    def add_bits(layer, bits):
      z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
      if not hparams.complex_addn:
        return layer + z_mul
      layer *= tf.nn.sigmoid(z_mul)
      z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
      layer += z_add
      return layer

    if not self.is_training:
      if hparams.full_latent_tower:
        rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits])
        bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      else:
        bits, _ = discretization.predict_bits_with_lstm(
            layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
            temperature=hparams.latent_predictor_temperature)
        bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
      return add_bits(layer, bits), 0.0

    # Embed.
    frames = tf.concat(inputs + [target], axis=-1)
    x = tfl.dense(
        frames, filters, name="latent_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Add embedded action if present.
    if action is not None:
      x = common_video.inject_additional_input(
          x, action, "action_enc_latent", hparams.action_injection)

    if hparams.full_latent_tower:
      for i in range(hparams.num_compress_steps):
        with tf.variable_scope("latent_downstride%d" % i):
          x = common_layers.make_even_size(x)
          if i < hparams.filter_double_steps:
            filters *= 2
          x = common_attention.add_timing_signal_nd(x)
          x = tfl.conv2d(x, filters, kernel,
                         activation=common_layers.belu,
                         strides=(2, 2), padding="SAME")
          x = common_layers.layer_norm(x)
    else:
      x = common_layers.double_discriminator(x)
      x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

    bits, bits_clean = discretization.tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise,
        hparams.discretize_warmup_steps, hparams.mode)
    if not hparams.full_latent_tower:
      _, pred_loss = discretization.predict_bits_with_lstm(
          layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
          target_bits=bits_clean)
      # Mix bits from latent with predicted bits on forward pass as a noise.
      if hparams.latent_rnn_max_sampling > 0.0:
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          bits_pred, _ = discretization.predict_bits_with_lstm(
              layer, hparams.latent_predictor_state_size,
              hparams.bottleneck_bits,
              temperature=hparams.latent_predictor_temperature)
          bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2)
        # Be bits_pred on the forward pass but bits on the backward one.
        bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean)
        # Select which bits to take from pred sampling with bit_p probability.
        which_bit = tf.random_uniform(common_layers.shape_list(bits))
        bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps)
        bit_p *= hparams.latent_rnn_max_sampling
        bits = tf.where(which_bit < bit_p, bits_pred, bits)

    res = add_bits(layer, bits)
    # During training, sometimes skip the latent to help action-conditioning.
    res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2)
    res_p *= hparams.latent_use_max_probability
    res_rand = tf.random_uniform([layer_shape[0]])
    res = tf.where(res_rand < res_p, res, layer)
    return res, pred_loss