Esempio n. 1
0
 def _transform_history_observations(self, frames):
   batch_size, history_size = frames.get_shape().as_list()[:2]
   new_frames = discretization.int_to_bit(frames, 8)
   new_frames = tf.reshape(
       new_frames, (batch_size, history_size) + self.observ_shape
   )
   return tf.cast(new_frames, self.observ_dtype)
 def testIntToBitOnes(self):
   x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
   x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
   diff = discretization.int_to_bit(x_int, num_bits=3) - x_bit
   with self.test_session() as sess:
     tf.global_variables_initializer().run()
     d = sess.run(diff)
     self.assertTrue(np.all(d == 0))
 def testIntToBitOnes(self):
     x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
     x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
     diff = discretization.int_to_bit(x_int, num_bits=3) - x_bit
     with self.test_session() as sess:
         tf.global_variables_initializer().run()
         d = sess.run(diff)
         self.assertTrue(np.all(d == 0))
 def _reset_non_empty(self, indices):
   # pylint: disable=protected-access
   new_values = self._batch_env._reset_non_empty(indices)
   new_values_unpacked = discretization.int_to_bit(new_values, 8)
   new_values_unpacked = tf.reshape(new_values_unpacked, (-1,)
                                    +self.observ_shape)
   # pylint: enable=protected-access
   assign_op = tf.scatter_update(self._observ, indices, new_values_unpacked)
   with tf.control_dependencies([assign_op]):
     return tf.identity(new_values_unpacked)
Esempio n. 5
0
 def simulate(self, action):
   reward, done = self._batch_env.simulate(action)
   with tf.control_dependencies([reward, done]):
     with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
       unpacked = discretization.int_to_bit(self._batch_env.observ, 8)
       unpacked = tf.reshape(unpacked, (-1,)+self.observ_shape)
       unpacked = tf.cast(unpacked, self.observ_dtype)
       assign_op = self._observ.assign(unpacked)
       with tf.control_dependencies([assign_op]):
         return tf.identity(reward), tf.identity(done)
    def add_task_id(self, task, example, encoder, hparams, is_infer):
        """Convert example to code switching mode by adding a task id."""
        if task.has_inputs:
            example["inputs"] = example["inputs"][:-1]  # remove EOS token

        if hasattr(task, "class_labels"):
            if self.vocab_type == text_problems.VocabType.CHARACTER:
                # TODO(urvashik): handle the case where num_labels > 9
                example["targets"] = tf.cast(
                    discretization.int_to_bit(example["targets"], 1, base=10) +
                    50, tf.int64)
                example["targets"] = tf.squeeze(example["targets"], axis=[-1])
            elif self.vocab_type == text_problems.VocabType.SUBWORD:
                offset = encoder.vocab_size + len(self.task_list)
                example["targets"] = offset + example["targets"]
        else:
            # sequence with inputs and targets eg: summarization
            if task.has_inputs:
                if hparams.multiproblem_max_input_length > 0:
                    example["inputs"] = example[
                        "inputs"][:hparams.multiproblem_max_input_length]
                # Do not truncate targets during inference with beam decoding.
                if hparams.multiproblem_max_target_length > 0 and not is_infer:
                    example["targets"] = example[
                        "targets"][:hparams.multiproblem_max_target_length]

        def make_constant_shape(x, size):
            x = x[:size]
            xlen = tf.shape(x)[0]
            x = tf.pad(x, [[0, size - xlen]])
            return tf.reshape(x, [size])

        if task.has_inputs:
            if is_infer:
                concat_list = [example["inputs"], [task.task_id]]
                example["inputs"] = tf.concat(concat_list, axis=0)
            else:
                inputs = example.pop("inputs")
                concat_list = [inputs, [task.task_id], example["targets"]]
                example["targets"] = tf.concat(concat_list, axis=0)
                if hparams.multiproblem_fixed_train_length > 0:
                    example["targets"] = make_constant_shape(
                        example["targets"],
                        hparams.multiproblem_fixed_train_length)
        else:
            concat_list = [[task.task_id], example["targets"]]
            example["targets"] = tf.concat(concat_list, axis=0)
            if not is_infer and hparams.multiproblem_fixed_train_length > 0:
                example["targets"] = make_constant_shape(
                    example["targets"],
                    hparams.multiproblem_fixed_train_length)

        example["task_id"] = tf.constant([task.task_id], dtype=tf.int64)
        return example
  def simulate(self, action):
    action = tf.Print(action, [action], message="action=", summarize=200)

    # action = tf.zeros_like(action) #Temporary hacked bugfix
    reward, done = self._batch_env.simulate(action)
    with tf.control_dependencies([reward, done]):
      with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        unpacked = discretization.int_to_bit(self._batch_env.observ, 8)
        unpacked = tf.reshape(unpacked, (-1,)+self.observ_shape)
        assign_op = self._observ.assign(unpacked)
        with tf.control_dependencies([assign_op]):
          return tf.identity(reward), tf.identity(done)
Esempio n. 8
0
  def add_task_id(self, task, example, encoder, hparams, is_infer):
    """Convert example to code switching mode by adding a task id."""
    if task.has_inputs:
      example["inputs"] = example["inputs"][:-1]  # remove EOS token

    if hasattr(task, "class_labels"):
      if self.vocab_type == text_problems.VocabType.CHARACTER:
        # TODO(urvashik): handle the case where num_labels > 9
        example["targets"] = tf.cast(discretization.int_to_bit(
            example["targets"], 1, base=10) + 50, tf.int64)
        example["targets"] = tf.squeeze(example["targets"], axis=[-1])
      elif self.vocab_type == text_problems.VocabType.SUBWORD:
        offset = encoder.vocab_size + len(self.task_list)
        example["targets"] = offset + example["targets"]
    else:
      # sequence with inputs and targets eg: summarization
      if task.has_inputs:
        if hparams.multiproblem_max_input_length > 0:
          example["inputs"] = example[
              "inputs"][:hparams.multiproblem_max_input_length]
        # Do not truncate targets during inference with beam decoding.
        if hparams.multiproblem_max_target_length > 0 and not is_infer:
          example["targets"] = example[
              "targets"][:hparams.multiproblem_max_target_length]

    def make_constant_shape(x, size):
      x = x[:size]
      xlen = tf.shape(x)[0]
      x = tf.pad(x, [[0, size - xlen]])
      return tf.reshape(x, [size])

    if task.has_inputs:
      if is_infer:
        concat_list = [example["inputs"], [task.task_id]]
        example["inputs"] = tf.concat(concat_list, axis=0)
      else:
        inputs = example.pop("inputs")
        concat_list = [inputs, [task.task_id], example["targets"]]
        example["targets"] = tf.concat(concat_list, axis=0)
        if hparams.multiproblem_fixed_train_length > 0:
          example["targets"] = make_constant_shape(
              example["targets"], hparams.multiproblem_fixed_train_length)
    else:
      concat_list = [[task.task_id], example["targets"]]
      example["targets"] = tf.concat(concat_list, axis=0)
      if not is_infer and hparams.multiproblem_fixed_train_length > 0:
        example["targets"] = make_constant_shape(
            example["targets"], hparams.multiproblem_fixed_train_length)

    example["task_id"] = tf.constant([task.task_id], dtype=tf.int64)
    return example
Esempio n. 9
0
  def add_task_id(self, task, example):
    """Convert example to code switching mode by adding a task id."""
    if hasattr(task, "class_labels"):
      # TODO(urvashik): handle the case where num_labels > 9
      example["targets"] = tf.cast(discretization.int_to_bit(
          example["targets"], 1, base=10) + 50, tf.int64)
      example["targets"] = tf.squeeze(example["targets"], axis=[-1])

    if task.has_inputs:
      inputs = example.pop("inputs")
      concat_list = [inputs, [task.task_id], example["targets"]]
    else:
      concat_list = [[task.task_id], example["targets"]]

    example["targets"] = tf.concat(concat_list, 0)
    return example
Esempio n. 10
0
  def add_task_id(self, task, example):
    """Convert example to code switching mode by adding a task id."""
    if hasattr(task, "class_labels"):
      # TODO(urvashik): handle the case where num_labels > 9
      example["targets"] = tf.cast(discretization.int_to_bit(
          example["targets"], 1, base=10) + 50, tf.int64)
      example["targets"] = tf.squeeze(example["targets"], axis=[-1])

    if task.has_inputs:
      inputs = example.pop("inputs")
      concat_list = [inputs, [task.task_id], example["targets"]]
    else:
      concat_list = [[task.task_id], example["targets"]]

    example["targets"] = tf.concat(concat_list, 0)
    return example
Esempio n. 11
0
  def add_task_id(self, task, example, encoder):
    """Convert example to code switching mode by adding a task id."""
    if hasattr(task, "class_labels"):
      if self.vocab_type == text_problems.VocabType.CHARACTER:
        # TODO(urvashik): handle the case where num_labels > 9
        example["targets"] = tf.cast(discretization.int_to_bit(
            example["targets"], 1, base=10) + 50, tf.int64)
        example["targets"] = tf.squeeze(example["targets"], axis=[-1])
      elif self.vocab_type == text_problems.VocabType.SUBWORD:
        offset = encoder.vocab_size + len(self.task_list)
        example["targets"] = offset + example["targets"]

    if task.has_inputs:
      inputs = example.pop("inputs")
      concat_list = [inputs, [task.task_id], example["targets"]]
    else:
      concat_list = [[task.task_id], example["targets"]]

    example["targets"] = tf.concat(concat_list, 0)
    return example
Esempio n. 12
0
 def testIntToBitOnes(self):
   x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
   x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
   diff = discretization.int_to_bit(x_int, num_bits=3) - x_bit
   d = self.evaluate(diff)
   self.assertTrue(np.all(d == 0))
Esempio n. 13
0
    def _setup(self):
        if self.make_extra_debug_info:
            self.report_reward_statistics_every = 10
            self.dones = 0
            self.real_reward = 0
            # Slight weirdness to make sim env and real env aligned
            if self.simulated_environment:
                self.real_env.reset()
                for _ in range(self.num_input_frames):
                    self.real_ob, _, _, _ = self.real_env.step(0)
            self.total_sim_reward, self.total_real_reward = 0.0, 0.0
            self.sum_of_rewards = 0.0
            self.successful_episode_reward_predictions = 0

        in_graph_wrappers = self.in_graph_wrappers + [
            (atari.MemoryWrapper, {}), (StackAndSkipWrapper, {
                "skip": 4
            })
        ]
        env_hparams = tf.contrib.training.HParams(
            in_graph_wrappers=in_graph_wrappers,
            problem=self.real_env_problem if self.real_env_problem else self,
            simulated_environment=self.simulated_environment)
        if self.simulated_environment:
            env_hparams.add_hparam("simulation_random_starts",
                                   self.simulation_random_starts)
            env_hparams.add_hparam("intrinsic_reward_scale",
                                   self.intrinsic_reward_scale)

        generator_batch_env = batch_env_factory(self.environment_spec,
                                                env_hparams,
                                                num_agents=1,
                                                xvfb=False)

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            if FLAGS.agent_policy_path:
                policy_lambda = self.collect_hparams.network
            else:
                # When no agent_policy_path is set, just generate random samples.
                policy_lambda = rl.random_policy_fun

        if FLAGS.autoencoder_path:
            # TODO(lukaszkaiser): remove hard-coded autoencoder params.
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                self.setup_autoencoder()
                autoencoder_model = self.autoencoder_model
                # Feeds for autoencoding.
                shape = [
                    self.raw_frame_height, self.raw_frame_width,
                    self.num_channels
                ]
                self.autoencoder_feed = tf.placeholder(tf.int32, shape=shape)
                self.autoencoder_result = self.autoencode_tensor(
                    self.autoencoder_feed)
                # Now for autodecoding.
                shape = self.frame_shape
                self.autodecoder_feed = tf.placeholder(tf.int32, shape=shape)
                bottleneck = tf.reshape(
                    discretization.int_to_bit(self.autodecoder_feed, 8), [
                        1, 1, self.frame_height, self.frame_width,
                        self.num_channels * 8
                    ])
                autoencoder_model.set_mode(tf.estimator.ModeKeys.PREDICT)
                self.autodecoder_result = autoencoder_model.decode(bottleneck)

        def preprocess_fn(x):
            shape = [
                self.raw_frame_height, self.raw_frame_width, self.num_channels
            ]
            # TODO(lukaszkaiser): we assume x comes from StackAndSkipWrapper skip=4.
            xs = [tf.reshape(t, [1] + shape) for t in tf.split(x, 4, axis=-1)]
            autoencoded = self.autoencode_tensor(tf.concat(xs, axis=0),
                                                 batch_size=4)
            encs = [
                tf.squeeze(t, axis=[0])
                for t in tf.split(autoencoded, 4, axis=0)
            ]
            res = tf.to_float(tf.concat(encs, axis=-1))
            return tf.expand_dims(res, axis=0)

        # TODO(lukaszkaiser): x is from StackAndSkipWrapper thus 4*num_channels.
        shape = [1, self.frame_height, self.frame_width, 4 * self.num_channels]
        do_preprocess = (self.autoencoder_model is not None
                         and not self.simulated_environment)
        preprocess = (preprocess_fn, shape) if do_preprocess else None

        def policy(x):
            return policy_lambda(self.environment_spec().action_space,
                                 self.collect_hparams, x)

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            self.collect_hparams.epoch_length = 10
            _, self.collect_trigger_op = collect.define_collect(
                policy,
                generator_batch_env,
                self.collect_hparams,
                eval_phase=self.eval_phase,
                scope="define_collect",
                preprocess=preprocess)

        self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size(
        )
        self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
Esempio n. 14
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        batch_size = layer_shape[0]
        state_size = hparams.latent_predictor_state_size
        lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
        discrete_predict = tf.layers.Dense(256, name="discrete_predict")
        discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")

        def add_d(layer, d):
            z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if self.is_predicting:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                layer_pred = tf.reshape(
                    layer, [batch_size, prod(layer_shape[1:])])
                prediction = tf.layers.dense(layer_pred,
                                             state_size,
                                             name="istate")
                c_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="cstate")
                m_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="mstate")
                state = (c_state, m_state)
                outputs = []
                for i in range(hparams.bottleneck_bits // 8):
                    output, state = lstm_cell(prediction, state)
                    discrete_logits = discrete_predict(output)
                    discrete_samples = common_layers.sample_with_temperature(
                        discrete_logits, hparams.latent_predictor_temperature)
                    outputs.append(tf.expand_dims(discrete_samples, axis=1))
                    prediction = discrete_embed(
                        tf.one_hot(discrete_samples, 256))
                outputs = tf.concat(outputs, axis=1)
                outputs = discretization.int_to_bit(outputs, 8)
                rand = tf.reshape(outputs,
                                  [batch_size, 1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return add_d(layer, d), 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")
        x0 = tf.tanh(x)
        d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 -
                                  x0)
        pred_loss = 0.0
        if not hparams.full_latent_tower:
            d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0),
                                [batch_size, hparams.bottleneck_bits // 8, 8])
            d_int = discretization.bit_to_int(d_pred, 8)
            tf.summary.histogram("d_int", tf.reshape(d_int, [-1]))
            d_hot = tf.one_hot(d_int, 256, axis=-1)
            d_pred = discrete_embed(d_hot)
            layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])])
            prediction0 = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="istate")
            c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
            m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
            pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred],
                             axis=1)
            state = (c_state, m_state)
            outputs = []
            for i in range(hparams.bottleneck_bits // 8):
                output, state = lstm_cell(pred[:, i, :], state)
                outputs.append(tf.expand_dims(output, axis=1))
            outputs = tf.concat(outputs, axis=1)
            d_int_pred = discrete_predict(outputs)
            pred_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=d_int_pred, labels=d_int)
            pred_loss = tf.reduce_mean(pred_loss)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            x += tf.truncated_normal(common_layers.shape_list(x),
                                     mean=0.0,
                                     stddev=0.2)
            x = tf.tanh(x)
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            x *= noise
            d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 -
                                     x)
            p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
            d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)
        return add_d(layer, d), pred_loss
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_extra):
        del rewards, video_extra

        hparams = self.hparams
        filters = hparams.hidden_size
        kernel2 = (4, 4)
        action = actions[-1]
        activation_fn = common_layers.belu
        if self.hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        # Normalize frames.
        frames = [common_layers.standardize_images(f) for f in frames]

        # Stack the inputs.
        if internal_states is not None and hparams.concat_internal_states:
            # Use the first part of the first internal state if asked to concatenate.
            batch_size = common_layers.shape_list(frames[0])[0]
            internal_state = internal_states[0][0][:batch_size, :, :, :]
            stacked_frames = tf.concat(frames + [internal_state], axis=-1)
        else:
            stacked_frames = tf.concat(frames, axis=-1)
        inputs_shape = common_layers.shape_list(stacked_frames)

        # Update internal states early if requested.
        if hparams.concat_internal_states:
            internal_states = self.update_internal_states_early(
                internal_states, frames)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_frames,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=activation_fn,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        if self.has_actions:
            with tf.variable_scope("policy"):
                x_flat = tf.layers.flatten(x)
                policy_pred = tf.layers.dense(x_flat,
                                              self.hparams.problem.num_actions)
                value_pred = tf.layers.dense(x_flat, 1)
                value_pred = tf.squeeze(value_pred, axis=-1)
        else:
            policy_pred, value_pred = None, None

        # Add embedded action if present.
        if self.has_actions:
            x = common_video.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)

        # Inject latent if present. Only for stochastic models.
        norm_target_frame = common_layers.standardize_images(target_frame)
        x, extra_loss = self.inject_latent(x, frames, norm_target_frame,
                                           action)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x, internal_states = self.middle_network(x, internal_states)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                if self.has_actions:
                    x = common_video.inject_additional_input(
                        x, action, "action_enc", hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=activation_fn,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if hparams.do_autoregressive_rnn:
            # If enabled, we predict the target frame autoregregressively using rnns.
            # To this end, the current prediciton is flattened into one long sequence
            # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value,
            # from 0 to 255) is predicted with an RNN. To avoid doing as many steps
            # as width * height * channels, we only use a number of pixels back,
            # as many as hparams.autoregressive_rnn_lookback.
            with tf.variable_scope("autoregressive_rnn"):
                batch_size = common_layers.shape_list(frames[0])[0]
                # Height, width, channels and lookback are the constants we need.
                h, w = inputs_shape[1], inputs_shape[
                    2]  # 105, 80 on Atari games
                c = hparams.problem.num_channels
                lookback = hparams.autoregressive_rnn_lookback
                assert (
                    h * w
                ) % lookback == 0, "Number of pixels must divide lookback."
                m = (h * w) // lookback  # Batch size multiplier for the RNN.
                # These are logits that will be used as inputs to the RNN.
                rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs")
                # They are of shape [batch_size, h, w, c, 64], reshaping now.
                rnn_inputs = tf.reshape(rnn_inputs,
                                        [batch_size * m, lookback * c, 64])
                # Same for the target frame.
                rnn_target = tf.reshape(target_frame,
                                        [batch_size * m, lookback * c])
                # Construct rnn starting state: flatten rnn_inputs, apply a relu layer.
                rnn_start_state = tf.nn.relu(
                    tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)),
                                    256,
                                    name="rnn_start_state"))
                # Our RNN function API is on bits, each subpixel has 8 bits.
                total_num_bits = lookback * c * 8
                # We need to provide RNN targets as bits (due to the API).
                rnn_target_bits = discretization.int_to_bit(rnn_target, 8)
                rnn_target_bits = tf.reshape(rnn_target_bits,
                                             [batch_size * m, total_num_bits])
                if self.is_training:
                    # Run the RNN in training mode, add it's loss to the losses.
                    rnn_predict, rnn_loss = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        target_bits=rnn_target_bits,
                        extra_inputs=rnn_inputs)
                    extra_loss += rnn_loss
                    # We still use non-RNN predictions too in order to guide the network.
                    x = tf.layers.dense(x, c * 256, name="logits")
                    x = tf.reshape(x, [batch_size, h, w, c, 256])
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c, 256])
                    # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%.
                    x = tf.reshape(tf.nn.log_softmax(x),
                                   [batch_size, h, w, c * 256])
                    rnn_predict = tf.nn.log_softmax(rnn_predict)
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c * 256])
                    alpha = 0.9 * common_layers.inverse_lin_decay(
                        hparams.autoregressive_rnn_warmup_steps)
                    x = alpha * rnn_predict + (1.0 - alpha) * x
                else:
                    # In prediction mode, run the RNN without any targets.
                    bits, _ = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        extra_inputs=rnn_inputs,
                        temperature=0.0
                    )  # No sampling from this RNN, just greedy.
                    # The output is in bits, get back the predicted pixels.
                    bits = tf.reshape(bits, [batch_size * m, lookback * c, 8])
                    ints = discretization.bit_to_int(tf.maximum(bits, 0), 8)
                    ints = tf.reshape(ints, [batch_size, h, w, c])
                    x = tf.reshape(tf.one_hot(ints, 256),
                                   [batch_size, h, w, c * 256])
        elif self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        reward_pred = None
        if self.has_rewards:
            # Reward prediction based on middle and final logits.
            reward_pred = tf.concat([x_mid, x_fin], axis=-1)
            reward_pred = tf.nn.relu(
                tf.layers.dense(reward_pred, 128, name="reward_pred"))
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
Esempio n. 16
0
 def testIntToBitOnes(self):
   x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
   x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
   diff = discretization.int_to_bit(x_int, num_bits=3) - x_bit
   d = self.evaluate(diff)
   self.assertTrue(np.all(d == 0))
Esempio n. 17
0
  def _setup(self):
    if self.make_extra_debug_info:
      self.report_reward_statistics_every = 10
      self.dones = 0
      self.real_reward = 0
      self.real_env.reset()
      # Slight weirdness to make sim env and real env aligned
      for _ in range(self.num_input_frames):
        self.real_ob, _, _, _ = self.real_env.step(0)
      self.total_sim_reward, self.total_real_reward = 0.0, 0.0
      self.sum_of_rewards = 0.0
      self.successful_episode_reward_predictions = 0

    in_graph_wrappers = self.in_graph_wrappers + [(atari.MemoryWrapper, {})]
    env_hparams = tf.contrib.training.HParams(
        in_graph_wrappers=in_graph_wrappers,
        problem=self,
        simulated_environment=self.simulated_environment)

    generator_batch_env = batch_env_factory(
        self.environment_spec, env_hparams, num_agents=1, xvfb=False)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      if FLAGS.agent_policy_path:
        policy_lambda = self.collect_hparams.network
      else:
        # When no agent_policy_path is set, just generate random samples.
        policy_lambda = rl.random_policy_fun
      policy_factory = tf.make_template(
          "network",
          functools.partial(policy_lambda, self.environment_spec().action_space,
                            self.collect_hparams),
          create_scope_now_=True,
          unique_name_="network")

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      self.collect_hparams.epoch_length = 10
      _, self.collect_trigger_op = collect.define_collect(
          policy_factory, generator_batch_env, self.collect_hparams,
          eval_phase=False, scope="define_collect")

    if FLAGS.autoencoder_path:
      # TODO(lukaszkaiser): remove hard-coded autoencoder params.
      with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        self.setup_autoencoder()
        autoencoder_model = self.autoencoder_model
        # Feeds for autoencoding.
        shape = [self.raw_frame_height, self.raw_frame_width, self.num_channels]
        self.autoencoder_feed = tf.placeholder(tf.int32, shape=shape)
        autoencoded = autoencoder_model.encode(
            tf.reshape(self.autoencoder_feed, [1, 1] + shape))
        autoencoded = tf.reshape(
            autoencoded, [self.frame_height, self.frame_width,
                          self.num_channels, 8])  # 8-bit groups.
        self.autoencoder_result = discretization.bit_to_int(autoencoded, 8)
        # Now for autodecoding.
        shape = [self.frame_height, self.frame_width, self.num_channels]
        self.autodecoder_feed = tf.placeholder(tf.int32, shape=shape)
        bottleneck = tf.reshape(
            discretization.int_to_bit(self.autodecoder_feed, 8),
            [1, 1, self.frame_height, self.frame_width, self.num_channels * 8])
        autoencoder_model.set_mode(tf.estimator.ModeKeys.PREDICT)
        self.autodecoder_result = autoencoder_model.decode(bottleneck)

    self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
    self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
Esempio n. 18
0
def normalize_example_nlp(task, example, is_infer, vocab_type, vocab_offset,
                          max_input_length, max_target_length,
                          fixed_train_length):
    """Normalize the examples from different tasks so they can be merged.

  This function is specific to NLP tasks and normalizes them so that in the
  end the example only has "targets" and "task_id". For tasks that originally
  have inputs, this is done by appending task_id to the inputs and prepending
  targets, so normalized_targets = inputs task_id targets. For classification
  tasks, targets are constructed by spelling out the class.

  Args:
    task: the Problem class of the task we are normalizing.
    example: a dictionary of tensors, the example to normalize.
    is_infer: bool, whether we are performing inference or not.
    vocab_type: the type of vocabulary in use.
    vocab_offset: integer, offset index for subword vocabularies.
    max_input_length: maximum length to cut inputs to.
    max_target_length: maximum length to cut targets to.
    fixed_train_length: set length to this size if > 0.

  Returns:
    a dictionary of tensors, like example, after normalizing, which in this
    case means that it only has "targets" and "task_id" as feature.
  """
    if task.has_inputs:
        example["inputs"] = example["inputs"][:-1]  # remove EOS token

    if hasattr(task, "class_labels"):
        if vocab_type == text_problems.VocabType.CHARACTER:
            # TODO(urvashik): handle the case where num_labels > 9
            example["targets"] = tf.cast(
                discretization.int_to_bit(example["targets"], 1, base=10) + 50,
                tf.int64)
            example["targets"] = tf.squeeze(example["targets"], axis=[-1])
        elif vocab_type == text_problems.VocabType.SUBWORD:
            example["targets"] = vocab_offset + example["targets"]
    else:
        # sequence with inputs and targets eg: summarization
        if task.has_inputs:
            if max_input_length > 0:
                example["inputs"] = example["inputs"][:max_input_length]
            # Do not truncate targets during inference with beam decoding.
            if max_target_length > 0 and not is_infer:
                example["targets"] = example["targets"][:max_target_length]

    def make_constant_shape(x, size):
        x = x[:size]
        xlen = tf.shape(x)[0]
        x = tf.pad(x, [[0, size - xlen]])
        return tf.reshape(x, [size])

    if task.has_inputs:
        if is_infer:
            concat_list = [example["inputs"], [task.task_id]]
            example["inputs"] = tf.concat(concat_list, axis=0)
        else:
            inputs = example.pop("inputs")
            concat_list = [inputs, [task.task_id], example["targets"]]
            example["targets"] = tf.concat(concat_list, axis=0)
            if fixed_train_length > 0:
                example["targets"] = make_constant_shape(
                    example["targets"], fixed_train_length)
    else:
        concat_list = [[task.task_id], example["targets"]]
        example["targets"] = tf.concat(concat_list, axis=0)
        if not is_infer and fixed_train_length > 0:
            example["targets"] = make_constant_shape(example["targets"],
                                                     fixed_train_length)

    example["task_id"] = tf.constant([task.task_id], dtype=tf.int64)
    return example