Beispiel #1
0
  def _unroll_neck_steps(self, env_output, initial_state):
    """Unrolls all timesteps and returns a list of outputs and a final state."""
    unused_reward, done, observation, unused_info = env_output
    # Add current time_step and batch_size.
    self._current_num_timesteps = tf.shape(done)[0]
    self._current_batch_size = tf.shape(done)[1]

    torso_output = utils.batch_apply(self._torso, observation)

    # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are
    # same as trailing dimensions of `neck_state`.
    neck_state = initial_state
    reset_state = self._get_reset_state(observation, done, neck_state)
    neck_output_list = []
    for timestep, d in enumerate(tf.unstack(done)):
      neck_input = utils.get_row_nested_tensor(torso_output, timestep)
      # If the episode ended, the neck state should be reset before the next
      # step.
      curr_timestep_reset_state = utils.get_row_nested_tensor(
          reset_state, timestep)
      neck_state = tf.nest.map_structure(
          lambda reset_state, state: tf.compat.v1.where(d, reset_state, state),  
          curr_timestep_reset_state, neck_state)
      neck_output, neck_state = self._neck(neck_input, neck_state)
      neck_output_list.append(neck_output)
    return neck_output_list, neck_state
Beispiel #2
0
    def call(self, env_output, neck_state):
        """Runs the entire episode given time-major tensors.

    Args:
      env_output: An `EnvOutput` tuple with following expectations:
        reward - Unused
        done - A boolean tensor of shape  [num_timesteps, batch_size].
        observation - A nested structure with individual tensors that have first
          two dimensions equal to [num_timesteps, batch_size]
        info - Unused
      neck_state: A tensor or nested structure with individual tensors that have
        first dimension equal to batch_size and no time dimension.

    Returns:
      An `AgentOutput` tuple with individual tensors that have first two
        dimensions equal to [num_timesteps, batch_size]
    """
        unused_reward, done, observation, unused_info = env_output
        # Add current time_step and batch_size.
        self._current_num_timesteps = tf.shape(done)[0]
        self._current_batch_size = tf.shape(done)[1]

        torso_output = utils.batch_apply(self._torso, observation)
        # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are
        # same as trailing dimensions of `neck_state`.
        reset_state = self._get_reset_state(observation, done, neck_state)
        neck_output_list = []
        for timestep, d in enumerate(tf.unstack(done)):
            neck_input = utils.get_row_nested_tensor(torso_output, timestep)
            # If the episode ended, the neck state should be reset before the next
            # step.
            curr_timestep_reset_state = utils.get_row_nested_tensor(
                reset_state, timestep)
            neck_state = tf.nest.map_structure(
                lambda reset_state, state: tf.compat.v1.where(
                    d, reset_state, state), curr_timestep_reset_state,
                neck_state)
            neck_output, neck_state = self._neck(neck_input, neck_state)
            neck_output_list.append(neck_output)

        head_input = tf.nest.map_structure(lambda *tensors: tf.stack(tensors),
                                           *neck_output_list)
        head_output = utils.batch_apply(self._head, head_input)
        assert isinstance(head_output, common.AgentOutput)
        return head_output, neck_state
Beispiel #3
0
 def testGetRowNestedTensor(self):
   x = {
       'a': tf.constant([[0., 0.], [1., 1.]]),
       'b': {
           'b_1': tf.ones(shape=(2, 3))
       }
   }
   result = utils.get_row_nested_tensor(x, 1)
   np.testing.assert_array_almost_equal(
       np.array([1., 1.]), result['a'].numpy())
   np.testing.assert_array_almost_equal(
       np.array([1., 1., 1.]), result['b']['b_1'].numpy())
Beispiel #4
0
def run_with_address(
    problem_type: framework_problem_type.ProblemType,
    listen_address: Text,
    hparams: Dict[Text, Any],
):
    """Runs the learner with the given problem type.

  Args:
    problem_type: An instance of `framework_problem_type.ProblemType`.
    listen_address: The network address on which to listen.
    hparams: A dict containing hyperparameter settings.
  """
    devices = device_lib.list_local_devices()
    logging.info('Found devices: %s', devices)
    devices = [d for d in devices if d.device_type == FLAGS.agent_device]
    assert devices, 'Could not find a device of type %s' % FLAGS.agent_device
    agent_device = devices[0].name
    logging.info('Using agent device: %s', agent_device)

    # Initialize agent, variables.
    specs = utils.read_specs(hparams['logdir'])
    flat_specs = [
        tf.TensorSpec.from_spec(s, str(i))
        for i, s in enumerate(tf.nest.flatten(specs))
    ]
    queue_capacity = FLAGS.queue_capacity or FLAGS.batch_size * 10
    queue = tf.queue.FIFOQueue(
        queue_capacity,
        [t.dtype for t in flat_specs],
        [t.shape for t in flat_specs],
    )
    agent = problem_type.get_agent()
    # Create dummy environment output of shape [num_timesteps, batch_size, ...].
    env_output = tf.nest.map_structure(
        lambda s: tf.zeros(
            list(s.shape)[0:1] + [FLAGS.batch_size] + list(s.shape)[1:], s.
            dtype), specs.env_output)
    init_observation = utils.get_row_nested_tensor(env_output.observation, 0)
    init_agent_state = agent.get_initial_state(init_observation,
                                               batch_size=FLAGS.batch_size)
    env_output = _convert_uint8_to_bfloat16(env_output)
    with tf.device(agent_device):
        agent(env_output, init_agent_state)

        # Create optimizer.

        if FLAGS.lr_decay_steps > 0 and FLAGS.lr_decay_rate < 1.:
            lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=FLAGS.learning_rate,
                decay_steps=FLAGS.lr_decay_steps,
                decay_rate=FLAGS.lr_decay_rate)
        else:
            lr_schedule = FLAGS.learning_rate
        optimizer = problem_type.get_optimizer(lr_schedule)
        # NOTE: `iterations` is a non-trainable variable which is managed by
        # optimizer (created inside optimizer as well as incremented by 1 on every
        # call to optimizer.minimize).
        iterations = optimizer.iterations
        study_loss_types = problem_type.get_study_loss_types()

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(actor_output):
            """Per-replica StepFn."""
            actor_output = tf.nest.pack_sequence_as(specs, actor_output)
            (initial_agent_state, env_output, actor_agent_output, actor_action,
             loss_type, info) = actor_output
            with tf.GradientTape() as tape:
                loss = loss_fns.compute_loss(
                    study_loss_types=study_loss_types,
                    current_batch_loss_type=loss_type,
                    agent=agent,
                    agent_state=initial_agent_state,
                    env_output=env_output,
                    actor_agent_output=actor_agent_output,
                    actor_action=actor_action,
                    num_steps=iterations)
            grads = tape.gradient(loss, agent.trainable_variables)
            if FLAGS.gradient_clip_norm > 0.:
                for i, g in enumerate(grads):
                    if g is not None:
                        grads[i] = tf.clip_by_norm(g, FLAGS.gradient_clip_norm)
            grad_norms = {}
            for var, grad in zip(agent.trainable_variables, grads):
                # For parameters which are initialized but not used for loss
                # computation, gradient tape would return None.
                if grad is not None:
                    grad_norms[var.name] = tf.norm(grad)
            optimizer.apply_gradients(zip(grads, agent.trainable_variables))
            return info, grad_norms

        return step_fn(next(iterator))

    ckpt_manager = _maybe_restore_from_ckpt(hparams['logdir'],
                                            agent=agent,
                                            optimizer=optimizer)
    server = _create_server(listen_address,
                            specs,
                            agent,
                            queue,
                            extra_variables=[iterations])
    logging.info('Starting gRPC server')
    server.start()

    dataset = tf.data.Dataset.from_tensors(0).repeat(None)
    dataset = dataset.map(lambda _: queue.dequeue())
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    # Transpose each batch to time-major order. This is relatively slow, so do
    # this work outside of the training loop.
    dataset = dataset.map(functools.partial(_transpose_batch, specs))
    dataset = dataset.apply(tf.data.experimental.copy_to_device(agent_device))
    with tf.device(agent_device):
        dataset = dataset.prefetch(1)
        iterator = iter(dataset)

    # Execute learning and track performance.
    summary_writer = tf.summary.create_file_writer(hparams['logdir'],
                                                   flush_millis=20000,
                                                   max_queue=1000)
    last_ckpt_time = time.time()
    with summary_writer.as_default():
        last_log_iterations = iterations.numpy()
        last_log_num_env_frames = iterations * hparams['iter_frame_ratio']
        last_log_time = time.time()
        while iterations < hparams['final_iteration']:
            logging.info('Iteration %d of %d', iterations + 1,
                         hparams['final_iteration'])

            # Save checkpoint at specified intervals or if no previous ckpt exists.
            current_time = time.time()
            if (current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs
                    or not ckpt_manager.latest_checkpoint):
                ckpt_manager.save(checkpoint_number=iterations)
                last_ckpt_time = current_time

            with utils.WallTimer() as wt:
                with tf.device(agent_device):
                    info, grad_norms = train_step(iterator)
            tf.summary.scalar('steps_summary/step_seconds',
                              wt.duration,
                              step=iterations)
            norm_summ_family = 'grad_norms/'
            for name, norm in grad_norms.items():
                tf.summary.scalar(norm_summ_family + name,
                                  norm,
                                  step=iterations)

            if current_time - last_log_time >= 120:
                num_env_frames = iterations.numpy(
                ) * hparams['iter_frame_ratio']
                num_frames_since = num_env_frames - last_log_num_env_frames
                num_iterations_since = iterations.numpy() - last_log_iterations
                elapsed_time = time.time() - last_log_time
                tf.summary.scalar(
                    'steps_summary/num_environment_frames_per_sec',
                    tf.cast(num_frames_since, tf.float32) / elapsed_time,
                    step=iterations)
                tf.summary.scalar('steps_summary/num_iterations_per_sec',
                                  tf.cast(num_iterations_since, tf.float32) /
                                  elapsed_time,
                                  step=iterations)
                tf.summary.scalar('queue_size', queue.size(), step=iterations)
                tf.summary.scalar('learning_rate',
                                  optimizer._decayed_lr(var_dtype=tf.float32),
                                  step=iterations)
                last_log_num_env_frames, last_log_iterations, last_log_time = (
                    num_env_frames, iterations.numpy(), time.time())
                logging.info('Number of environment frames: %d',
                             num_env_frames)

            problem_type.create_summary(step=iterations, info=info)

    # Finishing up.
    ckpt_manager.save(checkpoint_number=iterations)
    queue.close()
    server.shutdown()