def reward_prediction_big(
      self, input_images, input_reward, action, latent, mid_outputs):
    """Builds a reward prediction network."""
    del mid_outputs
    conv_size = self.tinyify([32, 32, 16, 8])

    with tf.variable_scope("reward_pred", reuse=tf.AUTO_REUSE):
      x = tf.concat(input_images, axis=3)
      x = tfcl.layer_norm(x)

      if not self.hparams.small_mode:
        x = tfl.conv2d(x, conv_size[1], [3, 3], strides=(2, 2),
                       activation=tf.nn.relu, name="reward_conv1")
        x = tfcl.layer_norm(x)

      # Inject additional inputs
      if action is not None:
        x = common_video.inject_additional_input(
            x, action, "action_enc", self.hparams.action_injection)
      if input_reward is not None:
        x = common_video.inject_additional_input(x, input_reward, "reward_enc")
      if latent is not None:
        latent = tfl.flatten(latent)
        latent = tf.expand_dims(latent, axis=1)
        latent = tf.expand_dims(latent, axis=1)
        x = common_video.inject_additional_input(x, latent, "latent_enc")

      x = tfl.conv2d(x, conv_size[2], [3, 3], strides=(2, 2),
                     activation=tf.nn.relu, name="reward_conv2")
      x = tfcl.layer_norm(x)
      x = tfl.conv2d(x, conv_size[3], [3, 3], strides=(2, 2),
                     activation=tf.nn.relu, name="reward_conv3")
Example #2
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_extra):
        del rewards, video_extra

        hparams = self.hparams
        filters = hparams.hidden_size
        kernel2 = (4, 4)

        # Embed the inputs.
        stacked_frames = tf.concat(frames, axis=-1)
        inputs_shape = common_layers.shape_list(stacked_frames)
        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_frames,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=common_layers.belu,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        # Add embedded action if present.
        if self.has_actions:
            action = actions[-1]
            x = common_video.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)

        # Inject latent if present. Only for stochastic models.
        x, extra_loss = self.inject_latent(x, frames, target_frame)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x, internal_states = self.middle_network(x, internal_states)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                if self.has_actions:
                    x = common_video.inject_additional_input(
                        x, action, "action_enc", hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=common_layers.belu,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        # No reward prediction if not needed.
        if not self.has_rewards:
            return x, None, extra_loss, internal_states

        # Reward prediction based on middle and final logits.
        reward_pred = tf.concat([x_mid, x_fin], axis=-1)
        reward_pred = tf.nn.relu(
            tf.layers.dense(reward_pred, 128, name="reward_pred"))
        reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
        reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
        return x, reward_pred, extra_loss, internal_states
Example #4
0
    def bottom_part_tower(self,
                          input_image,
                          input_reward,
                          action,
                          latent,
                          lstm_state,
                          lstm_size,
                          conv_size,
                          concat_latent=False):
        """The bottom part of predictive towers.

    With the current (early) design, the main prediction tower and
    the reward prediction tower share the same arcitecture. TF Scope can be
    adjusted as required to either share or not share the weights between
    the two towers.

    Args:
      input_image: the current image.
      input_reward: the current reward.
      action: the action taken by the agent.
      latent: the latent vector.
      lstm_state: the current internal states of conv lstms.
      lstm_size: the size of lstms.
      conv_size: the size of convolutions.
      concat_latent: whether or not to concatenate the latent at every step.

    Returns:
      - the output of the partial network.
      - intermidate outputs for skip connections.
    """
        lstm_func = common_video.conv_lstm_2d
        tile_and_concat = common_video.tile_and_concat

        input_image = common_layers.make_even_size(input_image)
        concat_input_image = tile_and_concat(input_image,
                                             latent,
                                             concat_latent=concat_latent)

        layer_id = 0
        enc0 = tfl.conv2d(concat_input_image,
                          conv_size[0], [5, 5],
                          strides=(2, 2),
                          activation=tf.nn.relu,
                          padding="SAME",
                          name="scale1_conv1")
        enc0 = tfcl.layer_norm(enc0, scope="layer_norm1")

        hidden1, lstm_state[layer_id] = lstm_func(enc0,
                                                  lstm_state[layer_id],
                                                  lstm_size[layer_id],
                                                  name="state1")
        hidden1 = tile_and_concat(hidden1, latent, concat_latent=concat_latent)
        hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2")
        layer_id += 1

        hidden2, lstm_state[layer_id] = lstm_func(hidden1,
                                                  lstm_state[layer_id],
                                                  lstm_size[layer_id],
                                                  name="state2")
        hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3")
        hidden2 = common_layers.make_even_size(hidden2)
        enc1 = tfl.conv2d(hidden2,
                          hidden2.get_shape()[3], [3, 3],
                          strides=(2, 2),
                          padding="SAME",
                          activation=tf.nn.relu,
                          name="conv2")
        enc1 = tile_and_concat(enc1, latent, concat_latent=concat_latent)
        layer_id += 1

        if self.hparams.small_mode:
            hidden4, enc2 = hidden2, enc1
        else:
            hidden3, lstm_state[layer_id] = lstm_func(enc1,
                                                      lstm_state[layer_id],
                                                      lstm_size[layer_id],
                                                      name="state3")
            hidden3 = tile_and_concat(hidden3,
                                      latent,
                                      concat_latent=concat_latent)
            hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4")
            layer_id += 1

            hidden4, lstm_state[layer_id] = lstm_func(hidden3,
                                                      lstm_state[layer_id],
                                                      lstm_size[layer_id],
                                                      name="state4")
            hidden4 = tile_and_concat(hidden4,
                                      latent,
                                      concat_latent=concat_latent)
            hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5")
            hidden4 = common_layers.make_even_size(hidden4)
            enc2 = tfl.conv2d(hidden4,
                              hidden4.get_shape()[3], [3, 3],
                              strides=(2, 2),
                              padding="SAME",
                              activation=tf.nn.relu,
                              name="conv3")
            layer_id += 1

        if action is not None:
            enc2 = common_video.inject_additional_input(
                enc2, action, "action_enc", self.hparams.action_injection)
        if input_reward is not None:
            enc2 = common_video.inject_additional_input(
                enc2, input_reward, "reward_enc")
        if latent is not None and not concat_latent:
            with tf.control_dependencies([latent]):
                enc2 = tf.concat([enc2, latent], axis=3)

        enc3 = tfl.conv2d(enc2,
                          hidden4.get_shape()[3], [1, 1],
                          strides=(1, 1),
                          padding="SAME",
                          activation=tf.nn.relu,
                          name="conv4")

        hidden5, lstm_state[layer_id] = lstm_func(enc3,
                                                  lstm_state[layer_id],
                                                  lstm_size[layer_id],
                                                  name="state5")
        hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6")
        hidden5 = tile_and_concat(hidden5, latent, concat_latent=concat_latent)
        layer_id += 1
        return hidden5, (enc0, enc1), layer_id
  def next_frame(self, frames, actions, rewards, target_frame,
                 internal_states, video_extra):
    del rewards, video_extra

    hparams = self.hparams
    filters = hparams.hidden_size
    kernel2 = (4, 4)
    action = actions[-1]

    # Stack the inputs.
    if internal_states is not None and hparams.concat_internal_states:
      # Use the first part of the first internal state if asked to concatenate.
      batch_size = common_layers.shape_list(frames[0])[0]
      internal_state = internal_states[0][0][:batch_size, :, :, :]
      stacked_frames = tf.concat(frames + [internal_state], axis=-1)
    else:
      stacked_frames = tf.concat(frames, axis=-1)
    inputs_shape = common_layers.shape_list(stacked_frames)

    # Update internal states early if requested.
    if hparams.concat_internal_states:
      internal_states = self.update_internal_states_early(
          internal_states, frames)

    # Using non-zero bias initializer below for edge cases of uniform inputs.
    x = tf.layers.dense(
        stacked_frames, filters, name="inputs_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Down-stride.
    layer_inputs = [x]
    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("downstride%d" % i):
        layer_inputs.append(x)
        x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        x = common_layers.make_even_size(x)
        if i < hparams.filter_double_steps:
          filters *= 2
        x = common_attention.add_timing_signal_nd(x)
        x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu,
                             strides=(2, 2), padding="SAME")
        x = common_layers.layer_norm(x)

    # Add embedded action if present.
    if self.has_actions:
      x = common_video.inject_additional_input(
          x, action, "action_enc", hparams.action_injection)

    # Inject latent if present. Only for stochastic models.
    x, extra_loss = self.inject_latent(x, frames, target_frame, action)

    x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    x, internal_states = self.middle_network(x, internal_states)

    # Up-convolve.
    layer_inputs = list(reversed(layer_inputs))
    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("upstride%d" % i):
        x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        if self.has_actions:
          x = common_video.inject_additional_input(
              x, action, "action_enc", hparams.action_injection)
        if i >= hparams.num_compress_steps - hparams.filter_double_steps:
          filters //= 2
        x = tf.layers.conv2d_transpose(
            x, filters, kernel2, activation=common_layers.belu,
            strides=(2, 2), padding="SAME")
        y = layer_inputs[i]
        shape = common_layers.shape_list(y)
        x = x[:, :shape[1], :shape[2], :]
        x = common_layers.layer_norm(x + y)
        x = common_attention.add_timing_signal_nd(x)

    # Cut down to original size.
    x = x[:, :inputs_shape[1], :inputs_shape[2], :]
    x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    if self.is_per_pixel_softmax:
      x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits")
    else:
      x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

    # No reward prediction if not needed.
    if not self.has_rewards:
      return x, None, extra_loss, internal_states

    # Reward prediction based on middle and final logits.
    reward_pred = tf.concat([x_mid, x_fin], axis=-1)
    reward_pred = tf.nn.relu(tf.layers.dense(
        reward_pred, 128, name="reward_pred"))
    reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
    reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
    return x, reward_pred, extra_loss, internal_states
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_extra):
        del rewards, video_extra

        hparams = self.hparams
        filters = hparams.hidden_size
        kernel2 = (4, 4)
        action = actions[-1]
        activation_fn = common_layers.belu
        if self.hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        # Normalize frames.
        frames = [common_layers.standardize_images(f) for f in frames]

        # Stack the inputs.
        if internal_states is not None and hparams.concat_internal_states:
            # Use the first part of the first internal state if asked to concatenate.
            batch_size = common_layers.shape_list(frames[0])[0]
            internal_state = internal_states[0][0][:batch_size, :, :, :]
            stacked_frames = tf.concat(frames + [internal_state], axis=-1)
        else:
            stacked_frames = tf.concat(frames, axis=-1)
        inputs_shape = common_layers.shape_list(stacked_frames)

        # Update internal states early if requested.
        if hparams.concat_internal_states:
            internal_states = self.update_internal_states_early(
                internal_states, frames)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_frames,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=activation_fn,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        if self.has_actions:
            with tf.variable_scope("policy"):
                x_flat = tf.layers.flatten(x)
                policy_pred = tf.layers.dense(x_flat,
                                              self.hparams.problem.num_actions)
                value_pred = tf.layers.dense(x_flat, 1)
                value_pred = tf.squeeze(value_pred, axis=-1)
        else:
            policy_pred, value_pred = None, None

        # Add embedded action if present.
        if self.has_actions:
            x = common_video.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)

        # Inject latent if present. Only for stochastic models.
        norm_target_frame = common_layers.standardize_images(target_frame)
        x, extra_loss = self.inject_latent(x, frames, norm_target_frame,
                                           action)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x, internal_states = self.middle_network(x, internal_states)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                if self.has_actions:
                    x = common_video.inject_additional_input(
                        x, action, "action_enc", hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=activation_fn,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if hparams.do_autoregressive_rnn:
            # If enabled, we predict the target frame autoregregressively using rnns.
            # To this end, the current prediciton is flattened into one long sequence
            # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value,
            # from 0 to 255) is predicted with an RNN. To avoid doing as many steps
            # as width * height * channels, we only use a number of pixels back,
            # as many as hparams.autoregressive_rnn_lookback.
            with tf.variable_scope("autoregressive_rnn"):
                batch_size = common_layers.shape_list(frames[0])[0]
                # Height, width, channels and lookback are the constants we need.
                h, w = inputs_shape[1], inputs_shape[
                    2]  # 105, 80 on Atari games
                c = hparams.problem.num_channels
                lookback = hparams.autoregressive_rnn_lookback
                assert (
                    h * w
                ) % lookback == 0, "Number of pixels must divide lookback."
                m = (h * w) // lookback  # Batch size multiplier for the RNN.
                # These are logits that will be used as inputs to the RNN.
                rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs")
                # They are of shape [batch_size, h, w, c, 64], reshaping now.
                rnn_inputs = tf.reshape(rnn_inputs,
                                        [batch_size * m, lookback * c, 64])
                # Same for the target frame.
                rnn_target = tf.reshape(target_frame,
                                        [batch_size * m, lookback * c])
                # Construct rnn starting state: flatten rnn_inputs, apply a relu layer.
                rnn_start_state = tf.nn.relu(
                    tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)),
                                    256,
                                    name="rnn_start_state"))
                # Our RNN function API is on bits, each subpixel has 8 bits.
                total_num_bits = lookback * c * 8
                # We need to provide RNN targets as bits (due to the API).
                rnn_target_bits = discretization.int_to_bit(rnn_target, 8)
                rnn_target_bits = tf.reshape(rnn_target_bits,
                                             [batch_size * m, total_num_bits])
                if self.is_training:
                    # Run the RNN in training mode, add it's loss to the losses.
                    rnn_predict, rnn_loss = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        target_bits=rnn_target_bits,
                        extra_inputs=rnn_inputs)
                    extra_loss += rnn_loss
                    # We still use non-RNN predictions too in order to guide the network.
                    x = tf.layers.dense(x, c * 256, name="logits")
                    x = tf.reshape(x, [batch_size, h, w, c, 256])
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c, 256])
                    # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%.
                    x = tf.reshape(tf.nn.log_softmax(x),
                                   [batch_size, h, w, c * 256])
                    rnn_predict = tf.nn.log_softmax(rnn_predict)
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c * 256])
                    alpha = 0.9 * common_layers.inverse_lin_decay(
                        hparams.autoregressive_rnn_warmup_steps)
                    x = alpha * rnn_predict + (1.0 - alpha) * x
                else:
                    # In prediction mode, run the RNN without any targets.
                    bits, _ = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        extra_inputs=rnn_inputs,
                        temperature=0.0
                    )  # No sampling from this RNN, just greedy.
                    # The output is in bits, get back the predicted pixels.
                    bits = tf.reshape(bits, [batch_size * m, lookback * c, 8])
                    ints = discretization.bit_to_int(tf.maximum(bits, 0), 8)
                    ints = tf.reshape(ints, [batch_size, h, w, c])
                    x = tf.reshape(tf.one_hot(ints, 256),
                                   [batch_size, h, w, c * 256])
        elif self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        reward_pred = None
        if self.has_rewards:
            # Reward prediction based on middle and final logits.
            reward_pred = tf.concat([x_mid, x_fin], axis=-1)
            reward_pred = tf.nn.relu(
                tf.layers.dense(reward_pred, 128, name="reward_pred"))
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
Example #7
0
  def bottom_part_tower(self, input_image, input_reward, action, latent,
                        lstm_state, lstm_size, conv_size, concat_latent=False):
    """The bottom part of predictive towers.

    With the current (early) design, the main prediction tower and
    the reward prediction tower share the same arcitecture. TF Scope can be
    adjusted as required to either share or not share the weights between
    the two towers.

    Args:
      input_image: the current image.
      input_reward: the current reward.
      action: the action taken by the agent.
      latent: the latent vector.
      lstm_state: the current internal states of conv lstms.
      lstm_size: the size of lstms.
      conv_size: the size of convolutions.
      concat_latent: whether or not to concatenate the latent at every step.

    Returns:
      - the output of the partial network.
      - intermidate outputs for skip connections.
    """
    lstm_func = common_video.conv_lstm_2d
    tile_and_concat = common_video.tile_and_concat

    input_image = common_layers.make_even_size(input_image)
    concat_input_image = tile_and_concat(
        input_image, latent, concat_latent=concat_latent)

    layer_id = 0
    enc0 = tfl.conv2d(
        concat_input_image,
        conv_size[0], [5, 5],
        strides=(2, 2),
        activation=tf.nn.relu,
        padding="SAME",
        name="scale1_conv1")
    enc0 = tfcl.layer_norm(enc0, scope="layer_norm1")

    hidden1, lstm_state[layer_id] = lstm_func(
        enc0, lstm_state[layer_id], lstm_size[layer_id], name="state1")
    hidden1 = tile_and_concat(hidden1, latent, concat_latent=concat_latent)
    hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2")
    layer_id += 1

    hidden2, lstm_state[layer_id] = lstm_func(
        hidden1, lstm_state[layer_id], lstm_size[layer_id], name="state2")
    hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3")
    hidden2 = common_layers.make_even_size(hidden2)
    enc1 = tfl.conv2d(hidden2, hidden2.get_shape()[3], [3, 3], strides=(2, 2),
                      padding="SAME", activation=tf.nn.relu, name="conv2")
    enc1 = tile_and_concat(enc1, latent, concat_latent=concat_latent)
    layer_id += 1

    if self.hparams.small_mode:
      hidden4, enc2 = hidden2, enc1
    else:
      hidden3, lstm_state[layer_id] = lstm_func(
          enc1, lstm_state[layer_id], lstm_size[layer_id], name="state3")
      hidden3 = tile_and_concat(hidden3, latent, concat_latent=concat_latent)
      hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4")
      layer_id += 1

      hidden4, lstm_state[layer_id] = lstm_func(
          hidden3, lstm_state[layer_id], lstm_size[layer_id], name="state4")
      hidden4 = tile_and_concat(hidden4, latent, concat_latent=concat_latent)
      hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5")
      hidden4 = common_layers.make_even_size(hidden4)
      enc2 = tfl.conv2d(hidden4, hidden4.get_shape()[3], [3, 3], strides=(2, 2),
                        padding="SAME", activation=tf.nn.relu, name="conv3")
      layer_id += 1

    if action is not None:
      enc2 = common_video.inject_additional_input(
          enc2, action, "action_enc", self.hparams.action_injection)
    if input_reward is not None:
      enc2 = common_video.inject_additional_input(
          enc2, input_reward, "reward_enc")
    if latent is not None and not concat_latent:
      with tf.control_dependencies([latent]):
        enc2 = tf.concat([enc2, latent], axis=3)

    enc3 = tfl.conv2d(enc2, hidden4.get_shape()[3], [1, 1], strides=(1, 1),
                      padding="SAME", activation=tf.nn.relu, name="conv4")

    hidden5, lstm_state[layer_id] = lstm_func(
        enc3, lstm_state[layer_id], lstm_size[layer_id], name="state5")
    hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6")
    hidden5 = tile_and_concat(hidden5, latent, concat_latent=concat_latent)
    layer_id += 1
    return hidden5, (enc0, enc1), layer_id
  def inject_latent(self, layer, inputs, target, action):
    """Inject a deterministic latent based on the target frame."""
    hparams = self.hparams
    final_filters = common_layers.shape_list(layer)[-1]
    filters = hparams.hidden_size
    kernel = (4, 4)
    layer_shape = common_layers.shape_list(layer)

    def add_bits(layer, bits):
      z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
      if not hparams.complex_addn:
        return layer + z_mul
      layer *= tf.nn.sigmoid(z_mul)
      z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
      layer += z_add
      return layer

    if not self.is_training:
      if hparams.full_latent_tower:
        rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits])
        bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      else:
        bits, _ = discretization.predict_bits_with_lstm(
            layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
            temperature=hparams.latent_predictor_temperature)
        bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
      return add_bits(layer, bits), 0.0

    # Embed.
    frames = tf.concat(inputs + [target], axis=-1)
    x = tfl.dense(
        frames, filters, name="latent_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Add embedded action if present.
    if action is not None:
      x = common_video.inject_additional_input(
          x, action, "action_enc_latent", hparams.action_injection)

    if hparams.full_latent_tower:
      for i in range(hparams.num_compress_steps):
        with tf.variable_scope("latent_downstride%d" % i):
          x = common_layers.make_even_size(x)
          if i < hparams.filter_double_steps:
            filters *= 2
          x = common_attention.add_timing_signal_nd(x)
          x = tfl.conv2d(x, filters, kernel,
                         activation=common_layers.belu,
                         strides=(2, 2), padding="SAME")
          x = common_layers.layer_norm(x)
    else:
      x = common_layers.double_discriminator(x)
      x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

    bits, bits_clean = discretization.tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise,
        hparams.discretize_warmup_steps, hparams.mode)
    if not hparams.full_latent_tower:
      _, pred_loss = discretization.predict_bits_with_lstm(
          layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
          target_bits=bits_clean)
      # Mix bits from latent with predicted bits on forward pass as a noise.
      if hparams.latent_rnn_max_sampling > 0.0:
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          bits_pred, _ = discretization.predict_bits_with_lstm(
              layer, hparams.latent_predictor_state_size,
              hparams.bottleneck_bits,
              temperature=hparams.latent_predictor_temperature)
          bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2)
        # Be bits_pred on the forward pass but bits on the backward one.
        bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean)
        # Select which bits to take from pred sampling with bit_p probability.
        which_bit = tf.random_uniform(common_layers.shape_list(bits))
        bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps)
        bit_p *= hparams.latent_rnn_max_sampling
        bits = tf.where(which_bit < bit_p, bits_pred, bits)

    res = add_bits(layer, bits)
    # During training, sometimes skip the latent to help action-conditioning.
    res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2)
    res_p *= hparams.latent_use_max_probability
    res_rand = tf.random_uniform([layer_shape[0]])
    res = tf.where(res_rand < res_p, res, layer)
    return res, pred_loss