Exemplo n.º 1
0
 def _create_env_output(self, batch_size, unroll_length):
     return utils.EnvOutput(
         reward=tf.random.uniform([unroll_length, batch_size]),
         done=tf.cast(
             tf.random.uniform([unroll_length, batch_size],
                               maxval=2,
                               dtype=tf.int32), tf.bool),
         observation=self._random_obs(batch_size, unroll_length))
Exemplo n.º 2
0
 def wrap(self, x, y=None):
     unroll = collections.namedtuple('unroll', 'env_outputs')
     return unroll(env_outputs=utils.EnvOutput(
         observation={
             'achieved_goal': x,
             'desired_goal': y if (y is not None) else x
         },
         done=tf.zeros(x.shape[:-1], tf.bool),
         reward=tf.zeros(x.shape[:-1], tf.float32),
         abandoned=tf.zeros(x.shape[:-1], tf.bool),
         episode_step=tf.ones(x.shape[:-1], tf.int32),
     ))
def _dummy_input(unroll):
    """Returns a dummy tuple that can be fed into an agent."""
    batch_size = 15
    base_shape = [6, batch_size] if unroll else [batch_size]
    prev_actions = tf.zeros(base_shape + [10], tf.float32)
    # Create the environment output.
    env_outputs = utils.EnvOutput(reward=tf.zeros(base_shape, tf.float32),
                                  done=tf.zeros(base_shape, tf.bool),
                                  observation=tf.zeros(base_shape + [17],
                                                       tf.float32),
                                  abandoned=tf.zeros(base_shape, tf.bool),
                                  episode_step=tf.zeros(base_shape, tf.bool))
    core_state = _dummy_rnn_core_state(batch_size=batch_size)
    return (prev_actions, env_outputs), core_state
Exemplo n.º 4
0
    def test_actor_critic_lstm(self):
        n_steps = 100
        batch_size = 10
        obs_size = 15
        action_size = 3

        action_dist = parametric_distribution.normal_tanh_distribution(
            action_size)
        agent = networks.ActorCriticLSTM(action_dist,
                                         n_critics=2,
                                         lstm_sizes=[10, 20],
                                         pre_mlp_sizes=[30, 40],
                                         post_mlp_sizes=[50],
                                         ff_mlp_sizes=[25, 35, 45])
        env_output = utils.EnvOutput(
            observation=tf.random.normal((n_steps, batch_size, obs_size)),
            reward=tf.random.normal((n_steps, batch_size)),
            done=tf.cast(tf.random.uniform((n_steps, batch_size), 0, 1),
                         tf.bool),
            abandoned=tf.zeros((n_steps, batch_size), dtype=tf.bool),
            episode_step=tf.ones((n_steps, batch_size), dtype=tf.int32))
        prev_action = tf.random.normal((n_steps, batch_size, action_size))
        action = tf.random.normal((n_steps, batch_size, action_size))
        state = agent.initial_state(10)

        # Run in one call.
        v_one_call = agent.get_V(prev_action, env_output, state)
        q_one_call = agent.get_Q(prev_action, env_output, state, action)

        # Run step-by-step.
        v_many_calls = []
        q_many_calls = []
        for i in range(n_steps):

            env_output_i = tf.nest.map_structure(lambda t: t[i], env_output)
            expanded_env_output_i = tf.nest.map_structure(
                lambda t: t[i, tf.newaxis], env_output)
            v_many_calls.append(
                agent.get_V(prev_action[i, tf.newaxis], expanded_env_output_i,
                            state)[0])
            q_many_calls.append(
                agent.get_Q(prev_action[i, tf.newaxis], expanded_env_output_i,
                            state, action[i, tf.newaxis])[0])
            unused_action, state = agent(prev_action[i], env_output_i, state)
        v_many_calls = tf.stack(v_many_calls)
        q_many_calls = tf.stack(q_many_calls)

        # Check if results are the same.
        self.assertAllClose(v_one_call, v_many_calls, 1e-4, 1e-4)
        self.assertAllClose(q_one_call, q_many_calls, 1e-4, 1e-4)
Exemplo n.º 5
0
def actor_loop(create_env_fn):
  """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """
  logging.info('Starting actor loop')
  if are_summaries_enabled():
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
        flush_millis=20000, max_queue=1000)
    timer_cls = profiling.ExportingTimer
  else:
    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

  actor_step = 0
  with summary_writer.as_default():
    while True:
      try:
        # Client to communicate with the learner.
        client = grpc.Client(FLAGS.server_address)

        env = create_env_fn(FLAGS.task)

        # Unique ID to identify a specific run of an actor.
        run_id = np.random.randint(np.iinfo(np.int64).max)
        observation = env.reset()
        reward = 0.0
        raw_reward = 0.0
        done = False

        while True:
          tf.summary.experimental.set_step(actor_step)
          env_output = utils.EnvOutput(reward, done, observation)
          with timer_cls('actor/elapsed_inference_s', 1000):
            action = client.inference(
                (FLAGS.task, run_id, env_output, raw_reward))
          with timer_cls('actor/elapsed_env_step_s', 1000):
            observation, reward, done, info = env.step(action.numpy())
          raw_reward = float(info.get('score_reward', reward))
          if done:
            with timer_cls('actor/elapsed_env_reset_s', 10):
              observation = env.reset()
          actor_step += 1
      except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
        logging.exception(e)
        env.close()
Exemplo n.º 6
0
def actor_loop(create_env_fn):
    """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """
    logging.info('Starting actor eval loop')

    summary_writer = tf.summary.create_file_writer(os.path.join(
        FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
                                                   flush_millis=20000,
                                                   max_queue=1000)
    timer_cls = profiling.ExportingTimer

    actor_step = 0
    with summary_writer.as_default():
        while True:
            try:
                # Client to communicate with the learner.
                client = grpc.Client(FLAGS.server_address)

                env = create_env_fn(FLAGS.task, color='black')
                env1 = create_env_fn(FLAGS.task, color='white')

                # Unique ID to identify a specific run of an actor.
                run_id = np.random.randint(np.iinfo(np.int64).max)
                observation = env.reset()
                reward = 0.0
                raw_reward = 0.0
                done = False

                episode_step = 0
                episode_return = 0
                episode_raw_return = 0

                eval_times = 0
                eval_state = 'black'
                print("starting eval: ", eval_state)

                while True:
                    tf.summary.experimental.set_step(actor_step)
                    env_output = utils.EnvOutput(
                        tf.cast(reward, tf.float32), done,
                        tf.cast(observation, tf.float32))
                    with timer_cls('actor/elapsed_inference_s', 1000):
                        action = client.inference_eval(FLAGS.task, run_id,
                                                       env_output, raw_reward)

                    if eval_state == 'black':
                        with timer_cls('actor/elapsed_env_step_s', 1000):
                            observation, reward, done, info = env.step(
                                action.numpy())
                    else:
                        with timer_cls('actor/elapsed_env_step_s', 1000):
                            observation, reward, done, info = env1.step(
                                action.numpy())

                    if is_rendering_enabled():
                        env.render()
                    episode_step += 1
                    episode_return += reward
                    raw_reward = float((info
                                        or {}).get('score_reward', reward))
                    episode_raw_return += raw_reward

                    if done:
                        eval_times += 1
                        if eval_times >= 50:
                            tf.summary.scalar(
                                'actor/eval_return_' + eval_state,
                                episode_return)
                            logging.info(
                                '%s win/all: %d/%d Raw return: %f Steps: %i',
                                eval_state, (episode_return + eval_times) / 2,
                                eval_times, episode_raw_return, episode_step)
                            episode_step = 0
                            episode_return = 0
                            episode_raw_return = 0

                            time.sleep(300)
                            eval_times = 0
                            eval_state = 'white' if eval_state == 'black' else 'black'
                            print("starting eval: ", eval_state)

                        if eval_state == 'black':
                            with timer_cls('actor/elapsed_env_reset_s', 10):
                                observation = env.reset()
                        else:
                            with timer_cls('actor/elapsed_env_reset_s', 10):
                                observation = env1.reset()

                    actor_step += 1
            except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
                logging.exception(e)
                env.close()
Exemplo n.º 7
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn, fps_log):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent(*decode(args))

        initial_agent_output, _ = create_variables(input_, initial_agent_state)
        # Create optimizer.
        iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length *
                            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(parametric_action_distribution,
                                          agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.experimental_run_v2(
            compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]
        logs = training_strategy.experimental_local_results(logs)

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.experimental_run_v2(apply_gradients, (loss, ))

        return logs

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(
        FLAGS.num_actors, FLAGS.unroll_length,
        (action_specs, env_output_specs, agent_output_specs))
    actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                     tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs, 'actor_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)
        actor_infos.reset(actors_needing_reset)
        store.reset(actors_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        first_agent_states.replace(actors_needing_reset, initial_agent_states)
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Update steps and return.
        actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(actor_infos.read(done_ids))
        actor_infos.reset(done_ids)
        actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(actor_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args), is_training=True)

                    return agent_inference(input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = inference_iteration.assign_add(1) % len(
            inference_devices)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            actor_ids, (prev_actions, env_outputs, agent_outputs))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)

    with strategy.scope():
        server.bind(inference, batched=True)
    server.start()

    def dequeue(ctx):
        # Create batch (time major).
        actor_outputs = tf.nest.map_structure(
            lambda *args: tf.stack(args), *[
                unroll_queue.dequeue() for i in range(
                    ctx.get_per_replica_batch_size(FLAGS.batch_size))
            ])
        actor_outputs = actor_outputs._replace(
            prev_actions=utils.make_time_major(actor_outputs.prev_actions),
            env_outputs=utils.make_time_major(actor_outputs.env_outputs),
            agent_outputs=utils.make_time_major(actor_outputs.agent_outputs))
        actor_outputs = actor_outputs._replace(
            env_outputs=encode(actor_outputs.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(actor_outputs)

    def dataset_fn(ctx):
        dataset = tf.data.Dataset.from_tensors(0).repeat(None)
        return dataset.map(lambda _: dequeue(ctx),
                           num_parallel_calls=ctx.num_replicas_in_sync)

    dataset = training_strategy.experimental_distribute_datasets_from_function(
        dataset_fn)
    it = iter(dataset)

    # Execute learning and track performance.
    with summary_writer.as_default(), \
      concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        log_future = executor.submit(lambda: None)  # No-op future.
        last_num_env_frames = iterations * iter_frame_ratio
        last_log_time = time.time()
        values_to_log = collections.defaultdict(lambda: [])
        while iterations < final_iteration:
            num_env_frames = iterations * iter_frame_ratio
            tf.summary.experimental.set_step(num_env_frames)

            # Save checkpoint.
            current_time = time.time()
            if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
                manager.save()
                # Apart from checkpointing, we also save the full model (including
                # the graph). This way we can load it after the code/parameters changed.
                tf.saved_model.save(agent,
                                    os.path.join(FLAGS.logdir, 'saved_model'))
                last_ckpt_time = current_time

            def log(iterations, num_env_frames):
                """Logs batch and episodes summaries."""
                nonlocal last_num_env_frames, last_log_time
                summary_writer.set_as_default()
                tf.summary.experimental.set_step(num_env_frames)

                # log data from the current minibatch
                if iterations % FLAGS.log_batch_frequency == 0:
                    for key, values in copy.deepcopy(values_to_log).items():
                        tf.summary.scalar(key, tf.reduce_mean(values))
                    values_to_log.clear()
                    tf.summary.scalar('learning_rate',
                                      learning_rate_fn(iterations))

                # log the number of frames per second
                dt = time.time() - last_log_time
                if dt > 60:
                    df = tf.cast(num_env_frames - last_num_env_frames,
                                 tf.float32)
                    tf.summary.scalar('num_environment_frames/sec', df / dt)
                    fps_log.logger.info('FPS: %f', df / dt)
                    print(f'FPS: {df / dt}')

                    last_num_env_frames, last_log_time = num_env_frames, time.time(
                    )

                # log data from info_queue
                n_episodes = info_queue.size()
                n_episodes -= n_episodes % FLAGS.log_episode_frequency
                if tf.not_equal(n_episodes, 0):
                    episode_stats = info_queue.dequeue_many(n_episodes)
                    episode_keys = [
                        'episode_num_frames', 'episode_return',
                        'episode_raw_return'
                    ]
                    for key, values in zip(episode_keys, episode_stats):
                        for value in tf.split(
                                values, values.shape[0] //
                                FLAGS.log_episode_frequency):
                            tf.summary.scalar(key, tf.reduce_mean(value))

                    for (frames, ep_return, raw_return) in zip(*episode_stats):
                        logging.info('Return: %f Raw return: %f Frames: %i',
                                     ep_return, raw_return, frames)

            logs = minimize(it)

            for per_replica_logs in logs:
                assert len(log_keys) == len(per_replica_logs)
                for key, value in zip(log_keys, per_replica_logs):
                    values_to_log[key].extend(
                        x.numpy()
                        for x in training_strategy.experimental_local_results(
                            value))

            log_future.result()  # Raise exception if any occurred in logging.
            log_future = executor.submit(log, iterations, num_env_frames)

    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
Exemplo n.º 8
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions, Q values and new agent state given the
      environment observations and previous agent state. See
      atari.agents.DuelingLSTMDQNNet for an example. The factory function takes
      as input the environment output specs and the number of possible actions
      in the env.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec([], tf.int32, 'action')
    num_actions = env.action_space.n
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env_output_specs, num_actions)
    target_agent = create_agent_fn(env_output_specs, num_actions)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent(*decode(args))

        @tf.function
        def create_target_agent_variables(*args):
            return target_agent(*decode(args))

        # The first call to Keras models to create varibales for agent and target.
        initial_agent_output, _ = create_variables(input_, initial_agent_state)
        create_target_agent_variables(input_, initial_agent_state)

        @tf.function
        def update_target_agent():
            """Synchronizes training and target agent variables."""
            variables = agent.trainable_variables
            target_variables = target_agent.trainable_variables
            assert len(target_variables) == len(variables), (
                'Mismatch in number of net tensors: {} != {}'.format(
                    len(target_variables), len(variables)))
            for target_var, source_var in zip(target_variables, variables):
                target_var.assign(source_var)

        # Create optimizer.
        iter_frame_ratio = (get_replay_insertion_batch_size() *
                            FLAGS.unroll_length * FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        """Computes and applies gradients.

    Args:
      iterator: An iterator of distributed dataset that produces `PerReplica`.

    Returns:
      A tuple:
        - priorities, the new priorities. Shape <float32>[batch_size].
        - indices, the indices for updating priorities. Shape
        <int32>[batch_size].
        - gradient_norm_before_clip, a scalar.
    """
        data = next(iterator)

        def compute_gradients(args):
            """A function to pass to `Strategy` for gradient computation."""
            args = decode(args, data)
            args = tf.nest.pack_sequence_as(SampledUnrolls(unroll_specs, 0, 0),
                                            args)
            with tf.GradientTape() as tape:
                # loss: [batch_size]
                # priorities: [batch_size]
                loss, priorities = compute_loss_and_priorities(
                    agent,
                    target_agent,
                    args.unrolls.agent_state,
                    args.unrolls.prev_actions,
                    args.unrolls.env_outputs,
                    args.unrolls.agent_outputs,
                    gamma=FLAGS.discounting,
                    burn_in=FLAGS.burn_in)
                loss = tf.reduce_mean(loss * args.importance_weights)
            grads = tape.gradient(loss, agent.trainable_variables)
            gradient_norm_before_clip = tf.linalg.global_norm(grads)
            if FLAGS.clip_norm:
                grads, _ = tf.clip_by_global_norm(
                    grads, FLAGS.clip_norm, use_norm=gradient_norm_before_clip)

            for t, g in zip(temp_grads, grads):
                t.assign(g)

            return loss, priorities, args.indices, gradient_norm_before_clip

        loss, priorities, indices, gradient_norm_before_clip = (
            training_strategy.run(compute_gradients, (data, )))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(loss):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))
            return loss

        loss = strategy.run(apply_gradients, (loss, ))

        # convert PerReplica to a Tensor
        if not isinstance(priorities, tf.Tensor):

            priorities = tf.reshape(tf.stack(priorities.values), [-1])
            indices = tf.reshape(tf.stack(indices.values), [-1])
            gradient_norm_before_clip = tf.reshape(
                tf.stack(gradient_norm_before_clip.values), [-1])
            gradient_norm_before_clip = tf.reduce_max(
                gradient_norm_before_clip)

        return loss, priorities, indices, gradient_norm_before_clip

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)

    # Setup checkpointing and restore checkpoint.

    ckpt = tf.train.Checkpoint(agent=agent,
                               target_agent=target_agent,
                               optimizer=optimizer)
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    server = grpc.Server([FLAGS.server_address])

    # Buffer of incomplete unrolls. Filled during inference with new transitions.
    # This only contains data from training environments.
    store = utils.UnrollStore(
        get_num_training_envs(),
        FLAGS.unroll_length,
        (action_specs, env_output_specs, agent_output_specs),
        num_overlapping_steps=FLAGS.burn_in)
    env_run_ids = utils.Aggregator(FLAGS.num_envs,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs,
                          tf.TensorSpec([], tf.float32, 'priority'),
                          *store.unroll_specs)
    # Queue of complete unrolls. Filled by the inference threads, and consumed by
    # the tf.data.Dataset thread.
    unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size,
                                             unroll_specs)
    episode_info_specs = EpisodeInfo(
        *(info_specs + (tf.TensorSpec([], tf.int32, 'env_ids'), )))
    info_queue = utils.StructuredFIFOQueue(-1, episode_info_specs)

    replay_buffer = utils.PrioritizedReplay(FLAGS.replay_buffer_size,
                                            unroll_specs,
                                            FLAGS.importance_sampling_exponent)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1, dtype=tf.int64)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'env_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(env_ids, run_ids, env_outputs, raw_rewards):
        """Agent inference.

    This evaluates the agent policy on the provided environment data (reward,
    done, observation), and store appropriate data to feed the main training
    loop.

    Args:
      env_ids: <int32>[inference_batch_size], the environment task IDs (in range
        [0, num_tasks)).
      run_ids: <int64>[inference_batch_size], the environment run IDs.
        Environment generates a random int64 run id at startup, so this can be
        used to detect the environment jobs that restarted.
      env_outputs: Follows env_output_specs, but with the inference_batch_size
        added as first dimension. These are the actual environment outputs
        (reward, done, observation).
      raw_rewards: <float32>[inference_batch_size], representing the raw reward
        of each step.

    Returns:
      A tensor <int32>[inference_batch_size] with one action for each
        environment.
    """
        # Reset the environments that had their first run or crashed.
        previous_run_ids = env_run_ids.read(env_ids)
        env_run_ids.replace(env_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        envs_needing_reset = tf.gather(env_ids, reset_indices)
        if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
            tf.print('Environments needing reset:', envs_needing_reset)
        env_infos.reset(envs_needing_reset)
        store.reset(
            tf.gather(envs_needing_reset,
                      tf.where(is_training_env(envs_needing_reset))[:, 0]))
        initial_agent_states = agent.initial_state(
            tf.shape(envs_needing_reset)[0])
        first_agent_states.replace(envs_needing_reset, initial_agent_states)
        agent_states.replace(envs_needing_reset, initial_agent_states)
        actions.reset(envs_needing_reset)

        tf.debugging.assert_non_positive(
            tf.cast(env_outputs.abandoned, tf.int32),
            'Abandoned done states are not supported in R2D2.')

        # Update steps and return.
        env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0])
        done_episodes_info = env_infos.read(done_ids)
        info_queue.enqueue_many(
            EpisodeInfo(*(done_episodes_info + (done_ids, ))))
        env_infos.reset(done_ids)
        env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(env_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(env_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args))

                    return agent_inference(input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = tf.cast(
            inference_iteration.assign_add(1) % len(inference_devices),
            tf.int32)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        agent_outputs = agent_outputs._replace(action=apply_epsilon_greedy(
            agent_outputs.action, env_ids, get_num_training_envs(),
            FLAGS.num_eval_envs, FLAGS.eval_epsilon, num_actions))

        # Append the latest outputs to the unroll, only for experience coming from
        # training environments (IDs < num_training_envs), and insert completed
        # unrolls in queue.
        # <int64>[num_training_envs]
        training_indices = tf.where(is_training_env(env_ids))[:, 0]
        training_env_ids = tf.gather(env_ids, training_indices)
        training_prev_actions, training_env_outputs, training_agent_outputs = (
            tf.nest.map_structure(lambda s: tf.gather(s, training_indices),
                                  (prev_actions, env_outputs, agent_outputs)))

        append_to_store = (training_prev_actions, training_env_outputs,
                           training_agent_outputs)
        completed_ids, completed_unrolls = store.append(
            training_env_ids, append_to_store)
        _, unrolled_env_outputs, unrolled_agent_outputs = completed_unrolls
        unrolled_agent_states = first_agent_states.read(completed_ids)

        # Only use the suffix of the unrolls that is actually used for training. The
        # prefix is only used for burn-in of agent state at training time.
        _, agent_outputs_suffix = utils.split_structure(
            utils.make_time_major(unrolled_agent_outputs), FLAGS.burn_in)
        _, env_outputs_suffix = utils.split_structure(
            utils.make_time_major(unrolled_env_outputs), FLAGS.burn_in)
        _, initial_priorities = compute_loss_and_priorities_from_agent_outputs(
            # We don't use the outputs from a separated target network for computing
            # initial priorities.
            agent_outputs_suffix,
            agent_outputs_suffix,
            env_outputs_suffix,
            agent_outputs_suffix,
            gamma=FLAGS.discounting)

        unrolls = Unroll(unrolled_agent_states, initial_priorities,
                         *completed_unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(env_ids, curr_agent_states)
        actions.replace(env_ids, agent_outputs.action)

        # Return environment actions to environments.
        return agent_outputs.action

    with strategy.scope():
        server.bind(inference)
    server.start()

    # Execute learning and track performance.
    with summary_writer.as_default(), \
      concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        log_future = executor.submit(lambda: None)  # No-op future.
        tf.summary.experimental.set_step(iterations * iter_frame_ratio)
        dataset = create_dataset(unroll_queue, replay_buffer,
                                 training_strategy, FLAGS.batch_size,
                                 FLAGS.priority_exponent, encode)
        it = iter(dataset)

        last_num_env_frames = iterations * iter_frame_ratio
        last_log_time = time.time()
        max_gradient_norm_before_clip = 0.
        while iterations < final_iteration:
            num_env_frames = iterations * iter_frame_ratio
            tf.summary.experimental.set_step(num_env_frames)

            if iterations.numpy() % FLAGS.update_target_every_n_step == 0:
                update_target_agent()

            # Save checkpoint.
            current_time = time.time()
            if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
                manager.save()
                last_ckpt_time = current_time

            def log(num_env_frames):
                """Logs environment summaries."""
                summary_writer.set_as_default()
                tf.summary.experimental.set_step(num_env_frames)
                episode_info = info_queue.dequeue_many(info_queue.size())
                for n, r, _, env_id in zip(*episode_info):
                    is_training = is_training_env(env_id)
                    logging.info(
                        'Return: %f Frames: %i Env id: %i (%s) Iteration: %i',
                        r, n, env_id, 'training' if is_training else 'eval',
                        iterations.numpy())
                    if not is_training:
                        tf.summary.scalar('eval/episode_return', r)
                        tf.summary.scalar('eval/episode_frames', n)

            log_future.result()  # Raise exception if any occurred in logging.
            log_future = executor.submit(log, num_env_frames)

            _, priorities, indices, gradient_norm = minimize(it)

            replay_buffer.update_priorities(indices, priorities)
            # Max of gradient norms (before clipping) since last tf.summary export.
            max_gradient_norm_before_clip = max(gradient_norm.numpy(),
                                                max_gradient_norm_before_clip)
            if current_time - last_log_time >= 120:
                df = tf.cast(num_env_frames - last_num_env_frames, tf.float32)
                dt = time.time() - last_log_time
                tf.summary.scalar('num_environment_frames/sec (actors)',
                                  df / dt)
                tf.summary.scalar('num_environment_frames/sec (learner)',
                                  df / dt * FLAGS.replay_ratio)

                tf.summary.scalar('learning_rate',
                                  learning_rate_fn(iterations))
                tf.summary.scalar('replay_buffer_num_inserted',
                                  replay_buffer.num_inserted)
                tf.summary.scalar('unroll_queue_size', unroll_queue.size())

                last_num_env_frames, last_log_time = num_env_frames, time.time(
                )
                tf.summary.histogram('updated_priorities', priorities)
                tf.summary.scalar('max_gradient_norm_before_clip',
                                  max_gradient_norm_before_clip)
                max_gradient_norm_before_clip = 0.

    manager.save()
    server.shutdown()
    unroll_queue.close()
Exemplo n.º 9
0
def actor_loop(create_env_fn, config=None, log_period=1):
  """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
    config: Configuration of the training.
    log_period: How often to log in seconds.
  """
  if not config:
    config = FLAGS
  env_batch_size = FLAGS.env_batch_size
  logging.info('Starting actor loop. Task: %r. Environment batch size: %r',
               FLAGS.task, env_batch_size)
  is_rendering_enabled = FLAGS.render and FLAGS.task == 0
  if are_summaries_enabled():
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
        flush_millis=20000, max_queue=1000)
    timer_cls = profiling.ExportingTimer
  else:
    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

  actor_step = 0
  with summary_writer.as_default():
    while True:
      try:
        # Client to communicate with the learner.
        client = grpc.Client(FLAGS.server_address)
        utils.update_config(config, client)
        batched_env = env_wrappers.BatchedEnvironment(
            create_env_fn, env_batch_size, FLAGS.task * env_batch_size, config)

        env_id = batched_env.env_ids
        run_id = np.random.randint(
            low=0,
            high=np.iinfo(np.int64).max,
            size=env_batch_size,
            dtype=np.int64)
        observation = batched_env.reset()
        reward = np.zeros(env_batch_size, np.float32)
        raw_reward = np.zeros(env_batch_size, np.float32)
        done = np.zeros(env_batch_size, np.bool)
        abandoned = np.zeros(env_batch_size, np.bool)

        global_step = 0
        episode_step = np.zeros(env_batch_size, np.int32)
        episode_return = np.zeros(env_batch_size, np.float32)
        episode_raw_return = np.zeros(env_batch_size, np.float32)
        episode_step_sum = 0
        episode_return_sum = 0
        episode_raw_return_sum = 0
        episodes_in_report = 0

        elapsed_inference_s_timer = timer_cls('actor/elapsed_inference_s', 1000)
        last_log_time = timeit.default_timer()
        last_global_step = 0
        while True:
          tf.summary.experimental.set_step(actor_step)
          env_output = utils.EnvOutput(reward, done, observation,
                                       abandoned, episode_step)
          with elapsed_inference_s_timer:
            action = client.inference(env_id, run_id, env_output, raw_reward)
          with timer_cls('actor/elapsed_env_step_s', 1000):
            observation, reward, done, info = batched_env.step(action.numpy())
          if is_rendering_enabled:
            batched_env.render()
          for i in range(env_batch_size):
            episode_step[i] += 1
            episode_return[i] += reward[i]
            raw_reward[i] = float((info[i] or {}).get('score_reward',
                                                      reward[i]))
            episode_raw_return[i] += raw_reward[i]
            # If the info dict contains an entry abandoned=True and the
            # episode was ended (done=True), then we need to specially handle
            # the final transition as per the explanations below.
            abandoned[i] = (info[i] or {}).get('abandoned', False)
            assert done[i] if abandoned[i] else True
            if done[i]:
              # If the episode was abandoned, we need to report the final
              # transition including the final observation as if the episode has
              # not terminated yet. This way, learning algorithms can use the
              # transition for learning.
              if abandoned[i]:
                # We do not signal yet that the episode was abandoned. This will
                # happen for the transition from the terminal state to the
                # resetted state.
                assert env_batch_size == 1 and i == 0, (
                    'Mixing of batched and non-batched inference calls is not '
                    'yet supported')
                env_output = utils.EnvOutput(reward,
                                             np.array([False]), observation,
                                             np.array([False]), episode_step)
                with elapsed_inference_s_timer:
                  # action is ignored
                  client.inference(env_id, run_id, env_output, raw_reward)
                reward[i] = 0.0
                raw_reward[i] = 0.0

              # Periodically log statistics.
              current_time = timeit.default_timer()
              episode_step_sum += episode_step[i]
              episode_return_sum += episode_return[i]
              episode_raw_return_sum += episode_raw_return[i]
              global_step += episode_step[i]
              episodes_in_report += 1
              if current_time - last_log_time >= log_period:
                logging.info(
                    'Actor steps: %i, Return: %f Raw return: %f '
                    'Episode steps: %f, Speed: %f steps/s', global_step,
                    episode_return_sum / episodes_in_report,
                    episode_raw_return_sum / episodes_in_report,
                    episode_step_sum / episodes_in_report,
                    (global_step - last_global_step) /
                    (current_time - last_log_time))
                last_global_step = global_step
                episode_return_sum = 0
                episode_raw_return_sum = 0
                episode_step_sum = 0
                episodes_in_report = 0
                last_log_time = current_time

              episode_step[i] = 0
              episode_return[i] = 0
              episode_raw_return[i] = 0

          # Finally, we reset the episode which will report the transition
          # from the terminal state to the resetted state in the next loop
          # iteration (with zero rewards).
          with timer_cls('actor/elapsed_env_reset_s', 10):
            observation = batched_env.reset_if_done(done)

          if is_rendering_enabled and done[0]:
            batched_env.render()

          actor_step += 1
      except (tf.errors.UnavailableError, tf.errors.CancelledError):
        logging.info('Inference call failed. This is normal at the end of '
                     'training.')
        batched_env.close()
Exemplo n.º 10
0
def visualize(create_env_fn, create_agent_fn, create_optimizer_fn):
  print('Visualization launched...')

  settings = utils.init_learner_multi_host(1)
  strategy, hosts, training_strategy, encode, decode = settings

  env = create_env_fn(0)
  parametric_action_distribution = get_parametric_distribution_for_action_space(
    env.action_space)
  agent = create_agent_fn(env.action_space, env.observation_space,
                          parametric_action_distribution)
  optimizer, learning_rate_fn = create_optimizer_fn(1e9)

  env_output_specs = utils.EnvOutput(
    tf.TensorSpec([], tf.float32, 'reward'),
    tf.TensorSpec([], tf.bool, 'done'),
    tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                  'observation'),
    tf.TensorSpec([], tf.bool, 'abandoned'),
    tf.TensorSpec([], tf.int32, 'episode_step'),
  )
  action_specs = tf.TensorSpec(env.action_space.shape,
                               env.action_space.dtype, 'action')
  agent_input_specs = (action_specs, env_output_specs)
  # Initialize agent and variables.
  agent = create_agent_fn(env.action_space, env.observation_space,
                          parametric_action_distribution)
  initial_agent_state = agent.initial_state(1)
  agent_state_specs = tf.nest.map_structure(
    lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
  unroll_specs = [None]  # Lazy initialization.
  input_ = tf.nest.map_structure(
    lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
  input_ = encode(input_)

  with strategy.scope():
    @tf.function
    def create_variables(*args):
      return agent.get_action(*decode(args))

    initial_agent_output, _ = create_variables(*input_, initial_agent_state)

    if not hasattr(agent, 'entropy_cost'):
      mul = FLAGS.entropy_cost_adjustment_speed
      agent.entropy_cost_param = tf.Variable(
          tf.math.log(FLAGS.entropy_cost) / mul,
          # Without the constraint, the param gradient may get rounded to 0
          # for very small values.
          constraint=lambda v: tf.clip_by_value(v, -20 / mul, 20 / mul),
          trainable=True,
          dtype=tf.float32)
      agent.entropy_cost = lambda: tf.exp(mul * agent.entropy_cost_param)
    # Create optimizer.
    iter_frame_ratio = (
        FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats)
    final_iteration = int(
        math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
    optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)


    iterations = optimizer.iterations
    optimizer._create_hypers()
    optimizer._create_slots(agent.trainable_variables)

    # ON_READ causes the replicated variable to act as independent variables for
    # each replica.
    temp_grads = [
        tf.Variable(tf.zeros_like(v), trainable=False,
                    synchronization=tf.VariableSynchronization.ON_READ)
        for v in agent.trainable_variables
    ]

  agent_output_specs = tf.nest.map_structure(
    lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)

  if True:
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    ckpt.restore('seed_rl/checkpoints/agent_good_3m/ckpt-9').assert_consumed()

  def get_agent_action(obs):
    initial_agent_state = agent.initial_state(1)
    shaped_obs = tf.reshape(tf.convert_to_tensor(obs), shape=(1,)+env.observation_space.shape)
    initial_env_output = (tf.constant([1.]), tf.constant([False]), shaped_obs,
                          tf.constant([False]), tf.constant([1], dtype=tf.float32),)
    agent_out = agent(tf.zeros([0], dtype=tf.float32), initial_env_output,
                      initial_agent_state)
    return agent_out

  def run_episode(steps):
    mode = None
    obs = env.reset()
    rewards = []

    for _ in range(steps):
      agent_out, state = get_agent_action(obs)
      action = agent_out.action.numpy()[0]
      obs, rew, done, info = env.step(action)
      rewards.append(rew)

      if done:
        break

    reward = np.sum(rewards)
    print('reward: {0}'.format(reward))
    return reward

  all_rewards = []
  iter = 0

  while True:
    all_rewards.append(run_episode(250))
    if len(all_rewards) > 1000:
      all_rewards = all_rewards[-1000:]
    print('mean cum reward: {0}'.format(np.mean(all_rewards)))

    if iter % 10 == 0:
      env.save_replay()
      print('\n REPLAY SAVED\n')
    iter += 1

  print('Graceful termination')
  sys.exit(0)
    def test_ppo_training_step(self, batch_mode, use_agent_state):
        action_space = gym.spaces.Box(low=-1,
                                      high=1,
                                      shape=[128],
                                      dtype=np.float32)
        distribution = (
            parametric_distribution.
            get_parametric_distribution_for_action_space(action_space))
        training_agent = continuous_control_agent.ContinuousControlAgent(
            distribution)
        virtual_bs = 32
        unroll_length = 5
        batches_per_step = 4
        done = tf.zeros([unroll_length, virtual_bs], dtype=tf.bool)
        prev_actions = tf.reshape(
            tf.stack([
                action_space.sample()
                for _ in range(unroll_length * virtual_bs)
            ]), [unroll_length, virtual_bs, -1])
        env_outputs = utils.EnvOutput(
            reward=tf.random.uniform([unroll_length, virtual_bs]),
            done=done,
            observation=tf.zeros([unroll_length, virtual_bs, 128],
                                 dtype=tf.float32),
            abandoned=tf.zeros_like(done),
            episode_step=tf.ones([unroll_length, virtual_bs], dtype=tf.int32))
        if use_agent_state:
            core_state = tf.zeros([virtual_bs, 64])
        else:
            core_state = training_agent.initial_state(virtual_bs)
        agent_outputs, _ = training_agent((prev_actions, env_outputs),
                                          core_state,
                                          unroll=True)
        args = Unroll(core_state, prev_actions, env_outputs, agent_outputs)

        class DummyStrategy:
            def __init__(self):
                self.num_replicas_in_sync = 1

        loss_fn = generalized_onpolicy_loss.GeneralizedOnPolicyLoss(
            training_agent,
            popart.PopArt(running_statistics.FixedMeanStd(), compensate=False),
            distribution,
            ga_advantages.GAE(lambda_=0.9),
            policy_losses.ppo(0.9),
            discount_factor=0.99,
            regularizer=policy_regularizers.KLPolicyRegularizer(entropy=0.5),
            baseline_cost=0.5,
            max_abs_reward=None,
            frame_skip=1,
            reward_scaling=10)
        loss_fn.init()
        loss, logs = ppo_training_step_utils.ppo_training_step(
            epochs_per_step=8,
            loss_fn=loss_fn,
            args=args,
            batch_mode=batch_mode,
            training_strategy=DummyStrategy(),
            virtual_batch_size=virtual_bs,
            unroll_length=unroll_length - 1,
            batches_per_step=batches_per_step,
            clip_norm=50.,
            optimizer=tf.keras.optimizers.Adam(1e-3),
            logger=utils.ProgressLogger())
        del loss
        del logs
Exemplo n.º 12
0
def actor_loop(create_env_fn):
    """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """

    project = neptune.init('pmtest/marl-vtrace')
    experiment = DummyExperiment()

    if FLAGS.task == 0 and not FLAGS.is_local:
        # First actor logs winning rate.
        while True:
            time.sleep(5)
            experiments = project.get_experiments(tag=FLAGS.nonce)
            if len(experiments) == 0:
                logging.info('Experiment not found, retry...')
            else:
                experiment = experiments[-1]
                break

    log_period = 5
    log_period_growth = 1.05
    log_period_max = 600

    last_replay_time = timeit.default_timer()
    replay_period = 600
    replay_period_growth = 1.2
    replay_period_max = 3600

    env_batch_size = FLAGS.env_batch_size
    logging.info('Starting actor loop. Task: %r. Environment batch size: %r',
                 FLAGS.task, env_batch_size)
    is_rendering_enabled = FLAGS.render and FLAGS.task == 0
    if are_summaries_enabled():
        summary_writer = tf.summary.create_file_writer(os.path.join(
            FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
                                                       flush_millis=20000,
                                                       max_queue=1000)
        timer_cls = profiling.ExportingTimer
    else:
        summary_writer = tf.summary.create_noop_writer()
        timer_cls = utils.nullcontext

    actor_step = 0
    with summary_writer.as_default():
        while True:
            try:
                # Client to communicate with the learner.
                client = grpc.Client(FLAGS.server_address)

                batched_env = env_wrappers.BatchedEnvironment(
                    create_env_fn, env_batch_size, FLAGS.task * env_batch_size)

                env_id = batched_env.env_ids
                run_id = np.random.randint(low=0,
                                           high=np.iinfo(np.int64).max,
                                           size=env_batch_size,
                                           dtype=np.int64)
                observation = batched_env.reset()
                reward = np.zeros(env_batch_size, np.float32)
                raw_reward = np.zeros(env_batch_size, np.float32)
                done = np.zeros(env_batch_size, np.bool)
                abandoned = np.zeros(env_batch_size, np.bool)

                global_step = 0
                episode_step = np.zeros(env_batch_size, np.int32)
                episode_return = np.zeros(env_batch_size, np.float32)
                episode_raw_return = np.zeros(env_batch_size, np.float32)
                episode_step_sum = 0
                episode_return_sum = 0
                episode_raw_return_sum = 0
                episode_won = 0
                episodes_in_report = 0

                elapsed_inference_s_timer = timer_cls(
                    'actor/elapsed_inference_s', 1000)
                last_log_time = timeit.default_timer()
                last_global_step = 0
                while True:
                    tf.summary.experimental.set_step(actor_step)
                    env_output = utils.EnvOutput(reward, done, observation,
                                                 abandoned, episode_step)
                    with elapsed_inference_s_timer:
                        action = client.inference(env_id, run_id, env_output,
                                                  raw_reward)
                    with timer_cls('actor/elapsed_env_step_s', 1000):
                        observation, reward, done, info = batched_env.step(
                            action.numpy())
                    if is_rendering_enabled:
                        batched_env.render()
                    for i in range(env_batch_size):
                        episode_step[i] += 1
                        episode_return[i] += reward[i]
                        raw_reward[i] = float(
                            (info[i] or {}).get('score_reward', reward[i]))
                        episode_raw_return[i] += raw_reward[i]
                        # If the info dict contains an entry abandoned=True and the
                        # episode was ended (done=True), then we need to specially handle
                        # the final transition as per the explanations below.
                        abandoned[i] = (info[i] or {}).get('abandoned', False)
                        assert done[i] if abandoned[i] else True
                        if done[i]:
                            # If the episode was abandoned, we need to report the final
                            # transition including the final observation as if the episode has
                            # not terminated yet. This way, learning algorithms can use the
                            # transition for learning.
                            if abandoned[i]:
                                # We do not signal yet that the episode was abandoned. This will
                                # happen for the transition from the terminal state to the
                                # resetted state.
                                assert env_batch_size == 1 and i == 0, (
                                    'Mixing of batched and non-batched inference calls is not '
                                    'yet supported')
                                env_output = utils.EnvOutput(
                                    reward, np.array([False]), observation,
                                    np.array([False]), episode_step)
                                with elapsed_inference_s_timer:
                                    # action is ignored
                                    client.inference(env_id, run_id,
                                                     env_output, raw_reward)
                                reward[i] = 0.0
                                raw_reward[i] = 0.0

                            # Periodically log statistics.
                            current_time = timeit.default_timer()
                            episode_step_sum += episode_step[i]
                            episode_return_sum += episode_return[i]
                            episode_raw_return_sum += episode_raw_return[i]
                            global_step += episode_step[i]
                            episode_won += (info[i]
                                            or {}).get('battle_won', False)
                            episodes_in_report += 1

                            if FLAGS.task == 0 and \
                                    current_time - last_replay_time > replay_period:
                                replay_period = min(
                                    replay_period_max,
                                    replay_period * replay_period_growth)
                                last_replay_time = current_time
                                batched_env.envs[0].save_replay()

                            if current_time - last_log_time > log_period:
                                log_period = min(
                                    log_period_max,
                                    log_period * log_period_growth)
                                logging.info(
                                    'Actor steps: %i, Return: %f Raw return: %f '
                                    'Episode steps: %f, Speed: %f steps/s, Won: %.2f',
                                    global_step,
                                    episode_return_sum / episodes_in_report,
                                    episode_raw_return_sum /
                                    episodes_in_report,
                                    episode_step_sum / episodes_in_report,
                                    (global_step - last_global_step) /
                                    (current_time - last_log_time),
                                    episode_won / episodes_in_report)
                                tf.summary.scalar('episodes win rate',
                                                  episode_won /
                                                  episodes_in_report,
                                                  step=global_step)
                                if FLAGS.task == 0:
                                    experiment.log_metric(
                                        log_name='episode win rate',
                                        x=global_step,
                                        y=episode_won / episodes_in_report)

                                last_global_step = global_step
                                episode_return_sum = 0
                                episode_raw_return_sum = 0
                                episode_step_sum = 0
                                episode_won = 0
                                episodes_in_report = 0
                                last_log_time = current_time

                            episode_step[i] = 0
                            episode_return[i] = 0
                            episode_raw_return[i] = 0

                    # Finally, we reset the episode which will report the transition
                    # from the terminal state to the resetted state in the next loop
                    # iteration (with zero rewards).
                    with timer_cls('actor/elapsed_env_reset_s', 10):
                        observation = batched_env.reset_if_done(done)

                    if is_rendering_enabled and done[0]:
                        batched_env.render()

                    actor_step += 1
            except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
                logging.exception(e)
                batched_env.close()
Exemplo n.º 13
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner_multi_host(FLAGS.num_training_tpus)
    strategy, hosts, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    unroll_specs = [None]  # Lazy initialization.
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent.get_action(*decode(args))

        initial_agent_output, _ = create_variables(*input_,
                                                   initial_agent_state)

        if not hasattr(agent, 'entropy_cost'):
            mul = FLAGS.entropy_cost_adjustment_speed
            agent.entropy_cost_param = tf.Variable(
                tf.math.log(FLAGS.entropy_cost) / mul,
                # Without the constraint, the param gradient may get rounded to 0
                # for very small values.
                constraint=lambda v: tf.clip_by_value(v, -20 / mul, 20 / mul),
                trainable=True,
                dtype=tf.float32)
            agent.entropy_cost = lambda: tf.exp(mul * agent.entropy_cost_param)
        # Create optimizer.
        iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length *
                            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs[0],
                                            decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.run(compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.run(apply_gradients, (loss, ))

        getattr(
            agent, 'end_of_training_step_callback',
            lambda: logging.info('end_of_training_step_callback not found'))()

        logger.step_end(logs, training_strategy, iter_frame_ratio)

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer,
                                  starting_step=iterations * iter_frame_ratio)

    servers = []
    unroll_queues = []
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )

    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    def create_host(i, host, inference_devices):
        with tf.device(host):
            server = grpc.Server([FLAGS.server_address])

            store = utils.UnrollStore(
                FLAGS.num_envs, FLAGS.unroll_length,
                (action_specs, env_output_specs, agent_output_specs))
            env_run_ids = utils.Aggregator(
                FLAGS.num_envs, tf.TensorSpec([], tf.int64, 'run_ids'))
            env_infos = utils.Aggregator(FLAGS.num_envs, info_specs,
                                         'env_infos')

            # First agent state in an unroll.
            first_agent_states = utils.Aggregator(FLAGS.num_envs,
                                                  agent_state_specs,
                                                  'first_agent_states')

            # Current agent state and action.
            agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                            'agent_states')
            actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

            unroll_specs[0] = Unroll(agent_state_specs, *store.unroll_specs)
            unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs[0])

            def add_batch_size(ts):
                return tf.TensorSpec([FLAGS.inference_batch_size] +
                                     list(ts.shape), ts.dtype, ts.name)

            inference_specs = (
                tf.TensorSpec([], tf.int32, 'env_id'),
                tf.TensorSpec([], tf.int64, 'run_id'),
                env_output_specs,
                tf.TensorSpec([], tf.float32, 'raw_reward'),
            )
            inference_specs = tf.nest.map_structure(add_batch_size,
                                                    inference_specs)

            def create_inference_fn(inference_device):
                @tf.function(input_signature=inference_specs)
                def inference(env_ids, run_ids, env_outputs, raw_rewards):
                    # Reset the environments that had their first run or crashed.
                    previous_run_ids = env_run_ids.read(env_ids)
                    env_run_ids.replace(env_ids, run_ids)
                    reset_indices = tf.where(
                        tf.not_equal(previous_run_ids, run_ids))[:, 0]
                    envs_needing_reset = tf.gather(env_ids, reset_indices)
                    if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
                        tf.print('Environment ids needing reset:',
                                 envs_needing_reset)
                    env_infos.reset(envs_needing_reset)
                    store.reset(envs_needing_reset)
                    initial_agent_states = agent.initial_state(
                        tf.shape(envs_needing_reset)[0])
                    first_agent_states.replace(envs_needing_reset,
                                               initial_agent_states)
                    agent_states.replace(envs_needing_reset,
                                         initial_agent_states)
                    actions.reset(envs_needing_reset)

                    tf.debugging.assert_non_positive(
                        tf.cast(env_outputs.abandoned, tf.int32),
                        'Abandoned done states are not supported in VTRACE.')

                    # Update steps and return.
                    env_infos.add(env_ids,
                                  (0, env_outputs.reward, raw_rewards))
                    done_ids = tf.gather(env_ids,
                                         tf.where(env_outputs.done)[:, 0])
                    if i == 0:
                        info_queue.enqueue_many(env_infos.read(done_ids))
                    env_infos.reset(done_ids)
                    env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

                    # Inference.
                    prev_actions = parametric_action_distribution.postprocess(
                        actions.read(env_ids))
                    input_ = encode((prev_actions, env_outputs))
                    prev_agent_states = agent_states.read(env_ids)
                    with tf.device(inference_device):

                        @tf.function
                        def agent_inference(*args):
                            return agent(*decode(args),
                                         is_training=False,
                                         postprocess_action=False)

                        agent_outputs, curr_agent_states = agent_inference(
                            *input_, prev_agent_states)

                    # Append the latest outputs to the unroll and insert completed unrolls
                    # in queue.
                    completed_ids, unrolls = store.append(
                        env_ids, (prev_actions, env_outputs, agent_outputs))
                    unrolls = Unroll(first_agent_states.read(completed_ids),
                                     *unrolls)
                    unroll_queue.enqueue_many(unrolls)
                    first_agent_states.replace(
                        completed_ids, agent_states.read(completed_ids))

                    # Update current state.
                    agent_states.replace(env_ids, curr_agent_states)
                    actions.replace(env_ids, agent_outputs.action)
                    # Return environment actions to environments.
                    return parametric_action_distribution.postprocess(
                        agent_outputs.action)

                return inference

            with strategy.scope():
                server.bind(
                    [create_inference_fn(d) for d in inference_devices])
            server.start()
            unroll_queues.append(unroll_queue)
            servers.append(server)

    for i, (host, inference_devices) in enumerate(hosts):
        create_host(i, host, inference_devices)

    def dequeue(ctx):
        # Create batch (time major).
        env_outputs = tf.nest.map_structure(
            lambda *args: tf.stack(args), *[
                unroll_queues[ctx.input_pipeline_id].dequeue() for i in range(
                    ctx.get_per_replica_batch_size(FLAGS.batch_size))
            ])
        env_outputs = env_outputs._replace(
            prev_actions=utils.make_time_major(env_outputs.prev_actions),
            env_outputs=utils.make_time_major(env_outputs.env_outputs),
            agent_outputs=utils.make_time_major(env_outputs.agent_outputs))
        env_outputs = env_outputs._replace(
            env_outputs=encode(env_outputs.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(env_outputs)

    def dataset_fn(ctx):
        dataset = tf.data.Dataset.from_tensors(0).repeat(None)

        def _dequeue(_):
            return dequeue(ctx)

        return dataset.map(_dequeue,
                           num_parallel_calls=ctx.num_replicas_in_sync //
                           len(hosts))

    dataset = training_strategy.experimental_distribute_datasets_from_function(
        dataset_fn)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        # Save checkpoint.
        current_time = time.time()
        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    for server in servers:
        server.shutdown()
    for unroll_queue in unroll_queues:
        unroll_queue.close()
Exemplo n.º 14
0
def actor_loop(create_env_fn):
    """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """
    logging.info('Starting actor loop')
    if are_summaries_enabled():
        summary_writer = tf.summary.create_file_writer(os.path.join(
            FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
                                                       flush_millis=20000,
                                                       max_queue=1000)
        timer_cls = profiling.ExportingTimer
    else:
        summary_writer = tf.summary.create_noop_writer()
        timer_cls = utils.nullcontext

    actor_step = 0
    with summary_writer.as_default():
        while True:
            try:
                # Client to communicate with the learner.
                client = grpc.Client(FLAGS.server_address)

                env = create_env_fn(FLAGS.task)

                # Unique ID to identify a specific run of an actor.
                run_id = np.random.randint(np.iinfo(np.int64).max)
                observation = env.reset()
                reward = 0.0
                raw_reward = 0.0
                done = False
                abandoned = False

                global_step = 0
                episode_step = 0
                episode_step_sum = 0
                episode_return_sum = 0
                episode_raw_return_sum = 0
                episodes_in_report = 0

                elapsed_inference_s_timer = timer_cls(
                    'actor/elapsed_inference_s', 1000)
                last_log_time = timeit.default_timer()
                while True:
                    tf.summary.experimental.set_step(actor_step)
                    env_output = utils.EnvOutput(reward, done, observation,
                                                 abandoned, episode_step)
                    with elapsed_inference_s_timer:
                        action = client.inference(FLAGS.task, run_id,
                                                  env_output, raw_reward)
                    with timer_cls('actor/elapsed_env_step_s', 1000):
                        observation, reward, done, info = env.step(
                            action.numpy())
                    if is_rendering_enabled():
                        env.render()
                    episode_step += 1
                    episode_return_sum += reward
                    raw_reward = float((info
                                        or {}).get('score_reward', reward))
                    episode_raw_return_sum += raw_reward
                    # If the info dict contains an entry abandoned=True and the
                    # episode was ended (done=True), then we need to specially handle
                    # the final transition as per the explanations below.
                    abandoned = (info or {}).get('abandoned', False)
                    assert done if abandoned else True
                    if done:
                        # If the episode was abandoned, we need to report the final
                        # transition including the final observation as if the episode has
                        # not terminated yet. This way, learning algorithms can use the
                        # transition for learning.
                        if abandoned:
                            # We do not signal yet that the episode was abandoned. This will
                            # happen for the transition from the terminal state to the
                            # resetted state.
                            env_output = utils.EnvOutput(
                                reward, False, observation, False,
                                episode_step)
                            with elapsed_inference_s_timer:
                                action = client.inference(
                                    FLAGS.task, run_id, env_output, raw_reward)
                            reward = 0.0
                            raw_reward = 0.0

                        # Periodically log statistics.
                        current_time = timeit.default_timer()
                        episode_step_sum += episode_step
                        global_step += episode_step
                        episodes_in_report += 1
                        if current_time - last_log_time > 1:
                            logging.info(
                                'Actor steps: %i, Return: %f Raw return: %f Episode steps: %f',
                                global_step,
                                episode_return_sum / episodes_in_report,
                                episode_raw_return_sum / episodes_in_report,
                                episode_step_sum / episodes_in_report)
                            episode_return_sum = 0
                            episode_raw_return_sum = 0
                            episode_step_sum = 0
                            episodes_in_report = 0
                            last_log_time = current_time

                        # Finally, we reset the episode which will report the transition
                        # from the terminal state to the resetted state in the next loop
                        # iteration (with zero rewards).
                        with timer_cls('actor/elapsed_env_reset_s', 10):
                            observation = env.reset()
                            episode_step = 0
                        if is_rendering_enabled():
                            env.render()
                    actor_step += 1
            except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
                logging.exception(e)
                env.close()
Exemplo n.º 15
0
def learner_loop(create_env_fn, create_agent_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    # validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint)  #.assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint)  #.assert_consumed()
        last_ckpt_time = time.time()

    actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                     tf.TensorSpec([], tf.int64, 'run_ids'))

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions')

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        # tf.print("agent initial_agent_states",tf.reduce_mean(initial_agent_states))
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Inference.
        prev_actions = parametric_action_distribution.postprocess(
            actions.read(actor_ids))
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)
        tf.print("agent states", tf.reduce_mean(prev_agent_states),
                 tf.reduce_max(prev_agent_states))
        # def make_inference_fn(inference_device):
        #   def device_specific_inference_fn():
        #     with tf.device(inference_device):
        #       @tf.function
        #       def agent_inference(*args):
        #         return agent(*decode(args), is_training=False,
        #                      postprocess_action=False)

        #       return agent_inference(*input_, prev_agent_states)

        #   return device_specific_inference_fn

        # # Distribute the inference calls among the inference cores.
        # branch_index = inference_iteration.assign_add(1) % len(inference_devices)
        # agent_outputs, curr_agent_states = tf.switch_case(branch_index, {
        #     i: make_inference_fn(inference_device)
        #     for i, inference_device in enumerate(inference_devices)
        # })
        with tf.device(inference_devices[0]):
            agent_outputs, curr_agent_states = agent(
                *(decode((*input_, prev_agent_states))))
        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)

    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

    actor_step = 0
    with summary_writer.as_default():
        while True:
            try:

                env = create_env_fn(FLAGS.task)

                # Unique ID to identify a specific run of an actor.
                run_id = np.random.randint(np.iinfo(np.int64).max)
                observation = env.reset()
                reward = 0.0
                raw_reward = 0.0
                done = False

                episode_step = 0
                episode_return = 0
                episode_raw_return = 0

                while True:
                    tf.summary.experimental.set_step(actor_step)
                    env_output = utils.EnvOutput(
                        [tf.cast(reward, tf.float32)], [done],
                        tf.cast(observation[None], tf.float32))
                    with timer_cls('actor/elapsed_inference_s', 1000):
                        action = inference([FLAGS.task], [run_id], env_output,
                                           [raw_reward])[0]
                    with timer_cls('actor/elapsed_env_step_s', 1000):
                        observation, reward, done, info = env.step(
                            action.numpy())

                    # env.render()
                    episode_step += 1
                    episode_return += reward
                    raw_reward = float((info
                                        or {}).get('score_reward', reward))
                    episode_raw_return += raw_reward
                    if done:
                        logging.info('Return: %f Raw return: %f Steps: %i',
                                     episode_return, episode_raw_return,
                                     episode_step)
                        env.render()
                        with timer_cls('actor/elapsed_env_reset_s', 10):
                            observation = env.reset()
                            episode_step = 0
                            episode_return = 0
                            episode_raw_return = 0

                    actor_step += 1
            except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
                logging.exception(e)
                env.close()
Exemplo n.º 16
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0, FLAGS)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.nest.map_structure(
            lambda s: tf.TensorSpec(s.shape, s.dtype, 'observation'),
            env.observation_space.__dict__.get('spaces',
                                               env.observation_space)),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    target_agent = create_agent_fn(env.action_space, env.observation_space,
                                   parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    agent_input_specs = (action_specs, env_output_specs)

    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1, 1] + list(s.shape), s.dtype), agent_input_specs)
    input_no_time = tf.nest.map_structure(lambda t: t[0], input_)

    input_ = encode(input_ + (initial_agent_state, ))
    input_no_time = encode(input_no_time + (initial_agent_state, ))

    with strategy.scope():
        # Initialize variables
        def initialize_agent_variables(agent):
            if not hasattr(agent, 'entropy_cost'):
                mul = FLAGS.entropy_cost_adjustment_speed
                agent.entropy_cost_param = tf.Variable(
                    tf.math.log(FLAGS.entropy_cost) / mul,
                    # Without the constraint, the param gradient may get rounded to 0
                    # for very small values.
                    constraint=lambda v: tf.clip_by_value(
                        v, -20 / mul, 20 / mul),
                    trainable=True,
                    dtype=tf.float32)
                agent.entropy_cost = lambda: tf.exp(mul * agent.
                                                    entropy_cost_param)

            @tf.function
            def create_variables():
                return [
                    agent.get_action(*decode(input_no_time)),
                    agent.get_V(*decode(input_)),
                    agent.get_Q(*decode(input_), action=decode(input_[0]))
                ]

            create_variables()

        initialize_agent_variables(agent)
        initialize_agent_variables(target_agent)

        # Target network update
        @tf.function
        def update_target_agent(polyak):
            """Synchronizes training and target agent variables."""
            variables = agent.variables
            target_variables = target_agent.variables
            assert len(target_variables) == len(variables), (
                'Mismatch in number of net tensors: {} != {}'.format(
                    len(target_variables), len(variables)))
            for target_var, source_var in zip(target_variables, variables):
                target_var.assign(polyak * target_var +
                                  (1. - polyak) * source_var)

        update_target_agent(polyak=0.)  # copy weights

        # Create optimizer.
        iter_frame_ratio = (
            get_replay_insertion_batch_size(per_replica=False) *
            (FLAGS.her_window_length or FLAGS.unroll_length) *
            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, target_agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.run(compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.run(apply_gradients, (loss, ))

        getattr(
            agent, 'end_of_training_step_callback',
            lambda: logging.info('end_of_training_step_callback not found'))()

        logger.step_end(logs, training_strategy, iter_frame_ratio)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent,
                               target_agent=target_agent,
                               optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer,
                                  starting_step=iterations * iter_frame_ratio)

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(FLAGS.num_envs, FLAGS.her_window_length
                              or FLAGS.unroll_length,
                              (action_specs, env_output_specs, action_specs))
    env_run_ids = utils.Aggregator(FLAGS.num_envs,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size,
                                             unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    if FLAGS.her_window_length:
        replay_buffer = utils.HindsightExperienceReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            compute_reward_fn=env.compute_reward,
            unroll_length=FLAGS.unroll_length,
            importance_sampling_exponent=0.,
            substitution_probability=FLAGS.her_substitution_probability)
    else:
        replay_buffer = utils.PrioritizedReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            importance_sampling_exponent=0.)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1, dtype=tf.int64)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'env_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(env_ids, run_ids, env_outputs, raw_rewards):
        # Reset the environment that had their first run or crashed.
        previous_run_ids = env_run_ids.read(env_ids)
        env_run_ids.replace(env_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        envs_needing_reset = tf.gather(env_ids, reset_indices)
        if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
            tf.print('Environment ids needing reset:', envs_needing_reset)
        env_infos.reset(envs_needing_reset)
        store.reset(envs_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(envs_needing_reset)[0])
        first_agent_states.replace(envs_needing_reset, initial_agent_states)
        agent_states.replace(envs_needing_reset, initial_agent_states)
        actions.reset(envs_needing_reset)

        tf.debugging.assert_non_positive(
            tf.cast(env_outputs.abandoned, tf.int32),
            'Abandoned done states are not supported in SAC.')

        # Update steps and return.
        env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(env_infos.read(done_ids))
        env_infos.reset(done_ids)
        env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(env_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(env_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args), is_training=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = tf.cast(
            inference_iteration.assign_add(1) % len(inference_devices),
            tf.int32)
        agent_actions, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            env_ids, (prev_actions, env_outputs, agent_actions))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(env_ids, curr_agent_states)
        actions.replace(env_ids, agent_actions)

        # Return environment actions to environments.
        return agent_actions

    with strategy.scope():
        server.bind(inference)
    server.start()

    dataset = create_dataset(unroll_queue, replay_buffer, training_strategy,
                             FLAGS.batch_size, encode)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        tf.summary.scalar('buffer/unrolls_inserted',
                          replay_buffer.num_inserted)
        # log data from info_queue
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        if iterations.numpy() % FLAGS.update_target_every_n_step == 0:
            update_target_agent(FLAGS.polyak)
        # Save checkpoint.
        current_time = time.time()
        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
Exemplo n.º 17
0
def actor_loop(create_env_fn):
  """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """
  logging.info('Starting actor loop')
  if are_summaries_enabled():
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
        flush_millis=20000, max_queue=1000)
    timer_cls = profiling.ExportingTimer
  else:
    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

  actor_step = 0
  with summary_writer.as_default():
    while True:
      try:
        # Client to communicate with the learner.
        client = grpc.Client(FLAGS.server_address)

        env = create_env_fn(FLAGS.task)

        # Unique ID to identify a specific run of an actor.
        run_id = np.random.randint(np.iinfo(np.int64).max)
        run_id1 = np.random.randint(np.iinfo(np.int64).max)
        observation = env.reset()
        reward = 0.0
        raw_reward = 0.0
        done = False

        episode_step = 0
        episode_return = 0

        color_state = 0
        episode_end = False

        while True:
          tf.summary.experimental.set_step(actor_step)

          env_output = utils.EnvOutput(tf.cast(reward, tf.float32), done, tf.cast(observation, tf.float32))
          if color_state==0:
            with timer_cls('actor/elapsed_inference_s', 1000):
              action = client.inference(
                  FLAGS.task, run_id, env_output, reward)
              
            with timer_cls('actor/elapsed_env_step_s', 1000):
              observation, _reward, _done, info = env.step(action.numpy())

          else:
            with timer_cls('actor/elapsed_inference_s', 1000):
              action = client.inference(
                  int(FLAGS.num_actors/2+FLAGS.task), run_id1, env_output, reward)
            with timer_cls('actor/elapsed_env_step_s', 1000):
              observation, _reward, _done, info = env.step(action.numpy())

          episode_step += 1
          if _done:
            random_num_ = np.random.random()
            if random_num_>0.98:
              if is_rendering_enabled():
                env.render()

            with timer_cls('actor/elapsed_env_reset_s', 10):
              observation = env.reset()

            color_state = 0
          else:
            color_state = 1 - color_state

          if episode_end:
            # this color must be white
            assert color_state==1
            if random_num_>0.98:
              logging.info('Return: %f Steps: %i', episode_return, episode_step)
            episode_step = 0
            episode_return = 0

            done = episode_end
            reward = -reward
            
            episode_end=_done
          else:
            reward=_reward
            episode_end=_done
            done = episode_end
            episode_return+=reward

          actor_step += 1
          
      except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
        logging.exception(e)
        env.close()