Example #1
0
 def bottom(self, inputs):
     with tf.variable_scope(self.name):
         inputs_shape = common_layers.shape_list(inputs)
         if len(inputs_shape) != 5:
             raise ValueError(
                 "Assuming videos given as tensors in the format "
                 "[batch, time, height, width, channels] but got one "
                 "of shape: %s" % str(inputs_shape))
         if not tf.contrib.eager.in_eager_mode():
             if inputs.get_shape().as_list()[1] is None:
                 tf.summary.image("inputs_last_frame",
                                  tf.cast(inputs[:, -1, :, :, :], tf.uint8),
                                  max_outputs=1)
             else:
                 for k in range(inputs_shape[1]):
                     tf.summary.image("inputs_frame_%d" % k,
                                      tf.cast(inputs[:, k, :, :, :],
                                              tf.uint8),
                                      max_outputs=1)
         # Standardize frames.
         inputs = tf.reshape(inputs, [-1] + inputs_shape[2:])
         inputs = common_layers.standardize_images(inputs)
         inputs = tf.reshape(inputs, inputs_shape)
         # Concatenate the time dimension on channels for image models to work.
         transposed = tf.transpose(inputs, [0, 2, 3, 1, 4])
         return tf.reshape(transposed, [
             inputs_shape[0], inputs_shape[2], inputs_shape[3],
             inputs_shape[1] * inputs_shape[4]
         ])
Example #2
0
  def get_sampled_frame(self, pred_frame):
    """Samples the frame based on modality.

      if the modality is L2/L1 then the next predicted frame is the
      next frame and there is no sampling but in case of Softmax loss
      the next actual frame should be sampled from predicted frame.

      This enables multi-frame target prediction with Softmax loss.

    Args:
      pred_frame: predicted frame.

    Returns:
      sampled frame.

    """
    # TODO(lukaszkaiser): the logic below heavily depend on the current
    # (a bit strange) video modalities - we should change that.

    if self.is_per_pixel_softmax:
      frame_shape = common_layers.shape_list(pred_frame)
      target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels]
      sampled_frame = tf.reshape(pred_frame, target_shape + [256])
      sampled_frame = pixels_from_softmax(
          sampled_frame, temperature=self.hparams.pixel_sampling_temperature)
      # TODO(lukaszkaiser): this should be consistent with modality.bottom()
      sampled_frame = common_layers.standardize_images(sampled_frame)
    else:
      x = common_layers.convert_real_to_rgb(pred_frame)
      x = x - tf.stop_gradient(x + tf.round(x))
      x = common_layers.convert_rgb_to_real(x)
      return x
    return sampled_frame
Example #3
0
  def get_sampled_frame(self, pred_frame):
    """Samples the frame based on modality.

      if the modality is L2/L1 then the next predicted frame is the
      next frame and there is no sampling but in case of Softmax loss
      the next actual frame should be sampled from predicted frame.

      This enables multi-frame target prediction with Softmax loss.

    Args:
      pred_frame: predicted frame.

    Returns:
      sampled frame.

    """
    # TODO(lukaszkaiser): the logic below heavily depend on the current
    # (a bit strange) video modalities - we should change that.

    if self.is_per_pixel_softmax:
      frame_shape = common_layers.shape_list(pred_frame)
      target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels]
      sampled_frame = tf.reshape(pred_frame, target_shape + [256])
      sampled_frame = pixels_from_softmax(
          sampled_frame, temperature=self.hparams.pixel_sampling_temperature)
      # TODO(lukaszkaiser): this should be consistent with modality.bottom()
      sampled_frame = common_layers.standardize_images(sampled_frame)
    else:
      x = common_layers.convert_real_to_rgb(pred_frame)
      x = x - tf.stop_gradient(x + tf.round(x))
      x = common_layers.convert_rgb_to_real(x)
      return x
    return sampled_frame
Example #4
0
    def get_sampled_frame(self, pred_frame):
        """Samples the frame based on modality.

      if the modality is L2/L1 then the next predicted frame is the
      next frame and there is no sampling but in case of Softmax loss
      the next actual frame should be sampled from predicted frame.

      This enables multi-frame target prediction with Softmax loss.

    Args:
      pred_frame: predicted frame.

    Returns:
      sampled frame.

    """
        if not self.is_per_pixel_softmax:
            return pred_frame
        frame_shape = common_layers.shape_list(pred_frame)
        target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels]
        sampled_frame = tf.reshape(pred_frame, target_shape + [256])
        # TODO(lukaszkaiser): should this be argmax or real sampling.
        sampled_frame = tf.argmax(sampled_frame, axis=-1)
        sampled_frame = tf.to_float(sampled_frame)
        # TODO(lukaszkaiser): this should be consistent with modality.bottom()
        sampled_frame = common_layers.standardize_images(sampled_frame)
        return sampled_frame
Example #5
0
    def bottom(self, inputs):
        """Transform input from data space to model space.

    Perform the Xception "Entry flow", which consists of two convolutional
    filter upscalings followed by three residually connected separable
    convolution blocks.

    Args:
      inputs: A Tensor with shape [batch, ...]
    Returns:
      body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
    """
        with tf.variable_scope(self.name):

            def xnet_resblock(x, filters, res_relu, name):
                with tf.variable_scope(name):
                    y = common_layers.separable_conv_block(
                        x,
                        filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
                        first_relu=True,
                        padding="SAME",
                        force2d=True,
                        name="sep_conv_block")
                    y = common_layers.pool(y, (3, 3),
                                           "MAX",
                                           "SAME",
                                           strides=(2, 2))
                    return y + common_layers.conv_block(x,
                                                        filters, [((1, 1),
                                                                   (1, 1))],
                                                        padding="SAME",
                                                        strides=(2, 2),
                                                        first_relu=res_relu,
                                                        force2d=True,
                                                        name="res_conv0")

            inputs = common_layers.standardize_images(inputs)
            # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet.
            # tf.summary.image("inputs", inputs, max_outputs=2)
            x = common_layers.conv_block(inputs,
                                         32, [((1, 1), (3, 3))],
                                         first_relu=False,
                                         padding="SAME",
                                         strides=(2, 2),
                                         force2d=True,
                                         name="conv0")
            x = common_layers.conv_block(x,
                                         64, [((1, 1), (3, 3))],
                                         padding="SAME",
                                         force2d=True,
                                         name="conv1")
            x = xnet_resblock(x, min(128, self._body_input_depth), True,
                              "block0")
            x = xnet_resblock(x, min(256, self._body_input_depth), False,
                              "block1")
            return xnet_resblock(x, self._body_input_depth, False, "block2")
Example #6
0
  def body(self, features):
    def deconv2d(cur, i, kernel_size, output_filters, activation=tf.nn.relu):
      thicker = common_layers.conv(
          cur,
          output_filters * 4,
          kernel_size,
          padding="SAME",
          activation=activation,
          name="deconv2d" + str(i))
      return tf.depth_to_space(thicker, 2)

    cur_frame = common_layers.standardize_images(features["inputs_0"])
    prev_frame = common_layers.standardize_images(features["inputs_1"])

    frames = tf.concat([cur_frame, prev_frame], axis=3)
    frames = tf.reshape(frames, [-1, 210, 160, 6])

    h1 = tf.layers.conv2d(frames, filters=64, strides=2, kernel_size=(8, 8),
                          padding="SAME", activation=tf.nn.relu)
    h2 = tf.layers.conv2d(h1, filters=128, strides=2, kernel_size=(6, 6),
                          padding="SAME", activation=tf.nn.relu)
    h3 = tf.layers.conv2d(h2, filters=128, strides=2, kernel_size=(6, 6),
                          padding="SAME", activation=tf.nn.relu)
    h4 = tf.layers.conv2d(h3, filters=128, strides=2, kernel_size=(4, 4),
                          padding="SAME", activation=tf.nn.relu)
    h45 = tf.reshape(h4, [-1, 14 * 10 * 128])
    h5 = tf.layers.dense(h45, 2048, activation=tf.nn.relu)
    h6 = tf.layers.dense(h5, 2048, activation=tf.nn.relu)
    h7 = tf.layers.dense(h6, 14 * 10 * 128, activation=tf.nn.relu)
    h8 = tf.reshape(h7, [-1, 14, 10, 128])

    h9 = deconv2d(h8, 1, (4, 4), 128, activation=tf.nn.relu)
    h9 = h9[:, :27, :, :]
    h10 = deconv2d(h9, 2, (6, 6), 128, activation=tf.nn.relu)
    h10 = h10[:, :53, :, :]
    h11 = deconv2d(h10, 3, (6, 6), 128, activation=tf.nn.relu)
    h11 = h11[:, :105, :, :]
    h12 = deconv2d(h11, 4, (8, 8), 3 * 256, activation=tf.identity)

    reward = tf.layers.flatten(h12)

    return {"targets": h12, "reward": reward}
Example #7
0
 def bottom(self, inputs):
   with tf.variable_scope(self.name):
     inputs = common_layers.standardize_images(inputs)
     tf.summary.image("inputs", inputs, max_outputs=2)
     return common_layers.conv_block(
         inputs,
         self._body_input_depth, [((1, 1), (3, 3))],
         first_relu=False,
         padding="SAME",
         force2d=True,
         name="small_image_conv")
Example #8
0
 def bottom(self, inputs):
     with tf.variable_scope(self.name):
         common_layers.summarize_video(inputs, "inputs")
         inputs_shape = common_layers.shape_list(inputs)
         # Standardize frames.
         inputs = tf.reshape(inputs, [-1] + inputs_shape[2:])
         inputs = common_layers.standardize_images(inputs)
         inputs = tf.reshape(inputs, inputs_shape)
         # Concatenate the time dimension on channels for image models to work.
         transposed = tf.transpose(inputs, [0, 2, 3, 1, 4])
         return tf.reshape(transposed, [
             inputs_shape[0], inputs_shape[2], inputs_shape[3],
             inputs_shape[1] * inputs_shape[4]
         ])
Example #9
0
 def bottom(self, x):
   inputs = x
   with tf.variable_scope(self.name):
     common_layers.summarize_video(inputs, "inputs")
     inputs_shape = common_layers.shape_list(inputs)
     # Standardize frames.
     inputs = tf.reshape(inputs, [-1] + inputs_shape[2:])
     inputs = common_layers.standardize_images(inputs)
     inputs = tf.reshape(inputs, inputs_shape)
     # Concatenate the time dimension on channels for image models to work.
     transposed = tf.transpose(inputs, [0, 2, 3, 1, 4])
     return tf.reshape(transposed, [
         inputs_shape[0], inputs_shape[2], inputs_shape[3],
         inputs_shape[1] * inputs_shape[4]
     ])
Example #10
0
 def bottom(self, inputs):
     with tf.variable_scope(self.name):
         inputs = common_layers.standardize_images(inputs)
         # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet.
         # tf.summary.image("inputs", inputs, max_outputs=2)
         if self._model_hparams.compress_steps > 0:
             strides = (2, 2)
         else:
             strides = (1, 1)
         return common_layers.conv_block(inputs,
                                         self._body_input_depth,
                                         [((1, 1), (3, 3))],
                                         first_relu=False,
                                         strides=strides,
                                         padding="SAME",
                                         force2d=True,
                                         name="small_image_conv")
Example #11
0
def xception_entry(inputs, hidden_dim):
    with tf.variable_scope("xception_entry"):

        def xnet_resblock(x, filters, res_relu, name):
            with tf.variable_scope(name):
                y = common_layers.separable_conv_block(x,
                                                       filters,
                                                       [((1, 1), (3, 3)),
                                                        ((1, 1), (3, 3))],
                                                       first_relu=True,
                                                       padding="SAME",
                                                       force2d=True,
                                                       name="sep_conv_block")
                y = common_layers.pool(y, (3, 3),
                                       "MAX",
                                       "SAME",
                                       strides=(2, 2))
                return y + common_layers.conv_block(x,
                                                    filters, [((1, 1),
                                                               (1, 1))],
                                                    padding="SAME",
                                                    strides=(2, 2),
                                                    first_relu=res_relu,
                                                    force2d=True,
                                                    name="res_conv0")

        inputs = common_layers.standardize_images(inputs)
        # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet.
        # tf.summary.image("inputs", inputs, max_outputs=2)
        x = common_layers.conv_block(inputs,
                                     32, [((1, 1), (3, 3))],
                                     first_relu=False,
                                     padding="SAME",
                                     strides=(2, 2),
                                     force2d=True,
                                     name="conv0")
        x = common_layers.conv_block(x,
                                     64, [((1, 1), (3, 3))],
                                     padding="SAME",
                                     force2d=True,
                                     name="conv1")
        x = xnet_resblock(x, min(128, hidden_dim), True, "block0")
        x = xnet_resblock(x, min(256, hidden_dim), False, "block1")
        return xnet_resblock(x, hidden_dim, False, "block2")
Example #12
0
 def bottom(self, inputs):
     with tf.variable_scope(self.name):
         inputs_shape = common_layers.shape_list(inputs)
         if len(inputs_shape) != 5:
             raise ValueError(
                 "Assuming videos given as tensors in the format "
                 "[batch, time, height, width, channels].")
         if not context.in_eager_mode():
             tf.summary.image("inputs",
                              tf.cast(inputs[:, -1, :, :, :], tf.uint8),
                              max_outputs=1)
         # Standardize frames.
         inputs = tf.reshape(inputs, [-1] + inputs_shape[2:])
         inputs = common_layers.standardize_images(inputs)
         inputs = tf.reshape(inputs, inputs_shape)
         # Concatenate the time dimension on channels for image models to work.
         transposed = tf.transpose(inputs, [0, 2, 3, 1, 4])
         return tf.reshape(transposed, [
             inputs_shape[0], inputs_shape[2], inputs_shape[3],
             inputs_shape[1] * inputs_shape[4]
         ])
Example #13
0
def xception_entry(inputs, hidden_dim):
  with tf.variable_scope("xception_entry"):

    def xnet_resblock(x, filters, res_relu, name):
      with tf.variable_scope(name):
        y = common_layers.separable_conv_block(
            x,
            filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
            first_relu=True,
            padding="SAME",
            force2d=True,
            name="sep_conv_block")
        y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2))
        return y + common_layers.conv_block(
            x,
            filters, [((1, 1), (1, 1))],
            padding="SAME",
            strides=(2, 2),
            first_relu=res_relu,
            force2d=True,
            name="res_conv0")

    inputs = common_layers.standardize_images(inputs)
    # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet.
    # tf.summary.image("inputs", inputs, max_outputs=2)
    x = common_layers.conv_block(
        inputs,
        32, [((1, 1), (3, 3))],
        first_relu=False,
        padding="SAME",
        strides=(2, 2),
        force2d=True,
        name="conv0")
    x = common_layers.conv_block(
        x, 64, [((1, 1), (3, 3))], padding="SAME", force2d=True, name="conv1")
    x = xnet_resblock(x, min(128, hidden_dim), True, "block0")
    x = xnet_resblock(x, min(256, hidden_dim), False, "block1")
    return xnet_resblock(x, hidden_dim, False, "block2")
Example #14
0
 def testStandardizeImages(self):
     x = np.random.rand(5, 7, 7, 3)
     with self.test_session() as session:
         y = common_layers.standardize_images(tf.constant(x))
         res = session.run(y)
     self.assertEqual(res.shape, (5, 7, 7, 3))
    def body(self, features):
        hparams = self.hparams
        is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT

        # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on
        # using the default (a bit strange) video modality - we should change that.

        # Split inputs and targets into lists.
        input_frames = tf.unstack(features["inputs"], axis=1)
        target_frames = tf.unstack(features["targets"], axis=1)
        all_frames = input_frames + target_frames
        if "input_action" in features:
            input_actions = list(
                tf.split(features["input_action"],
                         hparams.video_num_input_frames,
                         axis=1))
            target_actions = list(
                tf.split(features["target_action"],
                         hparams.video_num_target_frames,
                         axis=1))
            all_actions = input_actions + target_actions

        orig_frame_shape = common_layers.shape_list(all_frames[0])

        # Run a number of steps.
        res_frames, sampled_frames, sampled_frames_raw = [], [], []
        if "target_reward" in features:
            res_rewards, extra_loss = [], 0.0
        sample_prob = common_layers.inverse_exp_decay(
            hparams.scheduled_sampling_warmup_steps)
        sample_prob *= hparams.scheduled_sampling_prob
        for i in range(hparams.video_num_target_frames):
            cur_frames = all_frames[i:i + hparams.video_num_input_frames]
            features["inputs"] = tf.concat(cur_frames, axis=-1)
            features["cur_target_frame"] = all_frames[
                i + hparams.video_num_input_frames]
            if "input_action" in features:
                cur_actions = all_actions[i:i + hparams.video_num_input_frames]
                features["input_action"] = tf.concat(cur_actions, axis=1)

            # Run model.
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                if "target_reward" not in features:
                    res_frame = self.body_single(features)
                else:
                    res_dict, res_extra_loss = self.body_single(features)
                    extra_loss += res_extra_loss
                    res_frame = res_dict["targets"]
                    res_reward = res_dict["target_reward"]
                    res_rewards.append(res_reward)
            res_frames.append(res_frame)

            # Only for Softmax loss: sample frame so we can keep iterating.
            sampled_frame_raw = self.get_sampled_frame(res_frame)
            sampled_frames_raw.append(sampled_frame_raw)
            # TODO(lukaszkaiser): this should be consistent with modality.bottom()
            sampled_frame = common_layers.standardize_images(sampled_frame_raw)
            sampled_frames.append(sampled_frame)

            if is_predicting:
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

            # Scheduled sampling during training.
            if (hparams.scheduled_sampling_prob > 0.0 and self.is_training):
                do_sample = tf.less(tf.random_uniform([orig_frame_shape[0]]),
                                    sample_prob)
                orig_frame = all_frames[i + hparams.video_num_input_frames]
                sampled_frame = tf.where(do_sample, sampled_frame, orig_frame)
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

        # Concatenate results and return them.
        frames = tf.stack(res_frames, axis=1)

        if "target_reward" not in features:
            return frames
        rewards = tf.concat(res_rewards, axis=1)
        return {"targets": frames, "target_reward": rewards}, extra_loss
Example #16
0
 def targets_bottom(self, x):
     common_video.gif_summary("targets", x, max_outputs=1)
     x = common_layers.standardize_images(x)
     return x
Example #17
0
 def bottom(self, inputs):
   with tf.variable_scope(self.name):
     inputs = common_layers.standardize_images(inputs)
     if not context.in_eager_mode():
       tf.summary.image("inputs", inputs, max_outputs=2)
     return tf.to_float(inputs)
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_extra):
        del rewards, video_extra

        hparams = self.hparams
        filters = hparams.hidden_size
        kernel2 = (4, 4)
        action = actions[-1]
        activation_fn = common_layers.belu
        if self.hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        # Normalize frames.
        frames = [common_layers.standardize_images(f) for f in frames]

        # Stack the inputs.
        if internal_states is not None and hparams.concat_internal_states:
            # Use the first part of the first internal state if asked to concatenate.
            batch_size = common_layers.shape_list(frames[0])[0]
            internal_state = internal_states[0][0][:batch_size, :, :, :]
            stacked_frames = tf.concat(frames + [internal_state], axis=-1)
        else:
            stacked_frames = tf.concat(frames, axis=-1)
        inputs_shape = common_layers.shape_list(stacked_frames)

        # Update internal states early if requested.
        if hparams.concat_internal_states:
            internal_states = self.update_internal_states_early(
                internal_states, frames)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_frames,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=activation_fn,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        if self.has_actions:
            with tf.variable_scope("policy"):
                x_flat = tf.layers.flatten(x)
                policy_pred = tf.layers.dense(x_flat,
                                              self.hparams.problem.num_actions)
                value_pred = tf.layers.dense(x_flat, 1)
                value_pred = tf.squeeze(value_pred, axis=-1)
        else:
            policy_pred, value_pred = None, None

        # Add embedded action if present.
        if self.has_actions:
            x = common_video.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)

        # Inject latent if present. Only for stochastic models.
        norm_target_frame = common_layers.standardize_images(target_frame)
        x, extra_loss = self.inject_latent(x, frames, norm_target_frame,
                                           action)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x, internal_states = self.middle_network(x, internal_states)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                if self.has_actions:
                    x = common_video.inject_additional_input(
                        x, action, "action_enc", hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=activation_fn,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if hparams.do_autoregressive_rnn:
            # If enabled, we predict the target frame autoregregressively using rnns.
            # To this end, the current prediciton is flattened into one long sequence
            # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value,
            # from 0 to 255) is predicted with an RNN. To avoid doing as many steps
            # as width * height * channels, we only use a number of pixels back,
            # as many as hparams.autoregressive_rnn_lookback.
            with tf.variable_scope("autoregressive_rnn"):
                batch_size = common_layers.shape_list(frames[0])[0]
                # Height, width, channels and lookback are the constants we need.
                h, w = inputs_shape[1], inputs_shape[
                    2]  # 105, 80 on Atari games
                c = hparams.problem.num_channels
                lookback = hparams.autoregressive_rnn_lookback
                assert (
                    h * w
                ) % lookback == 0, "Number of pixels must divide lookback."
                m = (h * w) // lookback  # Batch size multiplier for the RNN.
                # These are logits that will be used as inputs to the RNN.
                rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs")
                # They are of shape [batch_size, h, w, c, 64], reshaping now.
                rnn_inputs = tf.reshape(rnn_inputs,
                                        [batch_size * m, lookback * c, 64])
                # Same for the target frame.
                rnn_target = tf.reshape(target_frame,
                                        [batch_size * m, lookback * c])
                # Construct rnn starting state: flatten rnn_inputs, apply a relu layer.
                rnn_start_state = tf.nn.relu(
                    tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)),
                                    256,
                                    name="rnn_start_state"))
                # Our RNN function API is on bits, each subpixel has 8 bits.
                total_num_bits = lookback * c * 8
                # We need to provide RNN targets as bits (due to the API).
                rnn_target_bits = discretization.int_to_bit(rnn_target, 8)
                rnn_target_bits = tf.reshape(rnn_target_bits,
                                             [batch_size * m, total_num_bits])
                if self.is_training:
                    # Run the RNN in training mode, add it's loss to the losses.
                    rnn_predict, rnn_loss = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        target_bits=rnn_target_bits,
                        extra_inputs=rnn_inputs)
                    extra_loss += rnn_loss
                    # We still use non-RNN predictions too in order to guide the network.
                    x = tf.layers.dense(x, c * 256, name="logits")
                    x = tf.reshape(x, [batch_size, h, w, c, 256])
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c, 256])
                    # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%.
                    x = tf.reshape(tf.nn.log_softmax(x),
                                   [batch_size, h, w, c * 256])
                    rnn_predict = tf.nn.log_softmax(rnn_predict)
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c * 256])
                    alpha = 0.9 * common_layers.inverse_lin_decay(
                        hparams.autoregressive_rnn_warmup_steps)
                    x = alpha * rnn_predict + (1.0 - alpha) * x
                else:
                    # In prediction mode, run the RNN without any targets.
                    bits, _ = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        extra_inputs=rnn_inputs,
                        temperature=0.0
                    )  # No sampling from this RNN, just greedy.
                    # The output is in bits, get back the predicted pixels.
                    bits = tf.reshape(bits, [batch_size * m, lookback * c, 8])
                    ints = discretization.bit_to_int(tf.maximum(bits, 0), 8)
                    ints = tf.reshape(ints, [batch_size, h, w, c])
                    x = tf.reshape(tf.one_hot(ints, 256),
                                   [batch_size, h, w, c * 256])
        elif self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        reward_pred = None
        if self.has_rewards:
            # Reward prediction based on middle and final logits.
            reward_pred = tf.concat([x_mid, x_fin], axis=-1)
            reward_pred = tf.nn.relu(
                tf.layers.dense(reward_pred, 128, name="reward_pred"))
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
    def body(self, features):
        hparams = self.hparams
        is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT

        # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on
        # using the default (a bit strange) video modality - we should change that.

        # Split inputs and targets into lists.
        input_frames = tf.unstack(features["inputs"], axis=1)
        target_frames = tf.unstack(features["targets"], axis=1)
        all_frames = input_frames + target_frames
        if "input_action" in features:
            input_actions = list(
                tf.split(features["input_action"],
                         hparams.video_num_input_frames,
                         axis=1))
            target_actions = list(
                tf.split(features["target_action"],
                         hparams.video_num_target_frames,
                         axis=1))
            all_actions = input_actions + target_actions

        orig_frame_shape = common_layers.shape_list(all_frames[0])
        batch_size = orig_frame_shape[0]
        ss_func = self.get_scheduled_sample_func(batch_size)

        # Run a number of steps.
        res_frames, sampled_frames, sampled_frames_raw = [], [], []
        extra_loss = 0.0
        if "target_reward" in features:
            res_rewards = []
        for i in range(hparams.video_num_target_frames):
            cur_frames = all_frames[i:i + hparams.video_num_input_frames]
            features["inputs"] = tf.concat(cur_frames, axis=-1)
            features["cur_target_frame"] = all_frames[
                i + hparams.video_num_input_frames]
            if "input_action" in features:
                cur_actions = all_actions[i:i + hparams.video_num_input_frames]
                features["input_action"] = tf.concat(cur_actions, axis=1)

            # Run model.
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                if "target_reward" not in features:
                    res_frame, res_extra_loss = self.body_single(features)
                else:
                    res_dict, res_extra_loss = self.body_single(features)
                    res_frame = res_dict["targets"]
                    res_reward = res_dict["target_reward"]
                    res_rewards.append(res_reward)
            extra_loss += res_extra_loss / float(
                hparams.video_num_target_frames)
            res_frames.append(res_frame)

            # Only for Softmax loss: sample frame so we can keep iterating.
            sampled_frame_raw = self.get_sampled_frame(res_frame)
            sampled_frames_raw.append(sampled_frame_raw)
            # TODO(lukaszkaiser): this should be consistent with modality.bottom()
            sampled_frame = common_layers.standardize_images(sampled_frame_raw)
            sampled_frames.append(sampled_frame)

            if is_predicting:
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

            # Scheduled sampling during training.
            if self.is_training:
                done_warm_start = True  # Always true for non-reccurent networks.
                groundtruth_items = [
                    all_frames[i + hparams.video_num_input_frames]
                ]
                generated_items = [sampled_frame]
                ss_frame, = self.get_scheduled_sample_inputs(
                    done_warm_start, groundtruth_items, generated_items,
                    ss_func)
                all_frames[i + hparams.video_num_input_frames] = ss_frame

        # Concatenate results and return them.
        frames = tf.stack(res_frames, axis=1)

        if "target_reward" not in features:
            return frames
        rewards = tf.concat(res_rewards, axis=1)
        return {"targets": frames, "target_reward": rewards}, extra_loss
Example #20
0
    def network(self):
        def middle_network(layer):
            # Run a stack of convolutions.
            x = layer
            kernel1 = (3, 3)
            filters = common_layers.shape_list(x)[-1]
            for i in range(2):
                with tf.variable_scope("layer%d" % i):
                    y = tf.nn.dropout(x, 1.0 - 0.5)
                    y = tf.layers.conv2d(y,
                                         filters,
                                         kernel1,
                                         activation=self.activation_fn,
                                         strides=(1, 1),
                                         padding="SAME")
                    if i == 0:
                        x = y
                    else:
                        x = common_layers.layer_norm(x + y)
            return x

        batch_size = tf.shape(self.states_ph)[0]

        filters = self.hidden_size
        kernel2 = (4, 4)
        action = self.actions_oph  #[0] NOTE - might remove this

        # Normalize states
        if (self.n_envs > 1):
            states = [
                common_layers.standardize_images(self.states_ph[i, :, :, :])
                for i in range(self.n_envs)
            ]
            stacked_states = tf.stack(states)
        else:
            stacked_states = common_layers.standardize_images(self.states_ph)
        inputs_shape = common_layers.shape_list(stacked_states)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_states,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(self.layers):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.dropout_p)
                x = common_layers.make_even_size(x)
                if i < 2:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=self.activation_fn,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        if self.is_policy:
            with tf.variable_scope("policy"):
                x_flat = tf.layers.flatten(x)
                policy_pred = tf.layers.dense(x_flat, self.action_dim)
                value_pred = tf.layers.dense(x_flat, 1)
                value_pred = tf.squeeze(value_pred, axis=-1)
        else:
            policy_pred, value_pred = None, None

        #if self.has_actions:
        x = inject_additional_input(x, action, "action_enc", "multi_additive")

        # Inject latent if present. Only for stochastic models.
        target_states = common_layers.standardize_images(self.target_states)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x = middle_network(x)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(self.layers):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - 0.1)
                if i >= self.layers - 2:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=self.activation_fn,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)

        x = tf.layers.dense(x, self.depth, name="logits")

        reward_pred = None
        if self.has_rewards:
            # Reward prediction based on middle and final logits.
            reward_pred = tf.concat([x_mid, x_fin], axis=-1)
            reward_pred = tf.nn.relu(
                tf.layers.dense(reward_pred, 128, name="reward_pred"))
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred
 def bottom(self, x):
   inputs = x
   with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
     common_layers.summarize_video(inputs, "inputs")
     inputs = common_layers.standardize_images(inputs)
     return common_layers.time_to_channels(inputs)
Example #22
0
  def body(self, features):
    hparams = self.hparams
    self.has_action = "input_action" in features
    self.has_reward = "target_reward" in features
    # dirty hack to enable the latent tower
    self.features = features

    # Split inputs and targets into lists.
    input_frames = tf.unstack(features["inputs"], axis=1)
    target_frames = tf.unstack(features["targets"], axis=1)
    all_frames = input_frames + target_frames
    if self.has_action:
      input_actions = tf.unstack(features["input_action"], axis=1)
      target_actions = tf.unstack(features["target_action"], axis=1)
      all_actions = input_actions + target_actions

    res_frames, sampled_frames, sampled_frames_raw, res_rewards = [], [], [], []
    lstm_states = [None] * hparams.num_lstm_layers
    extra_loss = 0.0

    num_frames = len(all_frames)
    for i in range(num_frames - 1):
      frame = all_frames[i]
      action = all_actions[i] if self.has_action else None

      # more hack to enable latent_tower
      # TODO(mbz): clean this up.
      self.features["inputs"] = all_frames[i]
      self.features["cur_target_frame"] = all_frames[i+1]

      # Run model.
      with tf.variable_scope("recurrent_model", reuse=tf.AUTO_REUSE):
        func_out = self.predict_next_frame(frame, action, lstm_states)
        res_frame, res_reward, res_extra_loss, lstm_states = func_out
        res_frames.append(res_frame)
        res_rewards.append(res_reward)
        extra_loss += res_extra_loss

      sampled_frame_raw = self.get_sampled_frame(res_frame)
      sampled_frames_raw.append(sampled_frame_raw)
      # TODO(lukaszkaiser): this should be consistent with modality.bottom()
      sampled_frame = common_layers.standardize_images(sampled_frame_raw)
      sampled_frames.append(sampled_frame)

      # Only for Softmax loss: sample next frame so we can keep iterating.
      if self.is_predicting and i >= hparams.video_num_input_frames:
        all_frames[i+1] = sampled_frame

    # Concatenate results and return them.
    output_frames = res_frames[hparams.video_num_input_frames-1:]
    frames = tf.stack(output_frames, axis=1)

    has_input_predictions = hparams.video_num_input_frames > 1
    if self.is_training and hparams.internal_loss and has_input_predictions:
      # add the loss for input frames as well.
      extra_gts = input_frames[1:]
      extra_pds = res_frames[:hparams.video_num_input_frames-1]
      extra_raw_gts = features["inputs_raw"][:, 1:]
      recon_loss = self.get_extra_internal_loss(
          extra_raw_gts, extra_gts, extra_pds)
      extra_loss += recon_loss

    if not self.has_reward:
      return frames, extra_loss
    rewards = tf.concat(res_rewards[hparams.video_num_input_frames-1:], axis=1)
    return {"targets": frames, "target_reward": rewards}, extra_loss
Example #23
0
 def bottom(self, inputs):
   with tf.variable_scope(self.name):
     inputs = common_layers.standardize_images(inputs)
     if not context.in_eager_mode():
       tf.summary.image("inputs", inputs, max_outputs=2)
     return tf.to_float(inputs)
Example #24
0
    def body(self, features):
        hparams = self.hparams
        is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT
        if hparams.video_num_target_frames < 2:
            res = self.body_single(features)
            return res

        # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on
        # using the default (a bit strange) video modality - we should change that.

        # Split inputs and targets into lists.
        input_frames = list(
            tf.split(features["inputs"],
                     hparams.video_num_input_frames,
                     axis=-1))
        target_frames = list(
            tf.split(features["targets"],
                     hparams.video_num_target_frames,
                     axis=-1))
        all_frames = input_frames + target_frames
        if "input_action" in features:
            input_actions = list(
                tf.split(features["input_action"],
                         hparams.video_num_input_frames,
                         axis=1))
            target_actions = list(
                tf.split(features["target_action"],
                         hparams.video_num_target_frames,
                         axis=1))
            all_actions = input_actions + target_actions

        # Run a number of steps.
        res_frames = []
        if "target_reward" in features:
            res_rewards, extra_loss = [], 0.0
        sample_prob = common_layers.inverse_exp_decay(
            hparams.scheduled_sampling_warmup_steps)
        sample_prob *= hparams.scheduled_sampling_prob
        for i in range(hparams.video_num_target_frames):
            cur_frames = all_frames[i:i + hparams.video_num_input_frames]
            features["inputs"] = tf.concat(cur_frames, axis=-1)
            if "input_action" in features:
                cur_actions = all_actions[i:i + hparams.video_num_input_frames]
                features["input_action"] = tf.concat(cur_actions, axis=1)

            # Run model.
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                if "target_reward" not in features:
                    res_frames.append(self.body_single(features))
                else:
                    res_dict, res_extra_loss = self.body_single(features)
                    extra_loss += res_extra_loss
                    res_frames.append(res_dict["targets"])
                    res_rewards.append(res_dict["target_reward"])

            # When predicting, use the generated frame.
            orig_frame = all_frames[i + hparams.video_num_input_frames]
            shape = common_layers.shape_list(orig_frame)
            sampled_frame = tf.reshape(
                res_frames[-1],
                shape[:-1] + [hparams.problem.num_channels, 256])
            sampled_frame = tf.to_float(tf.argmax(sampled_frame, axis=-1))
            sampled_frame = common_layers.standardize_images(sampled_frame)
            if is_predicting:
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

            # Scheduled sampling during training.
            if (hparams.scheduled_sampling_prob > 0.0 and self.is_training):
                do_sample = tf.less(tf.random_uniform([shape[0]]), sample_prob)
                sampled_frame = tf.where(do_sample, sampled_frame, orig_frame)
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

        # Concatenate results and return them.
        frames = tf.concat(res_frames, axis=-1)
        if "target_reward" not in features:
            return frames
        rewards = tf.concat(res_rewards, axis=1)
        return {"targets": frames, "target_reward": rewards}, extra_loss