def _unroll_neck_steps(self, env_output, initial_state): """Unrolls all timesteps and returns a list of outputs and a final state.""" unused_reward, done, observation, unused_info = env_output # Add current time_step and batch_size. self._current_num_timesteps = tf.shape(done)[0] self._current_batch_size = tf.shape(done)[1] torso_output = utils.batch_apply(self._torso, observation) # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are # same as trailing dimensions of `neck_state`. neck_state = initial_state reset_state = self._get_reset_state(observation, done, neck_state) neck_output_list = [] for timestep, d in enumerate(tf.unstack(done)): neck_input = utils.get_row_nested_tensor(torso_output, timestep) # If the episode ended, the neck state should be reset before the next # step. curr_timestep_reset_state = utils.get_row_nested_tensor( reset_state, timestep) neck_state = tf.nest.map_structure( lambda reset_state, state: tf.compat.v1.where(d, reset_state, state), curr_timestep_reset_state, neck_state) neck_output, neck_state = self._neck(neck_input, neck_state) neck_output_list.append(neck_output) return neck_output_list, neck_state
def call(self, env_output, neck_state): """Runs the entire episode given time-major tensors. Args: env_output: An `EnvOutput` tuple with following expectations: reward - Unused done - A boolean tensor of shape [num_timesteps, batch_size]. observation - A nested structure with individual tensors that have first two dimensions equal to [num_timesteps, batch_size] info - Unused neck_state: A tensor or nested structure with individual tensors that have first dimension equal to batch_size and no time dimension. Returns: An `AgentOutput` tuple with individual tensors that have first two dimensions equal to [num_timesteps, batch_size] """ unused_reward, done, observation, unused_info = env_output # Add current time_step and batch_size. self._current_num_timesteps = tf.shape(done)[0] self._current_batch_size = tf.shape(done)[1] torso_output = utils.batch_apply(self._torso, observation) # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are # same as trailing dimensions of `neck_state`. reset_state = self._get_reset_state(observation, done, neck_state) neck_output_list = [] for timestep, d in enumerate(tf.unstack(done)): neck_input = utils.get_row_nested_tensor(torso_output, timestep) # If the episode ended, the neck state should be reset before the next # step. curr_timestep_reset_state = utils.get_row_nested_tensor( reset_state, timestep) neck_state = tf.nest.map_structure( lambda reset_state, state: tf.compat.v1.where( d, reset_state, state), curr_timestep_reset_state, neck_state) neck_output, neck_state = self._neck(neck_input, neck_state) neck_output_list.append(neck_output) head_input = tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *neck_output_list) head_output = utils.batch_apply(self._head, head_input) assert isinstance(head_output, common.AgentOutput) return head_output, neck_state
def testGetRowNestedTensor(self): x = { 'a': tf.constant([[0., 0.], [1., 1.]]), 'b': { 'b_1': tf.ones(shape=(2, 3)) } } result = utils.get_row_nested_tensor(x, 1) np.testing.assert_array_almost_equal( np.array([1., 1.]), result['a'].numpy()) np.testing.assert_array_almost_equal( np.array([1., 1., 1.]), result['b']['b_1'].numpy())
def run_with_address( problem_type: framework_problem_type.ProblemType, listen_address: Text, hparams: Dict[Text, Any], ): """Runs the learner with the given problem type. Args: problem_type: An instance of `framework_problem_type.ProblemType`. listen_address: The network address on which to listen. hparams: A dict containing hyperparameter settings. """ devices = device_lib.list_local_devices() logging.info('Found devices: %s', devices) devices = [d for d in devices if d.device_type == FLAGS.agent_device] assert devices, 'Could not find a device of type %s' % FLAGS.agent_device agent_device = devices[0].name logging.info('Using agent device: %s', agent_device) # Initialize agent, variables. specs = utils.read_specs(hparams['logdir']) flat_specs = [ tf.TensorSpec.from_spec(s, str(i)) for i, s in enumerate(tf.nest.flatten(specs)) ] queue_capacity = FLAGS.queue_capacity or FLAGS.batch_size * 10 queue = tf.queue.FIFOQueue( queue_capacity, [t.dtype for t in flat_specs], [t.shape for t in flat_specs], ) agent = problem_type.get_agent() # Create dummy environment output of shape [num_timesteps, batch_size, ...]. env_output = tf.nest.map_structure( lambda s: tf.zeros( list(s.shape)[0:1] + [FLAGS.batch_size] + list(s.shape)[1:], s. dtype), specs.env_output) init_observation = utils.get_row_nested_tensor(env_output.observation, 0) init_agent_state = agent.get_initial_state(init_observation, batch_size=FLAGS.batch_size) env_output = _convert_uint8_to_bfloat16(env_output) with tf.device(agent_device): agent(env_output, init_agent_state) # Create optimizer. if FLAGS.lr_decay_steps > 0 and FLAGS.lr_decay_rate < 1.: lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=FLAGS.learning_rate, decay_steps=FLAGS.lr_decay_steps, decay_rate=FLAGS.lr_decay_rate) else: lr_schedule = FLAGS.learning_rate optimizer = problem_type.get_optimizer(lr_schedule) # NOTE: `iterations` is a non-trainable variable which is managed by # optimizer (created inside optimizer as well as incremented by 1 on every # call to optimizer.minimize). iterations = optimizer.iterations study_loss_types = problem_type.get_study_loss_types() @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(actor_output): """Per-replica StepFn.""" actor_output = tf.nest.pack_sequence_as(specs, actor_output) (initial_agent_state, env_output, actor_agent_output, actor_action, loss_type, info) = actor_output with tf.GradientTape() as tape: loss = loss_fns.compute_loss( study_loss_types=study_loss_types, current_batch_loss_type=loss_type, agent=agent, agent_state=initial_agent_state, env_output=env_output, actor_agent_output=actor_agent_output, actor_action=actor_action, num_steps=iterations) grads = tape.gradient(loss, agent.trainable_variables) if FLAGS.gradient_clip_norm > 0.: for i, g in enumerate(grads): if g is not None: grads[i] = tf.clip_by_norm(g, FLAGS.gradient_clip_norm) grad_norms = {} for var, grad in zip(agent.trainable_variables, grads): # For parameters which are initialized but not used for loss # computation, gradient tape would return None. if grad is not None: grad_norms[var.name] = tf.norm(grad) optimizer.apply_gradients(zip(grads, agent.trainable_variables)) return info, grad_norms return step_fn(next(iterator)) ckpt_manager = _maybe_restore_from_ckpt(hparams['logdir'], agent=agent, optimizer=optimizer) server = _create_server(listen_address, specs, agent, queue, extra_variables=[iterations]) logging.info('Starting gRPC server') server.start() dataset = tf.data.Dataset.from_tensors(0).repeat(None) dataset = dataset.map(lambda _: queue.dequeue()) dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) # Transpose each batch to time-major order. This is relatively slow, so do # this work outside of the training loop. dataset = dataset.map(functools.partial(_transpose_batch, specs)) dataset = dataset.apply(tf.data.experimental.copy_to_device(agent_device)) with tf.device(agent_device): dataset = dataset.prefetch(1) iterator = iter(dataset) # Execute learning and track performance. summary_writer = tf.summary.create_file_writer(hparams['logdir'], flush_millis=20000, max_queue=1000) last_ckpt_time = time.time() with summary_writer.as_default(): last_log_iterations = iterations.numpy() last_log_num_env_frames = iterations * hparams['iter_frame_ratio'] last_log_time = time.time() while iterations < hparams['final_iteration']: logging.info('Iteration %d of %d', iterations + 1, hparams['final_iteration']) # Save checkpoint at specified intervals or if no previous ckpt exists. current_time = time.time() if (current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs or not ckpt_manager.latest_checkpoint): ckpt_manager.save(checkpoint_number=iterations) last_ckpt_time = current_time with utils.WallTimer() as wt: with tf.device(agent_device): info, grad_norms = train_step(iterator) tf.summary.scalar('steps_summary/step_seconds', wt.duration, step=iterations) norm_summ_family = 'grad_norms/' for name, norm in grad_norms.items(): tf.summary.scalar(norm_summ_family + name, norm, step=iterations) if current_time - last_log_time >= 120: num_env_frames = iterations.numpy( ) * hparams['iter_frame_ratio'] num_frames_since = num_env_frames - last_log_num_env_frames num_iterations_since = iterations.numpy() - last_log_iterations elapsed_time = time.time() - last_log_time tf.summary.scalar( 'steps_summary/num_environment_frames_per_sec', tf.cast(num_frames_since, tf.float32) / elapsed_time, step=iterations) tf.summary.scalar('steps_summary/num_iterations_per_sec', tf.cast(num_iterations_since, tf.float32) / elapsed_time, step=iterations) tf.summary.scalar('queue_size', queue.size(), step=iterations) tf.summary.scalar('learning_rate', optimizer._decayed_lr(var_dtype=tf.float32), step=iterations) last_log_num_env_frames, last_log_iterations, last_log_time = ( num_env_frames, iterations.numpy(), time.time()) logging.info('Number of environment frames: %d', num_env_frames) problem_type.create_summary(step=iterations, info=info) # Finishing up. ckpt_manager.save(checkpoint_number=iterations) queue.close() server.shutdown()