def test_joint_distribution_logprob(self):
    joint_distribution = parametric_distribution.get_parametric_distribution_for_action_space(
        self.create_tuple_space())
    parameters = np.array([0., 0., 0.,   # Normal locs
                           .1, .2, .3,   # Normal scales
                           1, 0, 0, 0,   # Discrete action 1
                           0, 1, 0, 0,   # Discrete action 2
                           0, 0, 1, 0],  # Discrete action 3
                          np.float32)
    actions = np.array([[0, 0, 0, 0, 1, 2],
                        [0, 0, .99, 0, 1, 2],
                        [0, .99, 0, 0, 1, 2],
                        [.99, 0, 0, 0, 1, 2],
                        [0, 0, 0, 0, 1, 3],
                        [0, 0, 0, 0, 2, 2],
                        [0, 0, 0, 0, 2, 3],
                        [0, 0, 0, 1, 2, 3]], np.float32)
    continuous_actions = actions[:, :3]
    discrete_actions = actions[:, 3:]

    log_probs = joint_distribution(parameters).log_prob(actions)

    normaltanh_dist = parametric_distribution.get_parametric_distribution_for_action_space(
        self.create_box_space())
    continuous_parameters = parameters[:6]
    continuous_log_probs = normaltanh_dist(continuous_parameters).log_prob(
        continuous_actions)

    multidiscrete_dist = parametric_distribution.get_parametric_distribution_for_action_space(
        self.create_multidiscrete_space())
    discrete_parameters = tf.convert_to_tensor(parameters[6:])
    discrete_log_probs = multidiscrete_dist(discrete_parameters).log_prob(
        discrete_actions)

    self.assertAllClose(log_probs, continuous_log_probs + discrete_log_probs)
  def test_clipped_distribution_kl(self):
    clipped_distribution = parametric_distribution.get_parametric_distribution_for_action_space(
        self.create_box_space(),
        continuous_config=parametric_distribution.continuous_action_config(
            action_postprocessor='ClippedIdentity'))

    dist = clipped_distribution(np.ones((6,), np.float32))

    clipped_distribution2 = parametric_distribution.get_parametric_distribution_for_action_space(
        self.create_box_space(),
        continuous_config=parametric_distribution.continuous_action_config(
            action_postprocessor='ClippedIdentity'))
    dist2 = clipped_distribution2(np.ones((6,), np.float32))

    self.assertEqual(
        dist.kl_divergence(dist2), 0)
Beispiel #3
0
    def test_joint_distribution_shape(self):
        joint_distribution = parametric_distribution.get_parametric_distribution_for_action_space(
            self.create_tuple_space())

        batch_shape = [3, 2]
        parameters_shape = [3 * 2 + 3 * 4]

        parameters = tf.zeros(batch_shape + parameters_shape)
        self.assertEqual(
            joint_distribution.entropy(parameters).shape, batch_shape)
Beispiel #4
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0, FLAGS)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.nest.map_structure(
            lambda s: tf.TensorSpec(s.shape, s.dtype, 'observation'),
            env.observation_space.__dict__.get('spaces',
                                               env.observation_space)),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    target_agent = create_agent_fn(env.action_space, env.observation_space,
                                   parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    agent_input_specs = (action_specs, env_output_specs)

    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1, 1] + list(s.shape), s.dtype), agent_input_specs)
    input_no_time = tf.nest.map_structure(lambda t: t[0], input_)

    input_ = encode(input_ + (initial_agent_state, ))
    input_no_time = encode(input_no_time + (initial_agent_state, ))

    with strategy.scope():
        # Initialize variables
        def initialize_agent_variables(agent):
            if not hasattr(agent, 'entropy_cost'):
                mul = FLAGS.entropy_cost_adjustment_speed
                agent.entropy_cost_param = tf.Variable(
                    tf.math.log(FLAGS.entropy_cost) / mul,
                    # Without the constraint, the param gradient may get rounded to 0
                    # for very small values.
                    constraint=lambda v: tf.clip_by_value(
                        v, -20 / mul, 20 / mul),
                    trainable=True,
                    dtype=tf.float32)
                agent.entropy_cost = lambda: tf.exp(mul * agent.
                                                    entropy_cost_param)

            @tf.function
            def create_variables():
                return [
                    agent.get_action(*decode(input_no_time)),
                    agent.get_V(*decode(input_)),
                    agent.get_Q(*decode(input_), action=decode(input_[0]))
                ]

            create_variables()

        initialize_agent_variables(agent)
        initialize_agent_variables(target_agent)

        # Target network update
        @tf.function
        def update_target_agent(polyak):
            """Synchronizes training and target agent variables."""
            variables = agent.variables
            target_variables = target_agent.variables
            assert len(target_variables) == len(variables), (
                'Mismatch in number of net tensors: {} != {}'.format(
                    len(target_variables), len(variables)))
            for target_var, source_var in zip(target_variables, variables):
                target_var.assign(polyak * target_var +
                                  (1. - polyak) * source_var)

        update_target_agent(polyak=0.)  # copy weights

        # Create optimizer.
        iter_frame_ratio = (
            get_replay_insertion_batch_size(per_replica=False) *
            (FLAGS.her_window_length or FLAGS.unroll_length) *
            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, target_agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.run(compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.run(apply_gradients, (loss, ))

        getattr(
            agent, 'end_of_training_step_callback',
            lambda: logging.info('end_of_training_step_callback not found'))()

        logger.step_end(logs, training_strategy, iter_frame_ratio)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent,
                               target_agent=target_agent,
                               optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer,
                                  starting_step=iterations * iter_frame_ratio)

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(FLAGS.num_envs, FLAGS.her_window_length
                              or FLAGS.unroll_length,
                              (action_specs, env_output_specs, action_specs))
    env_run_ids = utils.Aggregator(FLAGS.num_envs,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size,
                                             unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    if FLAGS.her_window_length:
        replay_buffer = utils.HindsightExperienceReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            compute_reward_fn=env.compute_reward,
            unroll_length=FLAGS.unroll_length,
            importance_sampling_exponent=0.,
            substitution_probability=FLAGS.her_substitution_probability)
    else:
        replay_buffer = utils.PrioritizedReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            importance_sampling_exponent=0.)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1, dtype=tf.int64)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'env_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(env_ids, run_ids, env_outputs, raw_rewards):
        # Reset the environment that had their first run or crashed.
        previous_run_ids = env_run_ids.read(env_ids)
        env_run_ids.replace(env_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        envs_needing_reset = tf.gather(env_ids, reset_indices)
        if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
            tf.print('Environment ids needing reset:', envs_needing_reset)
        env_infos.reset(envs_needing_reset)
        store.reset(envs_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(envs_needing_reset)[0])
        first_agent_states.replace(envs_needing_reset, initial_agent_states)
        agent_states.replace(envs_needing_reset, initial_agent_states)
        actions.reset(envs_needing_reset)

        tf.debugging.assert_non_positive(
            tf.cast(env_outputs.abandoned, tf.int32),
            'Abandoned done states are not supported in SAC.')

        # Update steps and return.
        env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(env_infos.read(done_ids))
        env_infos.reset(done_ids)
        env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(env_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(env_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args), is_training=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = tf.cast(
            inference_iteration.assign_add(1) % len(inference_devices),
            tf.int32)
        agent_actions, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            env_ids, (prev_actions, env_outputs, agent_actions))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(env_ids, curr_agent_states)
        actions.replace(env_ids, agent_actions)

        # Return environment actions to environments.
        return agent_actions

    with strategy.scope():
        server.bind(inference)
    server.start()

    dataset = create_dataset(unroll_queue, replay_buffer, training_strategy,
                             FLAGS.batch_size, encode)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        tf.summary.scalar('buffer/unrolls_inserted',
                          replay_buffer.num_inserted)
        # log data from info_queue
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        if iterations.numpy() % FLAGS.update_target_every_n_step == 0:
            update_target_agent(FLAGS.polyak)
        # Save checkpoint.
        current_time = time.time()
        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
Beispiel #5
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn, fps_log):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent(*decode(args))

        initial_agent_output, _ = create_variables(input_, initial_agent_state)
        # Create optimizer.
        iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length *
                            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(parametric_action_distribution,
                                          agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.experimental_run_v2(
            compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]
        logs = training_strategy.experimental_local_results(logs)

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.experimental_run_v2(apply_gradients, (loss, ))

        return logs

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(
        FLAGS.num_actors, FLAGS.unroll_length,
        (action_specs, env_output_specs, agent_output_specs))
    actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                     tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs, 'actor_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)
        actor_infos.reset(actors_needing_reset)
        store.reset(actors_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        first_agent_states.replace(actors_needing_reset, initial_agent_states)
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Update steps and return.
        actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(actor_infos.read(done_ids))
        actor_infos.reset(done_ids)
        actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(actor_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args), is_training=True)

                    return agent_inference(input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = inference_iteration.assign_add(1) % len(
            inference_devices)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            actor_ids, (prev_actions, env_outputs, agent_outputs))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)

    with strategy.scope():
        server.bind(inference, batched=True)
    server.start()

    def dequeue(ctx):
        # Create batch (time major).
        actor_outputs = tf.nest.map_structure(
            lambda *args: tf.stack(args), *[
                unroll_queue.dequeue() for i in range(
                    ctx.get_per_replica_batch_size(FLAGS.batch_size))
            ])
        actor_outputs = actor_outputs._replace(
            prev_actions=utils.make_time_major(actor_outputs.prev_actions),
            env_outputs=utils.make_time_major(actor_outputs.env_outputs),
            agent_outputs=utils.make_time_major(actor_outputs.agent_outputs))
        actor_outputs = actor_outputs._replace(
            env_outputs=encode(actor_outputs.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(actor_outputs)

    def dataset_fn(ctx):
        dataset = tf.data.Dataset.from_tensors(0).repeat(None)
        return dataset.map(lambda _: dequeue(ctx),
                           num_parallel_calls=ctx.num_replicas_in_sync)

    dataset = training_strategy.experimental_distribute_datasets_from_function(
        dataset_fn)
    it = iter(dataset)

    # Execute learning and track performance.
    with summary_writer.as_default(), \
      concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        log_future = executor.submit(lambda: None)  # No-op future.
        last_num_env_frames = iterations * iter_frame_ratio
        last_log_time = time.time()
        values_to_log = collections.defaultdict(lambda: [])
        while iterations < final_iteration:
            num_env_frames = iterations * iter_frame_ratio
            tf.summary.experimental.set_step(num_env_frames)

            # Save checkpoint.
            current_time = time.time()
            if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
                manager.save()
                # Apart from checkpointing, we also save the full model (including
                # the graph). This way we can load it after the code/parameters changed.
                tf.saved_model.save(agent,
                                    os.path.join(FLAGS.logdir, 'saved_model'))
                last_ckpt_time = current_time

            def log(iterations, num_env_frames):
                """Logs batch and episodes summaries."""
                nonlocal last_num_env_frames, last_log_time
                summary_writer.set_as_default()
                tf.summary.experimental.set_step(num_env_frames)

                # log data from the current minibatch
                if iterations % FLAGS.log_batch_frequency == 0:
                    for key, values in copy.deepcopy(values_to_log).items():
                        tf.summary.scalar(key, tf.reduce_mean(values))
                    values_to_log.clear()
                    tf.summary.scalar('learning_rate',
                                      learning_rate_fn(iterations))

                # log the number of frames per second
                dt = time.time() - last_log_time
                if dt > 60:
                    df = tf.cast(num_env_frames - last_num_env_frames,
                                 tf.float32)
                    tf.summary.scalar('num_environment_frames/sec', df / dt)
                    fps_log.logger.info('FPS: %f', df / dt)
                    print(f'FPS: {df / dt}')

                    last_num_env_frames, last_log_time = num_env_frames, time.time(
                    )

                # log data from info_queue
                n_episodes = info_queue.size()
                n_episodes -= n_episodes % FLAGS.log_episode_frequency
                if tf.not_equal(n_episodes, 0):
                    episode_stats = info_queue.dequeue_many(n_episodes)
                    episode_keys = [
                        'episode_num_frames', 'episode_return',
                        'episode_raw_return'
                    ]
                    for key, values in zip(episode_keys, episode_stats):
                        for value in tf.split(
                                values, values.shape[0] //
                                FLAGS.log_episode_frequency):
                            tf.summary.scalar(key, tf.reduce_mean(value))

                    for (frames, ep_return, raw_return) in zip(*episode_stats):
                        logging.info('Return: %f Raw return: %f Frames: %i',
                                     ep_return, raw_return, frames)

            logs = minimize(it)

            for per_replica_logs in logs:
                assert len(log_keys) == len(per_replica_logs)
                for key, value in zip(log_keys, per_replica_logs):
                    values_to_log[key].extend(
                        x.numpy()
                        for x in training_strategy.experimental_local_results(
                            value))

            log_future.result()  # Raise exception if any occurred in logging.
            log_future = executor.submit(log, iterations, num_env_frames)

    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
Beispiel #6
0
def visualize(create_env_fn, create_agent_fn, create_optimizer_fn):
  print('Visualization launched...')

  settings = utils.init_learner_multi_host(1)
  strategy, hosts, training_strategy, encode, decode = settings

  env = create_env_fn(0)
  parametric_action_distribution = get_parametric_distribution_for_action_space(
    env.action_space)
  agent = create_agent_fn(env.action_space, env.observation_space,
                          parametric_action_distribution)
  optimizer, learning_rate_fn = create_optimizer_fn(1e9)

  env_output_specs = utils.EnvOutput(
    tf.TensorSpec([], tf.float32, 'reward'),
    tf.TensorSpec([], tf.bool, 'done'),
    tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                  'observation'),
    tf.TensorSpec([], tf.bool, 'abandoned'),
    tf.TensorSpec([], tf.int32, 'episode_step'),
  )
  action_specs = tf.TensorSpec(env.action_space.shape,
                               env.action_space.dtype, 'action')
  agent_input_specs = (action_specs, env_output_specs)
  # Initialize agent and variables.
  agent = create_agent_fn(env.action_space, env.observation_space,
                          parametric_action_distribution)
  initial_agent_state = agent.initial_state(1)
  agent_state_specs = tf.nest.map_structure(
    lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
  unroll_specs = [None]  # Lazy initialization.
  input_ = tf.nest.map_structure(
    lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
  input_ = encode(input_)

  with strategy.scope():
    @tf.function
    def create_variables(*args):
      return agent.get_action(*decode(args))

    initial_agent_output, _ = create_variables(*input_, initial_agent_state)

    if not hasattr(agent, 'entropy_cost'):
      mul = FLAGS.entropy_cost_adjustment_speed
      agent.entropy_cost_param = tf.Variable(
          tf.math.log(FLAGS.entropy_cost) / mul,
          # Without the constraint, the param gradient may get rounded to 0
          # for very small values.
          constraint=lambda v: tf.clip_by_value(v, -20 / mul, 20 / mul),
          trainable=True,
          dtype=tf.float32)
      agent.entropy_cost = lambda: tf.exp(mul * agent.entropy_cost_param)
    # Create optimizer.
    iter_frame_ratio = (
        FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats)
    final_iteration = int(
        math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
    optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)


    iterations = optimizer.iterations
    optimizer._create_hypers()
    optimizer._create_slots(agent.trainable_variables)

    # ON_READ causes the replicated variable to act as independent variables for
    # each replica.
    temp_grads = [
        tf.Variable(tf.zeros_like(v), trainable=False,
                    synchronization=tf.VariableSynchronization.ON_READ)
        for v in agent.trainable_variables
    ]

  agent_output_specs = tf.nest.map_structure(
    lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)

  if True:
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    ckpt.restore('seed_rl/checkpoints/agent_good_3m/ckpt-9').assert_consumed()

  def get_agent_action(obs):
    initial_agent_state = agent.initial_state(1)
    shaped_obs = tf.reshape(tf.convert_to_tensor(obs), shape=(1,)+env.observation_space.shape)
    initial_env_output = (tf.constant([1.]), tf.constant([False]), shaped_obs,
                          tf.constant([False]), tf.constant([1], dtype=tf.float32),)
    agent_out = agent(tf.zeros([0], dtype=tf.float32), initial_env_output,
                      initial_agent_state)
    return agent_out

  def run_episode(steps):
    mode = None
    obs = env.reset()
    rewards = []

    for _ in range(steps):
      agent_out, state = get_agent_action(obs)
      action = agent_out.action.numpy()[0]
      obs, rew, done, info = env.step(action)
      rewards.append(rew)

      if done:
        break

    reward = np.sum(rewards)
    print('reward: {0}'.format(reward))
    return reward

  all_rewards = []
  iter = 0

  while True:
    all_rewards.append(run_episode(250))
    if len(all_rewards) > 1000:
      all_rewards = all_rewards[-1000:]
    print('mean cum reward: {0}'.format(np.mean(all_rewards)))

    if iter % 10 == 0:
      env.save_replay()
      print('\n REPLAY SAVED\n')
    iter += 1

  print('Graceful termination')
  sys.exit(0)
Beispiel #7
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner_multi_host(FLAGS.num_training_tpus)
    strategy, hosts, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    unroll_specs = [None]  # Lazy initialization.
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent.get_action(*decode(args))

        initial_agent_output, _ = create_variables(*input_,
                                                   initial_agent_state)

        if not hasattr(agent, 'entropy_cost'):
            mul = FLAGS.entropy_cost_adjustment_speed
            agent.entropy_cost_param = tf.Variable(
                tf.math.log(FLAGS.entropy_cost) / mul,
                # Without the constraint, the param gradient may get rounded to 0
                # for very small values.
                constraint=lambda v: tf.clip_by_value(v, -20 / mul, 20 / mul),
                trainable=True,
                dtype=tf.float32)
            agent.entropy_cost = lambda: tf.exp(mul * agent.entropy_cost_param)
        # Create optimizer.
        iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length *
                            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs[0],
                                            decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.run(compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.run(apply_gradients, (loss, ))

        getattr(
            agent, 'end_of_training_step_callback',
            lambda: logging.info('end_of_training_step_callback not found'))()

        logger.step_end(logs, training_strategy, iter_frame_ratio)

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer,
                                  starting_step=iterations * iter_frame_ratio)

    servers = []
    unroll_queues = []
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )

    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    def create_host(i, host, inference_devices):
        with tf.device(host):
            server = grpc.Server([FLAGS.server_address])

            store = utils.UnrollStore(
                FLAGS.num_envs, FLAGS.unroll_length,
                (action_specs, env_output_specs, agent_output_specs))
            env_run_ids = utils.Aggregator(
                FLAGS.num_envs, tf.TensorSpec([], tf.int64, 'run_ids'))
            env_infos = utils.Aggregator(FLAGS.num_envs, info_specs,
                                         'env_infos')

            # First agent state in an unroll.
            first_agent_states = utils.Aggregator(FLAGS.num_envs,
                                                  agent_state_specs,
                                                  'first_agent_states')

            # Current agent state and action.
            agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                            'agent_states')
            actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

            unroll_specs[0] = Unroll(agent_state_specs, *store.unroll_specs)
            unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs[0])

            def add_batch_size(ts):
                return tf.TensorSpec([FLAGS.inference_batch_size] +
                                     list(ts.shape), ts.dtype, ts.name)

            inference_specs = (
                tf.TensorSpec([], tf.int32, 'env_id'),
                tf.TensorSpec([], tf.int64, 'run_id'),
                env_output_specs,
                tf.TensorSpec([], tf.float32, 'raw_reward'),
            )
            inference_specs = tf.nest.map_structure(add_batch_size,
                                                    inference_specs)

            def create_inference_fn(inference_device):
                @tf.function(input_signature=inference_specs)
                def inference(env_ids, run_ids, env_outputs, raw_rewards):
                    # Reset the environments that had their first run or crashed.
                    previous_run_ids = env_run_ids.read(env_ids)
                    env_run_ids.replace(env_ids, run_ids)
                    reset_indices = tf.where(
                        tf.not_equal(previous_run_ids, run_ids))[:, 0]
                    envs_needing_reset = tf.gather(env_ids, reset_indices)
                    if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
                        tf.print('Environment ids needing reset:',
                                 envs_needing_reset)
                    env_infos.reset(envs_needing_reset)
                    store.reset(envs_needing_reset)
                    initial_agent_states = agent.initial_state(
                        tf.shape(envs_needing_reset)[0])
                    first_agent_states.replace(envs_needing_reset,
                                               initial_agent_states)
                    agent_states.replace(envs_needing_reset,
                                         initial_agent_states)
                    actions.reset(envs_needing_reset)

                    tf.debugging.assert_non_positive(
                        tf.cast(env_outputs.abandoned, tf.int32),
                        'Abandoned done states are not supported in VTRACE.')

                    # Update steps and return.
                    env_infos.add(env_ids,
                                  (0, env_outputs.reward, raw_rewards))
                    done_ids = tf.gather(env_ids,
                                         tf.where(env_outputs.done)[:, 0])
                    if i == 0:
                        info_queue.enqueue_many(env_infos.read(done_ids))
                    env_infos.reset(done_ids)
                    env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

                    # Inference.
                    prev_actions = parametric_action_distribution.postprocess(
                        actions.read(env_ids))
                    input_ = encode((prev_actions, env_outputs))
                    prev_agent_states = agent_states.read(env_ids)
                    with tf.device(inference_device):

                        @tf.function
                        def agent_inference(*args):
                            return agent(*decode(args),
                                         is_training=False,
                                         postprocess_action=False)

                        agent_outputs, curr_agent_states = agent_inference(
                            *input_, prev_agent_states)

                    # Append the latest outputs to the unroll and insert completed unrolls
                    # in queue.
                    completed_ids, unrolls = store.append(
                        env_ids, (prev_actions, env_outputs, agent_outputs))
                    unrolls = Unroll(first_agent_states.read(completed_ids),
                                     *unrolls)
                    unroll_queue.enqueue_many(unrolls)
                    first_agent_states.replace(
                        completed_ids, agent_states.read(completed_ids))

                    # Update current state.
                    agent_states.replace(env_ids, curr_agent_states)
                    actions.replace(env_ids, agent_outputs.action)
                    # Return environment actions to environments.
                    return parametric_action_distribution.postprocess(
                        agent_outputs.action)

                return inference

            with strategy.scope():
                server.bind(
                    [create_inference_fn(d) for d in inference_devices])
            server.start()
            unroll_queues.append(unroll_queue)
            servers.append(server)

    for i, (host, inference_devices) in enumerate(hosts):
        create_host(i, host, inference_devices)

    def dequeue(ctx):
        # Create batch (time major).
        env_outputs = tf.nest.map_structure(
            lambda *args: tf.stack(args), *[
                unroll_queues[ctx.input_pipeline_id].dequeue() for i in range(
                    ctx.get_per_replica_batch_size(FLAGS.batch_size))
            ])
        env_outputs = env_outputs._replace(
            prev_actions=utils.make_time_major(env_outputs.prev_actions),
            env_outputs=utils.make_time_major(env_outputs.env_outputs),
            agent_outputs=utils.make_time_major(env_outputs.agent_outputs))
        env_outputs = env_outputs._replace(
            env_outputs=encode(env_outputs.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(env_outputs)

    def dataset_fn(ctx):
        dataset = tf.data.Dataset.from_tensors(0).repeat(None)

        def _dequeue(_):
            return dequeue(ctx)

        return dataset.map(_dequeue,
                           num_parallel_calls=ctx.num_replicas_in_sync //
                           len(hosts))

    dataset = training_strategy.experimental_distribute_datasets_from_function(
        dataset_fn)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        # Save checkpoint.
        current_time = time.time()
        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    for server in servers:
        server.shutdown()
    for unroll_queue in unroll_queues:
        unroll_queue.close()
def learner_loop(env_descriptor,
                 create_agent_fn,
                 create_optimizer_fn,
                 config: learner_config.LearnerConfig,
                 mzconfig,
                 pretraining=False):
  """Main learner loop.

  Args:
    env_descriptor: An instance of utils.EnvironmentDescriptor.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment descriptor
      and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument and
      must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
    config: A LearnerConfig object.
    mzconfig: A MuZeroConfig object.
    pretraining: Do pretraining.
  """
  logging.info('Starting learner loop')
  validate_config()
  settings = utils.init_learner(config.num_training_tpus)
  strategy, inference_devices, training_strategy, encode, decode = settings
  tf_function = noop_decorator if config.debug else tf.function
  parametric_action_distribution = get_parametric_distribution_for_action_space(
      env_descriptor.action_space)

  observation_specs = make_spec_from_gym_space(env_descriptor.observation_space,
                                               'observation')
  action_specs = make_spec_from_gym_space(env_descriptor.action_space, 'action')

  if pretraining:
    assert env_descriptor.pretraining_space is not None, (
        'Must define a pretraining space')
    pretraining_specs = make_spec_from_gym_space(
        env_descriptor.pretraining_space, 'pretraining')

  # Initialize agent and variables.
  with strategy.scope():
    agent = create_agent_fn(env_descriptor, parametric_action_distribution)
  initial_agent_state = agent.initial_state(1)
  if config.debug:
    logging.info('initial state:\n{}'.format(initial_agent_state))

  agent_state_specs = tf.nest.map_structure(
      lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)

  zero_observation = tf.nest.map_structure(
      lambda s: tf.zeros([1] + list(s.shape), s.dtype), observation_specs)
  zero_action = tf.nest.map_structure(
      lambda s: tf.zeros([1] + list(s.shape), s.dtype), action_specs)

  zero_initial_args = [encode(zero_observation)]
  zero_recurrent_args = [encode(initial_agent_state), encode(zero_action)]
  if config.debug:
    logging.info('zero initial args:\n{}'.format(zero_initial_args))
    logging.info('zero recurrent args:\n{}'.format(zero_recurrent_args))

  if pretraining:
    zero_pretraining = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), pretraining_specs)
    zero_pretraining_args = [encode(zero_pretraining)]
    logging.info('zero pretraining args:\n{}'.format(zero_pretraining_args))
  else:
    zero_pretraining_args = None

  with strategy.scope():

    def create_variables(initial_args, recurrent_args, pretraining_args):
      agent.initial_inference(*map(decode, initial_args))
      agent.recurrent_inference(*map(decode, recurrent_args))
      if pretraining_args is not None:
        agent.pretraining_loss(*map(decode, pretraining_args))

    # This complicates BatchNormalization, can't use it.
    create_variables(zero_initial_args, zero_recurrent_args,
                     zero_pretraining_args)

  with strategy.scope():
    # Create optimizer.
    optimizer, learning_rate_fn = create_optimizer_fn(config.total_iterations)

    # pylint: disable=protected-access
    iterations = optimizer.iterations
    optimizer._create_hypers()
    optimizer._create_slots(
        agent.get_trainable_variables(pretraining=pretraining))
    # pylint: enable=protected-access

  with strategy.scope():
    # ON_READ causes the replicated variable to act as independent variables for
    # each replica.
    temp_grads = [
        tf.Variable(
            tf.zeros_like(v),
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            name='temp_grad_{}'.format(v.name),
        ) for v in agent.get_trainable_variables(pretraining=pretraining)
    ]

  logging.info('--------------------------')
  logging.info('TRAINABLE VARIABLES')
  for v in agent.get_trainable_variables(pretraining=pretraining):
    logging.info('{}: {} | {}'.format(v.name, v.shape, v.dtype))
  logging.info('--------------------------')

  @tf_function
  def _compute_loss(*args, **kwargs):
    if pretraining:
      return compute_pretrain_loss(config, *args, **kwargs)
    else:
      return compute_loss(config, *args, **kwargs)

  @tf_function
  def minimize(iterator):
    data = next(iterator)

    @tf_function
    def compute_gradients(args):
      args = tf.nest.pack_sequence_as(weighted_replay_buffer_specs,
                                      decode(args, data))
      with tf.GradientTape() as tape:
        loss, logs = _compute_loss(parametric_action_distribution, agent, *args)
      grads = tape.gradient(
          loss, agent.get_trainable_variables(pretraining=pretraining))
      for t, g in zip(temp_grads, grads):
        t.assign(g if g is not None else tf.zeros_like(t))
      return loss, logs

    loss, logs = training_strategy.run(compute_gradients, (data,))
    loss = training_strategy.experimental_local_results(loss)[0]
    logs = training_strategy.experimental_local_results(logs)

    @tf_function
    def apply_gradients(_):
      grads = temp_grads
      if config.gradient_norm_clip > 0.:
        grads, _ = tf.clip_by_global_norm(grads, config.gradient_norm_clip)
      optimizer.apply_gradients(
          zip(grads, agent.get_trainable_variables(pretraining=pretraining)))

    strategy.run(apply_gradients, (loss,))

    return logs

  # Logging.
  logdir = os.path.join(config.logdir, 'learner')
  summary_writer = tf.summary.create_file_writer(
      logdir,
      flush_millis=config.flush_learner_log_every_n_s * 1000,
      max_queue=int(1E6))

  # Setup checkpointing and restore checkpoint.
  ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
  manager = tf.train.CheckpointManager(
      ckpt, logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6)

  # Continuing a run from an intermediate checkpoint.  On this path, we do not
  # need to read `init_checkpoint`.
  if manager.latest_checkpoint:
    logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
    last_ckpt_time = time.time()

    # Also properly reset iterations.
    iterations = optimizer.iterations
  else:
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    # If there is a checkpoint from pre-training specified, load it now.
    # Note that we only need to do this if we are not already restoring a
    # checkpoint from the actual training.
    if config.init_checkpoint is not None:
      logging.info('Loading initial checkpoint from %s ...',
                   config.init_checkpoint)
      # We don't want to restore the optimizer from pretraining
      ckpt_without_optimizer = tf.train.Checkpoint(agent=agent)
      # Loading checkpoints from independent pre-training might miss, for
      # example, optimizer weights (or have used different optimizers), and
      # might also not have fully instantiated all network parts (e.g. the
      # "core"-recurrence).
      # We still want to catch cases where nothing at all matches, but can not
      # do anything stricter here.
      ckpt_without_optimizer.restore(
          config.init_checkpoint).assert_nontrivial_match()
      logging.info('Finished loading the initial checkpoint.')

  server = grpc.Server([config.server_address])

  num_target_steps = mzconfig.num_unroll_steps + 1
  target_specs = (
      tf.TensorSpec([num_target_steps], tf.float32, 'value_mask'),
      tf.TensorSpec([num_target_steps], tf.float32, 'reward_mask'),
      tf.TensorSpec([num_target_steps], tf.float32, 'policy_mask'),
      tf.TensorSpec([num_target_steps], tf.float32, 'value'),
      tf.TensorSpec([num_target_steps], tf.float32, 'reward'),
      tf.TensorSpec([num_target_steps, env_descriptor.action_space.n],
                    tf.float32, 'policy'),
  )

  if pretraining:
    replay_buffer_specs = pretraining_specs
  else:
    replay_buffer_specs = (
        observation_specs,
        tf.TensorSpec(
            env_descriptor.action_space.shape + (mzconfig.num_unroll_steps,),
            env_descriptor.action_space.dtype, 'history'),
        *target_specs,
    )

  weighted_replay_buffer_specs = (
      tf.TensorSpec([], tf.float32, 'importance_weights'), *replay_buffer_specs)

  episode_stat_specs = (
      tf.TensorSpec([], tf.string, 'summary_name'),
      tf.TensorSpec([], tf.float32, 'reward'),
      tf.TensorSpec([], tf.int64, 'episode_length'),
  )
  if env_descriptor.extras:
    episode_stat_specs += tuple(
        tf.TensorSpec([], stat[1], stat[0])
        for stat in env_descriptor.extras.get('learner_stats', []))

  replay_buffer_size = config.replay_buffer_size
  replay_buffer = utils.PrioritizedReplay(
      replay_buffer_size,
      replay_buffer_specs,
      config.importance_sampling_exponent,
  )

  replay_queue_specs = (
      tf.TensorSpec([], tf.float32, 'priority'),
      *replay_buffer_specs,
  )
  replay_queue_size = config.replay_queue_size
  replay_buffer_queue = utils.StructuredFIFOQueue(replay_queue_size,
                                                  replay_queue_specs)

  episode_stat_queue = utils.StructuredFIFOQueue(-1, episode_stat_specs)

  def get_add_batch_size(batch_size):

    def add_batch_size(ts):
      return tf.TensorSpec([batch_size] + list(ts.shape), ts.dtype, ts.name)

    return add_batch_size

  def make_inference_fn(inference_device, inference_fn, *args):

    args = encode(args)

    def device_specific_inference_fn():
      with tf.device(inference_device):

        @tf_function
        def agent_inference(*args):
          return inference_fn(*decode(args), training=False)

        return agent_inference(*args)

    return device_specific_inference_fn

  initial_inference_specs = (observation_specs,)

  def make_initial_inference_fn(inference_device):

    @tf.function(
        input_signature=tf.nest.map_structure(
            get_add_batch_size(config.initial_inference_batch_size),
            initial_inference_specs))
    def initial_inference(observation):
      return make_inference_fn(inference_device, agent.initial_inference,
                               observation)()

    return initial_inference

  recurrent_inference_specs = (
      agent_state_specs,
      action_specs,
  )

  def make_recurrent_inference_fn(inference_device):

    @tf.function(
        input_signature=tf.nest.map_structure(
            get_add_batch_size(config.recurrent_inference_batch_size),
            recurrent_inference_specs))
    def recurrent_inference(hidden_state, action):
      return make_inference_fn(inference_device, agent.recurrent_inference,
                               hidden_state, action)()

    return recurrent_inference

  @tf.function(
      input_signature=tf.nest.map_structure(
          get_add_batch_size(config.batch_size), replay_queue_specs))
  def add_to_replay_buffer(*batch):
    queue_size = replay_buffer_queue.size()
    num_free = replay_queue_size - queue_size
    if not config.replay_queue_block and num_free < config.recurrent_inference_batch_size:
      replay_buffer_queue.dequeue_many(config.recurrent_inference_batch_size)
    replay_buffer_queue.enqueue_many(batch)

  @tf.function(input_signature=episode_stat_specs)
  def add_to_reward_queue(*stats):
    episode_stat_queue.enqueue(stats)

  @tf.function(input_signature=[])
  def learning_iteration():
    return optimizer.iterations

  with strategy.scope():
    server.bind([make_initial_inference_fn(d) for d in inference_devices])
    server.bind([make_recurrent_inference_fn(d) for d in inference_devices])
    server.bind(add_to_replay_buffer)
    server.bind(add_to_reward_queue)
    server.bind(learning_iteration)
  server.start()

  @tf_function
  def dequeue(ctx):

    while tf.constant(True):

      num_dequeues = config.learner_skip + 1
      if num_dequeues < 1:
        queue_size = replay_buffer_queue.size()
        num_dequeues = tf.maximum(queue_size // config.batch_size - 1,
                                  tf.ones_like(queue_size))
      for _ in tf.range(num_dequeues):
        batch = replay_buffer_queue.dequeue_many(config.batch_size)
        priorities, *samples = batch
        replay_buffer.insert(tuple(samples), priorities)

      if replay_buffer.num_inserted >= replay_buffer_size:
        break

      tf.print(
          'waiting for replay buffer to fill. Status:',
          replay_buffer.num_inserted,
          ' / ',
          replay_buffer_size,
      )

    indices, weights, replays = replay_buffer.sample(
        ctx.get_per_replica_batch_size(config.batch_size),
        config.priority_sampling_exponent)
    if config.replay_buffer_update_priority_after_sampling_value >= 0.:
      replay_buffer.update_priorities(
          indices,
          tf.convert_to_tensor(
              np.ones(indices.shape) *
              config.replay_buffer_update_priority_after_sampling_value,
              dtype=tf.float32))

    data = (weights, *replays)
    data = tuple(map(encode, data))

    # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
    # repack.
    return tf.nest.flatten(data)

  def dataset_fn(ctx):
    dataset = tf.data.Dataset.from_tensors(0).repeat(None)
    return dataset.map(
        lambda _: dequeue(ctx), num_parallel_calls=ctx.num_replicas_in_sync)

  dataset = training_strategy.experimental_distribute_datasets_from_function(
      dataset_fn)
  it = iter(dataset)

  # Execute learning and track performance.
  with summary_writer.as_default(), \
       concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
    log_future = executor.submit(lambda: None)  # No-op future.
    last_iterations = iterations
    last_log_time = time.time()
    values_to_log = collections.defaultdict(lambda: [])
    while iterations < config.total_iterations:
      tf.summary.experimental.set_step(iterations)

      # Save checkpoint.
      current_time = time.time()
      if current_time - last_ckpt_time >= config.save_checkpoint_secs:
        manager.save()
        if config.export_agent:
          # We also export the agent as a SavedModel to be used for inference.
          saved_model_dir = os.path.join(logdir, 'saved_model')
          network.export_agent_for_initial_inference(
              agent=agent,
              model_dir=os.path.join(saved_model_dir, 'initial_inference'))
          network.export_agent_for_recurrent_inference(
              agent=agent,
              model_dir=os.path.join(saved_model_dir, 'recurrent_inference'))
        last_ckpt_time = current_time

      def log(iterations):
        """Logs batch and episodes summaries."""
        nonlocal last_iterations, last_log_time
        summary_writer.set_as_default()
        tf.summary.experimental.set_step(iterations)

        # log data from the current minibatch
        for key, values in copy.deepcopy(values_to_log).items():
          if values:
            tf.summary.scalar(key, values[-1])  # could also take mean
        values_to_log.clear()
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        tf.summary.scalar('replay_queue_size', replay_buffer_queue.size())
        stats = episode_stat_queue.dequeue_many(episode_stat_queue.size())

        summary_name_idx = [spec.name for spec in episode_stat_specs
                           ].index('summary_name')
        summary_name_stats = stats[summary_name_idx]
        unique_summary_names, unique_summary_name_idx = tf.unique(
            summary_name_stats)

        def log_mean_value(values, label):
          mean_value = tf.reduce_mean(tf.cast(values, tf.float32))
          tf.summary.scalar(label, mean_value)


        for stat, stat_spec in zip(stats, episode_stat_specs):
          if stat_spec.name == 'summary_name' or len(stat) <= 0:
            continue

          for idx, summary_name in enumerate(unique_summary_names):
            add_to_summary = unique_summary_name_idx == idx
            stat_masked = tf.boolean_mask(stat, add_to_summary)
            label = f'{summary_name.numpy().decode()}/mean_{stat_spec.name}'
            if len(stat_masked) > 0:  # pylint: disable=g-explicit-length-test
              log_mean_value(stat_masked, label=label)

      logs = minimize(it)

      if (config.enable_learner_logging == 1 and
          iterations % config.log_frequency == 0):
        for per_replica_logs in logs:
          assert len(log_keys) == len(per_replica_logs)
          for key, value in zip(log_keys, per_replica_logs):
            try:
              values_to_log[key].append(value.numpy())
            except AttributeError:
              values_to_log[key].extend(
                  x.numpy()
                  for x in training_strategy.experimental_local_results(value))

        log_future.result()  # Raise exception if any occurred in logging.
        log_future = executor.submit(log, iterations)

  manager.save()
  server.shutdown()
Beispiel #9
0
def learner_loop(create_env_fn, create_agent_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    # validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint)  #.assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint)  #.assert_consumed()
        last_ckpt_time = time.time()

    actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                     tf.TensorSpec([], tf.int64, 'run_ids'))

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions')

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        # tf.print("agent initial_agent_states",tf.reduce_mean(initial_agent_states))
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Inference.
        prev_actions = parametric_action_distribution.postprocess(
            actions.read(actor_ids))
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)
        tf.print("agent states", tf.reduce_mean(prev_agent_states),
                 tf.reduce_max(prev_agent_states))
        # def make_inference_fn(inference_device):
        #   def device_specific_inference_fn():
        #     with tf.device(inference_device):
        #       @tf.function
        #       def agent_inference(*args):
        #         return agent(*decode(args), is_training=False,
        #                      postprocess_action=False)

        #       return agent_inference(*input_, prev_agent_states)

        #   return device_specific_inference_fn

        # # Distribute the inference calls among the inference cores.
        # branch_index = inference_iteration.assign_add(1) % len(inference_devices)
        # agent_outputs, curr_agent_states = tf.switch_case(branch_index, {
        #     i: make_inference_fn(inference_device)
        #     for i, inference_device in enumerate(inference_devices)
        # })
        with tf.device(inference_devices[0]):
            agent_outputs, curr_agent_states = agent(
                *(decode((*input_, prev_agent_states))))
        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)

    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

    actor_step = 0
    with summary_writer.as_default():
        while True:
            try:

                env = create_env_fn(FLAGS.task)

                # Unique ID to identify a specific run of an actor.
                run_id = np.random.randint(np.iinfo(np.int64).max)
                observation = env.reset()
                reward = 0.0
                raw_reward = 0.0
                done = False

                episode_step = 0
                episode_return = 0
                episode_raw_return = 0

                while True:
                    tf.summary.experimental.set_step(actor_step)
                    env_output = utils.EnvOutput(
                        [tf.cast(reward, tf.float32)], [done],
                        tf.cast(observation[None], tf.float32))
                    with timer_cls('actor/elapsed_inference_s', 1000):
                        action = inference([FLAGS.task], [run_id], env_output,
                                           [raw_reward])[0]
                    with timer_cls('actor/elapsed_env_step_s', 1000):
                        observation, reward, done, info = env.step(
                            action.numpy())

                    # env.render()
                    episode_step += 1
                    episode_return += reward
                    raw_reward = float((info
                                        or {}).get('score_reward', reward))
                    episode_raw_return += raw_reward
                    if done:
                        logging.info('Return: %f Raw return: %f Steps: %i',
                                     episode_return, episode_raw_return,
                                     episode_step)
                        env.render()
                        with timer_cls('actor/elapsed_env_reset_s', 10):
                            observation = env.reset()
                            episode_step = 0
                            episode_return = 0
                            episode_raw_return = 0

                    actor_step += 1
            except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
                logging.exception(e)
                env.close()