def test_duplicate_actor_id(self): store = utils.UnrollStore(num_actors=2, unroll_length=3, timestep_specs=tf.TensorSpec([], tf.int32)) with self.assertRaises(tf.errors.InvalidArgumentError): store.append(tf.constant([2, 2], dtype=tf.int32), tf.constant([42, 43], dtype=tf.int32))
def test_structure(self): named_tuple = collections.namedtuple('named_tuple', 'x y') num_actors = 2 unroll_length = 10 store = utils.UnrollStore(num_actors=num_actors, unroll_length=unroll_length, timestep_specs=named_tuple( x=tf.TensorSpec([], tf.int32), y=tf.TensorSpec([], tf.int32))) for _ in range(unroll_length): completed_ids, unrolls = store.append( tf.range(num_actors), named_tuple(tf.zeros([num_actors], tf.int32), tf.zeros([num_actors], tf.int32))) self.assertAllEqual(tf.constant(()), completed_ids) self.assertAllEqual( named_tuple(tf.zeros([0, unroll_length + 1]), tf.zeros([0, unroll_length + 1])), unrolls) completed_ids, unrolls = store.append( tf.range(num_actors), named_tuple(tf.zeros([num_actors], tf.int32), tf.zeros([num_actors], tf.int32))) self.assertAllEqual(tf.range(num_actors), completed_ids) self.assertAllEqual( named_tuple(tf.zeros([num_actors, unroll_length + 1]), tf.zeros([num_actors, unroll_length + 1])), unrolls)
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn): """Main learner loop. Args: create_env_fn: Callable that must return a newly created environment. The callable takes the task ID as argument - an arbitrary task ID of 0 will be passed by the learner. The returned environment should follow GYM's API. It is only used for infering tensor shapes. This environment will not be used to generate experience. create_agent_fn: Function that must create a new tf.Module with the neural network that outputs actions and new agent state given the environment observations and previous agent state. See dmlab.agents.ImpalaDeep for an example. The factory function takes as input the environment action and observation spaces and a parametric distribution over actions. create_optimizer_fn: Function that takes the final iteration as argument and must return a tf.keras.optimizers.Optimizer and a tf.keras.optimizers.schedules.LearningRateSchedule. """ logging.info('Starting learner loop') validate_config() settings = utils.init_learner(FLAGS.num_training_tpus) strategy, inference_devices, training_strategy, encode, decode = settings env = create_env_fn(0, FLAGS) parametric_action_distribution = get_parametric_distribution_for_action_space( env.action_space) env_output_specs = utils.EnvOutput( tf.TensorSpec([], tf.float32, 'reward'), tf.TensorSpec([], tf.bool, 'done'), tf.nest.map_structure( lambda s: tf.TensorSpec(s.shape, s.dtype, 'observation'), env.observation_space.__dict__.get('spaces', env.observation_space)), tf.TensorSpec([], tf.bool, 'abandoned'), tf.TensorSpec([], tf.int32, 'episode_step'), ) action_specs = tf.TensorSpec(env.action_space.shape, env.action_space.dtype, 'action') # Initialize agent and variables. agent = create_agent_fn(env.action_space, env.observation_space, parametric_action_distribution) target_agent = create_agent_fn(env.action_space, env.observation_space, parametric_action_distribution) initial_agent_state = agent.initial_state(1) agent_state_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state) agent_input_specs = (action_specs, env_output_specs) input_ = tf.nest.map_structure( lambda s: tf.zeros([1, 1] + list(s.shape), s.dtype), agent_input_specs) input_no_time = tf.nest.map_structure(lambda t: t[0], input_) input_ = encode(input_ + (initial_agent_state, )) input_no_time = encode(input_no_time + (initial_agent_state, )) with strategy.scope(): # Initialize variables def initialize_agent_variables(agent): if not hasattr(agent, 'entropy_cost'): mul = FLAGS.entropy_cost_adjustment_speed agent.entropy_cost_param = tf.Variable( tf.math.log(FLAGS.entropy_cost) / mul, # Without the constraint, the param gradient may get rounded to 0 # for very small values. constraint=lambda v: tf.clip_by_value( v, -20 / mul, 20 / mul), trainable=True, dtype=tf.float32) agent.entropy_cost = lambda: tf.exp(mul * agent. entropy_cost_param) @tf.function def create_variables(): return [ agent.get_action(*decode(input_no_time)), agent.get_V(*decode(input_)), agent.get_Q(*decode(input_), action=decode(input_[0])) ] create_variables() initialize_agent_variables(agent) initialize_agent_variables(target_agent) # Target network update @tf.function def update_target_agent(polyak): """Synchronizes training and target agent variables.""" variables = agent.variables target_variables = target_agent.variables assert len(target_variables) == len(variables), ( 'Mismatch in number of net tensors: {} != {}'.format( len(target_variables), len(variables))) for target_var, source_var in zip(target_variables, variables): target_var.assign(polyak * target_var + (1. - polyak) * source_var) update_target_agent(polyak=0.) # copy weights # Create optimizer. iter_frame_ratio = ( get_replay_insertion_batch_size(per_replica=False) * (FLAGS.her_window_length or FLAGS.unroll_length) * FLAGS.num_action_repeats) final_iteration = int( math.ceil(FLAGS.total_environment_frames / iter_frame_ratio)) optimizer, learning_rate_fn = create_optimizer_fn(final_iteration) iterations = optimizer.iterations optimizer._create_hypers() optimizer._create_slots(agent.trainable_variables) # ON_READ causes the replicated variable to act as independent variables for # each replica. temp_grads = [ tf.Variable(tf.zeros_like(v), trainable=False, synchronization=tf.VariableSynchronization.ON_READ) for v in agent.trainable_variables ] @tf.function def minimize(iterator): data = next(iterator) def compute_gradients(args): args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data)) with tf.GradientTape() as tape: loss, logs = compute_loss(logger, parametric_action_distribution, agent, target_agent, *args) grads = tape.gradient(loss, agent.trainable_variables) for t, g in zip(temp_grads, grads): t.assign(g) return loss, logs loss, logs = training_strategy.run(compute_gradients, (data, )) loss = training_strategy.experimental_local_results(loss)[0] def apply_gradients(_): optimizer.apply_gradients( zip(temp_grads, agent.trainable_variables)) strategy.run(apply_gradients, (loss, )) getattr( agent, 'end_of_training_step_callback', lambda: logging.info('end_of_training_step_callback not found'))() logger.step_end(logs, training_strategy, iter_frame_ratio) # Setup checkpointing and restore checkpoint. ckpt = tf.train.Checkpoint(agent=agent, target_agent=target_agent, optimizer=optimizer) if FLAGS.init_checkpoint is not None: tf.print('Loading initial checkpoint from %s...' % FLAGS.init_checkpoint) ckpt.restore(FLAGS.init_checkpoint).assert_consumed() manager = tf.train.CheckpointManager(ckpt, FLAGS.logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6) last_ckpt_time = 0 # Force checkpointing of the initial model. if manager.latest_checkpoint: logging.info('Restoring checkpoint: %s', manager.latest_checkpoint) ckpt.restore(manager.latest_checkpoint).assert_consumed() last_ckpt_time = time.time() # Logging. summary_writer = tf.summary.create_file_writer(FLAGS.logdir, flush_millis=20000, max_queue=1000) logger = utils.ProgressLogger(summary_writer=summary_writer, starting_step=iterations * iter_frame_ratio) server = grpc.Server([FLAGS.server_address]) store = utils.UnrollStore(FLAGS.num_envs, FLAGS.her_window_length or FLAGS.unroll_length, (action_specs, env_output_specs, action_specs)) env_run_ids = utils.Aggregator(FLAGS.num_envs, tf.TensorSpec([], tf.int64, 'run_ids')) info_specs = ( tf.TensorSpec([], tf.int64, 'episode_num_frames'), tf.TensorSpec([], tf.float32, 'episode_returns'), tf.TensorSpec([], tf.float32, 'episode_raw_returns'), ) env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos') # First agent state in an unroll. first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'first_agent_states') # Current agent state and action. agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'agent_states') actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions') unroll_specs = Unroll(agent_state_specs, *store.unroll_specs) unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size, unroll_specs) info_queue = utils.StructuredFIFOQueue(-1, info_specs) if FLAGS.her_window_length: replay_buffer = utils.HindsightExperienceReplay( FLAGS.replay_buffer_size, unroll_specs, compute_reward_fn=env.compute_reward, unroll_length=FLAGS.unroll_length, importance_sampling_exponent=0., substitution_probability=FLAGS.her_substitution_probability) else: replay_buffer = utils.PrioritizedReplay( FLAGS.replay_buffer_size, unroll_specs, importance_sampling_exponent=0.) def add_batch_size(ts): return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape), ts.dtype, ts.name) inference_iteration = tf.Variable(-1, dtype=tf.int64) inference_specs = ( tf.TensorSpec([], tf.int32, 'env_id'), tf.TensorSpec([], tf.int64, 'run_id'), env_output_specs, tf.TensorSpec([], tf.float32, 'raw_reward'), ) inference_specs = tf.nest.map_structure(add_batch_size, inference_specs) @tf.function(input_signature=inference_specs) def inference(env_ids, run_ids, env_outputs, raw_rewards): # Reset the environment that had their first run or crashed. previous_run_ids = env_run_ids.read(env_ids) env_run_ids.replace(env_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] envs_needing_reset = tf.gather(env_ids, reset_indices) if tf.not_equal(tf.shape(envs_needing_reset)[0], 0): tf.print('Environment ids needing reset:', envs_needing_reset) env_infos.reset(envs_needing_reset) store.reset(envs_needing_reset) initial_agent_states = agent.initial_state( tf.shape(envs_needing_reset)[0]) first_agent_states.replace(envs_needing_reset, initial_agent_states) agent_states.replace(envs_needing_reset, initial_agent_states) actions.reset(envs_needing_reset) tf.debugging.assert_non_positive( tf.cast(env_outputs.abandoned, tf.int32), 'Abandoned done states are not supported in SAC.') # Update steps and return. env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0]) info_queue.enqueue_many(env_infos.read(done_ids)) env_infos.reset(done_ids) env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(env_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(env_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args), is_training=False) return agent_inference(*input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = tf.cast( inference_iteration.assign_add(1) % len(inference_devices), tf.int32) agent_actions, curr_agent_states = tf.switch_case( branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) # Append the latest outputs to the unroll and insert completed unrolls in # queue. completed_ids, unrolls = store.append( env_ids, (prev_actions, env_outputs, agent_actions)) unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace(completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(env_ids, curr_agent_states) actions.replace(env_ids, agent_actions) # Return environment actions to environments. return agent_actions with strategy.scope(): server.bind(inference) server.start() dataset = create_dataset(unroll_queue, replay_buffer, training_strategy, FLAGS.batch_size, encode) it = iter(dataset) def additional_logs(): tf.summary.scalar('learning_rate', learning_rate_fn(iterations)) tf.summary.scalar('buffer/unrolls_inserted', replay_buffer.num_inserted) # log data from info_queue n_episodes = info_queue.size() n_episodes -= n_episodes % FLAGS.log_episode_frequency if tf.not_equal(n_episodes, 0): episode_stats = info_queue.dequeue_many(n_episodes) episode_keys = [ 'episode_num_frames', 'episode_return', 'episode_raw_return' ] for key, values in zip(episode_keys, episode_stats): for value in tf.split( values, values.shape[0] // FLAGS.log_episode_frequency): tf.summary.scalar(key, tf.reduce_mean(value)) for (frames, ep_return, raw_return) in zip(*episode_stats): logging.info('Return: %f Raw return: %f Frames: %i', ep_return, raw_return, frames) logger.start(additional_logs) # Execute learning. while iterations < final_iteration: if iterations.numpy() % FLAGS.update_target_every_n_step == 0: update_target_agent(FLAGS.polyak) # Save checkpoint. current_time = time.time() if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs: manager.save() # Apart from checkpointing, we also save the full model (including # the graph). This way we can load it after the code/parameters changed. tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model')) last_ckpt_time = current_time minimize(it) logger.shutdown() manager.save() tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model')) server.shutdown() unroll_queue.close()
def test_full(self): store = utils.UnrollStore(num_actors=4, unroll_length=3, timestep_specs=tf.TensorSpec([], tf.int32)) def gen(): yield False, 0, 10 yield False, 2, 30 yield False, 1, 20 yield False, 0, 11 yield False, 2, 31 yield False, 3, 40 yield False, 0, 12 yield False, 2, 32 yield False, 3, 41 yield False, 0, 13 # Unroll: 10, 11, 12, 13 yield False, 1, 21 yield True, 2, 33 # No unroll because of reset yield False, 0, 14 yield False, 2, 34 yield False, 3, 42 yield False, 0, 15 yield False, 1, 22 yield False, 2, 35 yield False, 0, 16 # Unroll: 13, 14, 15, 16 yield False, 1, 23 # Unroll: 20, 21, 22, 23 yield False, 2, 36 # Unroll: 33, 34, 35, 36 dataset = tf.data.Dataset.from_generator(gen, (tf.bool, tf.int32, tf.int32), ([], [], [])) dataset = dataset.batch(3, drop_remainder=True) i = iter(dataset) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) self.assertAllEqual(tf.zeros([0, 4]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) self.assertAllEqual(tf.zeros([0, 4]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) self.assertAllEqual(tf.zeros([0, 4]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([0]), completed_ids) self.assertAllEqual(tf.constant([[10, 11, 12, 13]]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) self.assertAllEqual(tf.zeros([0, 4]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) self.assertAllEqual(tf.zeros([0, 4]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([0, 1, 2]), completed_ids) self.assertAllEqual( tf.constant([[13, 14, 15, 16], [20, 21, 22, 23], [33, 34, 35, 36]]), unrolls)
def test_overlap_2(self): store = utils.UnrollStore(num_actors=2, unroll_length=2, timestep_specs=tf.TensorSpec([], tf.int32), num_overlapping_steps=2) def gen(): yield False, 0, 10 yield False, 1, 20 yield False, 0, 11 yield False, 1, 21 yield False, 0, 12 # Unroll: 0, 0, 10, 11, 12 yield True, 1, 22 yield False, 0, 13 yield False, 1, 23 yield False, 0, 14 # Unroll: 10, 11, 12, 13, 14 yield False, 1, 24 # Unroll: 0, 0, 22, 23, 24 yield True, 0, 15 yield False, 1, 25 yield False, 0, 16 yield False, 1, 26 # Unroll: 22, 23, 24, 25, 26 yield False, 0, 17 # Unroll: 0, 0, 15, 16, 17 yield False, 1, 27 dataset = tf.data.Dataset.from_generator(gen, (tf.bool, tf.int32, tf.int32), ([], [], [])) dataset = dataset.batch(2, drop_remainder=True) i = iter(dataset) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([0]), completed_ids) self.assertAllEqual(tf.constant([[0, 0, 10, 11, 12]]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([0, 1]), completed_ids) self.assertAllEqual( tf.constant([[10, 11, 12, 13, 14], [0, 0, 22, 23, 24]]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.zeros([0]), completed_ids) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([1]), completed_ids) self.assertAllEqual(tf.constant([[22, 23, 24, 25, 26]]), unrolls) should_reset, actor_ids, values = next(i) store.reset(actor_ids[should_reset]) completed_ids, unrolls = store.append(actor_ids, values) self.assertAllEqual(tf.constant([0]), completed_ids) self.assertAllEqual(tf.constant([[0, 0, 15, 16, 17]]), unrolls)
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn, fps_log): """Main learner loop. Args: create_env_fn: Callable that must return a newly created environment. The callable takes the task ID as argument - an arbitrary task ID of 0 will be passed by the learner. The returned environment should follow GYM's API. It is only used for infering tensor shapes. This environment will not be used to generate experience. create_agent_fn: Function that must create a new tf.Module with the neural network that outputs actions and new agent state given the environment observations and previous agent state. See dmlab.agents.ImpalaDeep for an example. The factory function takes as input the environment action and observation spaces and a parametric distribution over actions. create_optimizer_fn: Function that takes the final iteration as argument and must return a tf.keras.optimizers.Optimizer and a tf.keras.optimizers.schedules.LearningRateSchedule. """ logging.info('Starting learner loop') validate_config() settings = utils.init_learner(FLAGS.num_training_tpus) strategy, inference_devices, training_strategy, encode, decode = settings env = create_env_fn(0) parametric_action_distribution = get_parametric_distribution_for_action_space( env.action_space) env_output_specs = utils.EnvOutput( tf.TensorSpec([], tf.float32, 'reward'), tf.TensorSpec([], tf.bool, 'done'), tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype, 'observation'), ) action_specs = tf.TensorSpec(env.action_space.shape, env.action_space.dtype, 'action') agent_input_specs = (action_specs, env_output_specs) # Initialize agent and variables. agent = create_agent_fn(env.action_space, env.observation_space, parametric_action_distribution) initial_agent_state = agent.initial_state(1) agent_state_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state) input_ = tf.nest.map_structure( lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs) input_ = encode(input_) with strategy.scope(): @tf.function def create_variables(*args): return agent(*decode(args)) initial_agent_output, _ = create_variables(input_, initial_agent_state) # Create optimizer. iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats) final_iteration = int( math.ceil(FLAGS.total_environment_frames / iter_frame_ratio)) optimizer, learning_rate_fn = create_optimizer_fn(final_iteration) iterations = optimizer.iterations optimizer._create_hypers() optimizer._create_slots(agent.trainable_variables) # ON_READ causes the replicated variable to act as independent variables for # each replica. temp_grads = [ tf.Variable(tf.zeros_like(v), trainable=False, synchronization=tf.VariableSynchronization.ON_READ) for v in agent.trainable_variables ] @tf.function def minimize(iterator): data = next(iterator) def compute_gradients(args): args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data)) with tf.GradientTape() as tape: loss, logs = compute_loss(parametric_action_distribution, agent, *args) grads = tape.gradient(loss, agent.trainable_variables) for t, g in zip(temp_grads, grads): t.assign(g) return loss, logs loss, logs = training_strategy.experimental_run_v2( compute_gradients, (data, )) loss = training_strategy.experimental_local_results(loss)[0] logs = training_strategy.experimental_local_results(logs) def apply_gradients(_): optimizer.apply_gradients( zip(temp_grads, agent.trainable_variables)) strategy.experimental_run_v2(apply_gradients, (loss, )) return logs agent_output_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output) # Logging. summary_writer = tf.summary.create_file_writer(FLAGS.logdir, flush_millis=20000, max_queue=1000) # Setup checkpointing and restore checkpoint. ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer) if FLAGS.init_checkpoint is not None: tf.print('Loading initial checkpoint from %s...' % FLAGS.init_checkpoint) ckpt.restore(FLAGS.init_checkpoint).assert_consumed() manager = tf.train.CheckpointManager(ckpt, FLAGS.logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6) last_ckpt_time = 0 # Force checkpointing of the initial model. if manager.latest_checkpoint: logging.info('Restoring checkpoint: %s', manager.latest_checkpoint) ckpt.restore(manager.latest_checkpoint).assert_consumed() last_ckpt_time = time.time() server = grpc.Server([FLAGS.server_address]) store = utils.UnrollStore( FLAGS.num_actors, FLAGS.unroll_length, (action_specs, env_output_specs, agent_output_specs)) actor_run_ids = utils.Aggregator(FLAGS.num_actors, tf.TensorSpec([], tf.int64, 'run_ids')) info_specs = ( tf.TensorSpec([], tf.int64, 'episode_num_frames'), tf.TensorSpec([], tf.float32, 'episode_returns'), tf.TensorSpec([], tf.float32, 'episode_raw_returns'), ) actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs, 'actor_infos') # First agent state in an unroll. first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs, 'first_agent_states') # Current agent state and action. agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs, 'agent_states') actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions') unroll_specs = Unroll(agent_state_specs, *store.unroll_specs) unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs) info_queue = utils.StructuredFIFOQueue(-1, info_specs) def add_batch_size(ts): return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape), ts.dtype, ts.name) inference_iteration = tf.Variable(-1) inference_specs = ( tf.TensorSpec([], tf.int32, 'actor_id'), tf.TensorSpec([], tf.int64, 'run_id'), env_output_specs, tf.TensorSpec([], tf.float32, 'raw_reward'), ) inference_specs = tf.nest.map_structure(add_batch_size, inference_specs) @tf.function(input_signature=inference_specs) def inference(actor_ids, run_ids, env_outputs, raw_rewards): # Reset the actors that had their first run or crashed. previous_run_ids = actor_run_ids.read(actor_ids) actor_run_ids.replace(actor_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] actors_needing_reset = tf.gather(actor_ids, reset_indices) if tf.not_equal(tf.shape(actors_needing_reset)[0], 0): tf.print('Actor ids needing reset:', actors_needing_reset) actor_infos.reset(actors_needing_reset) store.reset(actors_needing_reset) initial_agent_states = agent.initial_state( tf.shape(actors_needing_reset)[0]) first_agent_states.replace(actors_needing_reset, initial_agent_states) agent_states.replace(actors_needing_reset, initial_agent_states) actions.reset(actors_needing_reset) # Update steps and return. actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0]) info_queue.enqueue_many(actor_infos.read(done_ids)) actor_infos.reset(done_ids) actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(actor_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(actor_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args), is_training=True) return agent_inference(input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = inference_iteration.assign_add(1) % len( inference_devices) agent_outputs, curr_agent_states = tf.switch_case( branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) # Append the latest outputs to the unroll and insert completed unrolls in # queue. completed_ids, unrolls = store.append( actor_ids, (prev_actions, env_outputs, agent_outputs)) unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace(completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(actor_ids, curr_agent_states) actions.replace(actor_ids, agent_outputs.action) # Return environment actions to actors. return parametric_action_distribution.postprocess(agent_outputs.action) with strategy.scope(): server.bind(inference, batched=True) server.start() def dequeue(ctx): # Create batch (time major). actor_outputs = tf.nest.map_structure( lambda *args: tf.stack(args), *[ unroll_queue.dequeue() for i in range( ctx.get_per_replica_batch_size(FLAGS.batch_size)) ]) actor_outputs = actor_outputs._replace( prev_actions=utils.make_time_major(actor_outputs.prev_actions), env_outputs=utils.make_time_major(actor_outputs.env_outputs), agent_outputs=utils.make_time_major(actor_outputs.agent_outputs)) actor_outputs = actor_outputs._replace( env_outputs=encode(actor_outputs.env_outputs)) # tf.data.Dataset treats list leafs as tensors, so we need to flatten and # repack. return tf.nest.flatten(actor_outputs) def dataset_fn(ctx): dataset = tf.data.Dataset.from_tensors(0).repeat(None) return dataset.map(lambda _: dequeue(ctx), num_parallel_calls=ctx.num_replicas_in_sync) dataset = training_strategy.experimental_distribute_datasets_from_function( dataset_fn) it = iter(dataset) # Execute learning and track performance. with summary_writer.as_default(), \ concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: log_future = executor.submit(lambda: None) # No-op future. last_num_env_frames = iterations * iter_frame_ratio last_log_time = time.time() values_to_log = collections.defaultdict(lambda: []) while iterations < final_iteration: num_env_frames = iterations * iter_frame_ratio tf.summary.experimental.set_step(num_env_frames) # Save checkpoint. current_time = time.time() if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs: manager.save() # Apart from checkpointing, we also save the full model (including # the graph). This way we can load it after the code/parameters changed. tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model')) last_ckpt_time = current_time def log(iterations, num_env_frames): """Logs batch and episodes summaries.""" nonlocal last_num_env_frames, last_log_time summary_writer.set_as_default() tf.summary.experimental.set_step(num_env_frames) # log data from the current minibatch if iterations % FLAGS.log_batch_frequency == 0: for key, values in copy.deepcopy(values_to_log).items(): tf.summary.scalar(key, tf.reduce_mean(values)) values_to_log.clear() tf.summary.scalar('learning_rate', learning_rate_fn(iterations)) # log the number of frames per second dt = time.time() - last_log_time if dt > 60: df = tf.cast(num_env_frames - last_num_env_frames, tf.float32) tf.summary.scalar('num_environment_frames/sec', df / dt) fps_log.logger.info('FPS: %f', df / dt) print(f'FPS: {df / dt}') last_num_env_frames, last_log_time = num_env_frames, time.time( ) # log data from info_queue n_episodes = info_queue.size() n_episodes -= n_episodes % FLAGS.log_episode_frequency if tf.not_equal(n_episodes, 0): episode_stats = info_queue.dequeue_many(n_episodes) episode_keys = [ 'episode_num_frames', 'episode_return', 'episode_raw_return' ] for key, values in zip(episode_keys, episode_stats): for value in tf.split( values, values.shape[0] // FLAGS.log_episode_frequency): tf.summary.scalar(key, tf.reduce_mean(value)) for (frames, ep_return, raw_return) in zip(*episode_stats): logging.info('Return: %f Raw return: %f Frames: %i', ep_return, raw_return, frames) logs = minimize(it) for per_replica_logs in logs: assert len(log_keys) == len(per_replica_logs) for key, value in zip(log_keys, per_replica_logs): values_to_log[key].extend( x.numpy() for x in training_strategy.experimental_local_results( value)) log_future.result() # Raise exception if any occurred in logging. log_future = executor.submit(log, iterations, num_env_frames) manager.save() tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model')) server.shutdown() unroll_queue.close()
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn): """Main learner loop. Args: create_env_fn: Callable that must return a newly created environment. The callable takes the task ID as argument - an arbitrary task ID of 0 will be passed by the learner. The returned environment should follow GYM's API. It is only used for infering tensor shapes. This environment will not be used to generate experience. create_agent_fn: Function that must create a new tf.Module with the neural network that outputs actions, Q values and new agent state given the environment observations and previous agent state. See atari.agents.DuelingLSTMDQNNet for an example. The factory function takes as input the environment output specs and the number of possible actions in the env. create_optimizer_fn: Function that takes the final iteration as argument and must return a tf.keras.optimizers.Optimizer and a tf.keras.optimizers.schedules.LearningRateSchedule. """ logging.info('Starting learner loop') validate_config() settings = utils.init_learner(FLAGS.num_training_tpus) strategy, inference_devices, training_strategy, encode, decode = settings env = create_env_fn(0) env_output_specs = utils.EnvOutput( tf.TensorSpec([], tf.float32, 'reward'), tf.TensorSpec([], tf.bool, 'done'), tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype, 'observation'), tf.TensorSpec([], tf.bool, 'abandoned'), tf.TensorSpec([], tf.int32, 'episode_step'), ) action_specs = tf.TensorSpec([], tf.int32, 'action') num_actions = env.action_space.n agent_input_specs = (action_specs, env_output_specs) # Initialize agent and variables. agent = create_agent_fn(env_output_specs, num_actions) target_agent = create_agent_fn(env_output_specs, num_actions) initial_agent_state = agent.initial_state(1) agent_state_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state) input_ = tf.nest.map_structure( lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs) input_ = encode(input_) with strategy.scope(): @tf.function def create_variables(*args): return agent(*decode(args)) @tf.function def create_target_agent_variables(*args): return target_agent(*decode(args)) # The first call to Keras models to create varibales for agent and target. initial_agent_output, _ = create_variables(input_, initial_agent_state) create_target_agent_variables(input_, initial_agent_state) @tf.function def update_target_agent(): """Synchronizes training and target agent variables.""" variables = agent.trainable_variables target_variables = target_agent.trainable_variables assert len(target_variables) == len(variables), ( 'Mismatch in number of net tensors: {} != {}'.format( len(target_variables), len(variables))) for target_var, source_var in zip(target_variables, variables): target_var.assign(source_var) # Create optimizer. iter_frame_ratio = (get_replay_insertion_batch_size() * FLAGS.unroll_length * FLAGS.num_action_repeats) final_iteration = int( math.ceil(FLAGS.total_environment_frames / iter_frame_ratio)) optimizer, learning_rate_fn = create_optimizer_fn(final_iteration) iterations = optimizer.iterations optimizer._create_hypers() optimizer._create_slots(agent.trainable_variables) # ON_READ causes the replicated variable to act as independent variables for # each replica. temp_grads = [ tf.Variable(tf.zeros_like(v), trainable=False, synchronization=tf.VariableSynchronization.ON_READ) for v in agent.trainable_variables ] @tf.function def minimize(iterator): """Computes and applies gradients. Args: iterator: An iterator of distributed dataset that produces `PerReplica`. Returns: A tuple: - priorities, the new priorities. Shape <float32>[batch_size]. - indices, the indices for updating priorities. Shape <int32>[batch_size]. - gradient_norm_before_clip, a scalar. """ data = next(iterator) def compute_gradients(args): """A function to pass to `Strategy` for gradient computation.""" args = decode(args, data) args = tf.nest.pack_sequence_as(SampledUnrolls(unroll_specs, 0, 0), args) with tf.GradientTape() as tape: # loss: [batch_size] # priorities: [batch_size] loss, priorities = compute_loss_and_priorities( agent, target_agent, args.unrolls.agent_state, args.unrolls.prev_actions, args.unrolls.env_outputs, args.unrolls.agent_outputs, gamma=FLAGS.discounting, burn_in=FLAGS.burn_in) loss = tf.reduce_mean(loss * args.importance_weights) grads = tape.gradient(loss, agent.trainable_variables) gradient_norm_before_clip = tf.linalg.global_norm(grads) if FLAGS.clip_norm: grads, _ = tf.clip_by_global_norm( grads, FLAGS.clip_norm, use_norm=gradient_norm_before_clip) for t, g in zip(temp_grads, grads): t.assign(g) return loss, priorities, args.indices, gradient_norm_before_clip loss, priorities, indices, gradient_norm_before_clip = ( training_strategy.run(compute_gradients, (data, ))) loss = training_strategy.experimental_local_results(loss)[0] def apply_gradients(loss): optimizer.apply_gradients( zip(temp_grads, agent.trainable_variables)) return loss loss = strategy.run(apply_gradients, (loss, )) # convert PerReplica to a Tensor if not isinstance(priorities, tf.Tensor): priorities = tf.reshape(tf.stack(priorities.values), [-1]) indices = tf.reshape(tf.stack(indices.values), [-1]) gradient_norm_before_clip = tf.reshape( tf.stack(gradient_norm_before_clip.values), [-1]) gradient_norm_before_clip = tf.reduce_max( gradient_norm_before_clip) return loss, priorities, indices, gradient_norm_before_clip agent_output_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output) # Logging. summary_writer = tf.summary.create_file_writer(FLAGS.logdir, flush_millis=20000, max_queue=1000) # Setup checkpointing and restore checkpoint. ckpt = tf.train.Checkpoint(agent=agent, target_agent=target_agent, optimizer=optimizer) manager = tf.train.CheckpointManager(ckpt, FLAGS.logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6) last_ckpt_time = 0 # Force checkpointing of the initial model. if manager.latest_checkpoint: logging.info('Restoring checkpoint: %s', manager.latest_checkpoint) ckpt.restore(manager.latest_checkpoint).assert_consumed() last_ckpt_time = time.time() server = grpc.Server([FLAGS.server_address]) # Buffer of incomplete unrolls. Filled during inference with new transitions. # This only contains data from training environments. store = utils.UnrollStore( get_num_training_envs(), FLAGS.unroll_length, (action_specs, env_output_specs, agent_output_specs), num_overlapping_steps=FLAGS.burn_in) env_run_ids = utils.Aggregator(FLAGS.num_envs, tf.TensorSpec([], tf.int64, 'run_ids')) info_specs = ( tf.TensorSpec([], tf.int64, 'episode_num_frames'), tf.TensorSpec([], tf.float32, 'episode_returns'), tf.TensorSpec([], tf.float32, 'episode_raw_returns'), ) env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos') # First agent state in an unroll. first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'first_agent_states') # Current agent state and action. agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'agent_states') actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions') unroll_specs = Unroll(agent_state_specs, tf.TensorSpec([], tf.float32, 'priority'), *store.unroll_specs) # Queue of complete unrolls. Filled by the inference threads, and consumed by # the tf.data.Dataset thread. unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size, unroll_specs) episode_info_specs = EpisodeInfo( *(info_specs + (tf.TensorSpec([], tf.int32, 'env_ids'), ))) info_queue = utils.StructuredFIFOQueue(-1, episode_info_specs) replay_buffer = utils.PrioritizedReplay(FLAGS.replay_buffer_size, unroll_specs, FLAGS.importance_sampling_exponent) def add_batch_size(ts): return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape), ts.dtype, ts.name) inference_iteration = tf.Variable(-1, dtype=tf.int64) inference_specs = ( tf.TensorSpec([], tf.int32, 'env_id'), tf.TensorSpec([], tf.int64, 'run_id'), env_output_specs, tf.TensorSpec([], tf.float32, 'raw_reward'), ) inference_specs = tf.nest.map_structure(add_batch_size, inference_specs) @tf.function(input_signature=inference_specs) def inference(env_ids, run_ids, env_outputs, raw_rewards): """Agent inference. This evaluates the agent policy on the provided environment data (reward, done, observation), and store appropriate data to feed the main training loop. Args: env_ids: <int32>[inference_batch_size], the environment task IDs (in range [0, num_tasks)). run_ids: <int64>[inference_batch_size], the environment run IDs. Environment generates a random int64 run id at startup, so this can be used to detect the environment jobs that restarted. env_outputs: Follows env_output_specs, but with the inference_batch_size added as first dimension. These are the actual environment outputs (reward, done, observation). raw_rewards: <float32>[inference_batch_size], representing the raw reward of each step. Returns: A tensor <int32>[inference_batch_size] with one action for each environment. """ # Reset the environments that had their first run or crashed. previous_run_ids = env_run_ids.read(env_ids) env_run_ids.replace(env_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] envs_needing_reset = tf.gather(env_ids, reset_indices) if tf.not_equal(tf.shape(envs_needing_reset)[0], 0): tf.print('Environments needing reset:', envs_needing_reset) env_infos.reset(envs_needing_reset) store.reset( tf.gather(envs_needing_reset, tf.where(is_training_env(envs_needing_reset))[:, 0])) initial_agent_states = agent.initial_state( tf.shape(envs_needing_reset)[0]) first_agent_states.replace(envs_needing_reset, initial_agent_states) agent_states.replace(envs_needing_reset, initial_agent_states) actions.reset(envs_needing_reset) tf.debugging.assert_non_positive( tf.cast(env_outputs.abandoned, tf.int32), 'Abandoned done states are not supported in R2D2.') # Update steps and return. env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0]) done_episodes_info = env_infos.read(done_ids) info_queue.enqueue_many( EpisodeInfo(*(done_episodes_info + (done_ids, )))) env_infos.reset(done_ids) env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(env_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(env_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args)) return agent_inference(input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = tf.cast( inference_iteration.assign_add(1) % len(inference_devices), tf.int32) agent_outputs, curr_agent_states = tf.switch_case( branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) agent_outputs = agent_outputs._replace(action=apply_epsilon_greedy( agent_outputs.action, env_ids, get_num_training_envs(), FLAGS.num_eval_envs, FLAGS.eval_epsilon, num_actions)) # Append the latest outputs to the unroll, only for experience coming from # training environments (IDs < num_training_envs), and insert completed # unrolls in queue. # <int64>[num_training_envs] training_indices = tf.where(is_training_env(env_ids))[:, 0] training_env_ids = tf.gather(env_ids, training_indices) training_prev_actions, training_env_outputs, training_agent_outputs = ( tf.nest.map_structure(lambda s: tf.gather(s, training_indices), (prev_actions, env_outputs, agent_outputs))) append_to_store = (training_prev_actions, training_env_outputs, training_agent_outputs) completed_ids, completed_unrolls = store.append( training_env_ids, append_to_store) _, unrolled_env_outputs, unrolled_agent_outputs = completed_unrolls unrolled_agent_states = first_agent_states.read(completed_ids) # Only use the suffix of the unrolls that is actually used for training. The # prefix is only used for burn-in of agent state at training time. _, agent_outputs_suffix = utils.split_structure( utils.make_time_major(unrolled_agent_outputs), FLAGS.burn_in) _, env_outputs_suffix = utils.split_structure( utils.make_time_major(unrolled_env_outputs), FLAGS.burn_in) _, initial_priorities = compute_loss_and_priorities_from_agent_outputs( # We don't use the outputs from a separated target network for computing # initial priorities. agent_outputs_suffix, agent_outputs_suffix, env_outputs_suffix, agent_outputs_suffix, gamma=FLAGS.discounting) unrolls = Unroll(unrolled_agent_states, initial_priorities, *completed_unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace(completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(env_ids, curr_agent_states) actions.replace(env_ids, agent_outputs.action) # Return environment actions to environments. return agent_outputs.action with strategy.scope(): server.bind(inference) server.start() # Execute learning and track performance. with summary_writer.as_default(), \ concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: log_future = executor.submit(lambda: None) # No-op future. tf.summary.experimental.set_step(iterations * iter_frame_ratio) dataset = create_dataset(unroll_queue, replay_buffer, training_strategy, FLAGS.batch_size, FLAGS.priority_exponent, encode) it = iter(dataset) last_num_env_frames = iterations * iter_frame_ratio last_log_time = time.time() max_gradient_norm_before_clip = 0. while iterations < final_iteration: num_env_frames = iterations * iter_frame_ratio tf.summary.experimental.set_step(num_env_frames) if iterations.numpy() % FLAGS.update_target_every_n_step == 0: update_target_agent() # Save checkpoint. current_time = time.time() if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs: manager.save() last_ckpt_time = current_time def log(num_env_frames): """Logs environment summaries.""" summary_writer.set_as_default() tf.summary.experimental.set_step(num_env_frames) episode_info = info_queue.dequeue_many(info_queue.size()) for n, r, _, env_id in zip(*episode_info): is_training = is_training_env(env_id) logging.info( 'Return: %f Frames: %i Env id: %i (%s) Iteration: %i', r, n, env_id, 'training' if is_training else 'eval', iterations.numpy()) if not is_training: tf.summary.scalar('eval/episode_return', r) tf.summary.scalar('eval/episode_frames', n) log_future.result() # Raise exception if any occurred in logging. log_future = executor.submit(log, num_env_frames) _, priorities, indices, gradient_norm = minimize(it) replay_buffer.update_priorities(indices, priorities) # Max of gradient norms (before clipping) since last tf.summary export. max_gradient_norm_before_clip = max(gradient_norm.numpy(), max_gradient_norm_before_clip) if current_time - last_log_time >= 120: df = tf.cast(num_env_frames - last_num_env_frames, tf.float32) dt = time.time() - last_log_time tf.summary.scalar('num_environment_frames/sec (actors)', df / dt) tf.summary.scalar('num_environment_frames/sec (learner)', df / dt * FLAGS.replay_ratio) tf.summary.scalar('learning_rate', learning_rate_fn(iterations)) tf.summary.scalar('replay_buffer_num_inserted', replay_buffer.num_inserted) tf.summary.scalar('unroll_queue_size', unroll_queue.size()) last_num_env_frames, last_log_time = num_env_frames, time.time( ) tf.summary.histogram('updated_priorities', priorities) tf.summary.scalar('max_gradient_norm_before_clip', max_gradient_norm_before_clip) max_gradient_norm_before_clip = 0. manager.save() server.shutdown() unroll_queue.close()
def create_host(i, host, inference_devices): with tf.device(host): server = grpc.Server([FLAGS.server_address]) store = utils.UnrollStore( FLAGS.num_envs, FLAGS.unroll_length, (action_specs, env_output_specs, agent_output_specs)) env_run_ids = utils.Aggregator( FLAGS.num_envs, tf.TensorSpec([], tf.int64, 'run_ids')) env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos') # First agent state in an unroll. first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'first_agent_states') # Current agent state and action. agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs, 'agent_states') actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions') unroll_specs[0] = Unroll(agent_state_specs, *store.unroll_specs) unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs[0]) def add_batch_size(ts): return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape), ts.dtype, ts.name) inference_specs = ( tf.TensorSpec([], tf.int32, 'env_id'), tf.TensorSpec([], tf.int64, 'run_id'), env_output_specs, tf.TensorSpec([], tf.float32, 'raw_reward'), ) inference_specs = tf.nest.map_structure(add_batch_size, inference_specs) def create_inference_fn(inference_device): @tf.function(input_signature=inference_specs) def inference(env_ids, run_ids, env_outputs, raw_rewards): # Reset the environments that had their first run or crashed. previous_run_ids = env_run_ids.read(env_ids) env_run_ids.replace(env_ids, run_ids) reset_indices = tf.where( tf.not_equal(previous_run_ids, run_ids))[:, 0] envs_needing_reset = tf.gather(env_ids, reset_indices) if tf.not_equal(tf.shape(envs_needing_reset)[0], 0): tf.print('Environment ids needing reset:', envs_needing_reset) env_infos.reset(envs_needing_reset) store.reset(envs_needing_reset) initial_agent_states = agent.initial_state( tf.shape(envs_needing_reset)[0]) first_agent_states.replace(envs_needing_reset, initial_agent_states) agent_states.replace(envs_needing_reset, initial_agent_states) actions.reset(envs_needing_reset) tf.debugging.assert_non_positive( tf.cast(env_outputs.abandoned, tf.int32), 'Abandoned done states are not supported in VTRACE.') # Update steps and return. env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0]) if i == 0: info_queue.enqueue_many(env_infos.read(done_ids)) env_infos.reset(done_ids) env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(env_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(env_ids) with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args), is_training=False) agent_outputs, curr_agent_states = agent_inference( *input_, prev_agent_states) # Append the latest outputs to the unroll and insert completed unrolls # in queue. completed_ids, unrolls = store.append( env_ids, (prev_actions, env_outputs, agent_outputs)) unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace( completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(env_ids, curr_agent_states) actions.replace(env_ids, agent_outputs.action) # Return environment actions to environments. return agent_outputs.action return inference with strategy.scope(): server.bind( [create_inference_fn(d) for d in inference_devices]) server.start() unroll_queues.append(unroll_queue) servers.append(server)