def testReadWriteSpecs(self): logdir = FLAGS.test_tmpdir specs = { 'a': tf.TensorSpec(shape=(2, 3), dtype=tf.float32), 'b': { 'b_1': tf.TensorSpec(shape=(5,), dtype=tf.string), 'b_2': tf.TensorSpec(shape=(5, 6), dtype=tf.int32), } } utils.write_specs(logdir, specs) # Now read and verify specs_read = utils.read_specs(logdir) def _check_equal(sp1, sp2): self.assertEqual(sp1, sp2) tf.nest.map_structure(_check_equal, specs, specs_read)
def run_with_address( problem_type: framework_problem_type.ProblemType, listen_address: Text, hparams: Dict[Text, Any], ): """Runs the learner with the given problem type. Args: problem_type: An instance of `framework_problem_type.ProblemType`. listen_address: The network address on which to listen. hparams: A dict containing hyperparameter settings. """ devices = device_lib.list_local_devices() logging.info('Found devices: %s', devices) devices = [d for d in devices if d.device_type == FLAGS.agent_device] assert devices, 'Could not find a device of type %s' % FLAGS.agent_device agent_device = devices[0].name logging.info('Using agent device: %s', agent_device) # Initialize agent, variables. specs = utils.read_specs(hparams['logdir']) flat_specs = [ tf.TensorSpec.from_spec(s, str(i)) for i, s in enumerate(tf.nest.flatten(specs)) ] queue_capacity = FLAGS.queue_capacity or FLAGS.batch_size * 10 queue = tf.queue.FIFOQueue( queue_capacity, [t.dtype for t in flat_specs], [t.shape for t in flat_specs], ) agent = problem_type.get_agent() # Create dummy environment output of shape [num_timesteps, batch_size, ...]. env_output = tf.nest.map_structure( lambda s: tf.zeros( list(s.shape)[0:1] + [FLAGS.batch_size] + list(s.shape)[1:], s. dtype), specs.env_output) init_observation = utils.get_row_nested_tensor(env_output.observation, 0) init_agent_state = agent.get_initial_state(init_observation, batch_size=FLAGS.batch_size) env_output = _convert_uint8_to_bfloat16(env_output) with tf.device(agent_device): agent(env_output, init_agent_state) # Create optimizer. if FLAGS.lr_decay_steps > 0 and FLAGS.lr_decay_rate < 1.: lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=FLAGS.learning_rate, decay_steps=FLAGS.lr_decay_steps, decay_rate=FLAGS.lr_decay_rate) else: lr_schedule = FLAGS.learning_rate optimizer = problem_type.get_optimizer(lr_schedule) # NOTE: `iterations` is a non-trainable variable which is managed by # optimizer (created inside optimizer as well as incremented by 1 on every # call to optimizer.minimize). iterations = optimizer.iterations study_loss_types = problem_type.get_study_loss_types() @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(actor_output): """Per-replica StepFn.""" actor_output = tf.nest.pack_sequence_as(specs, actor_output) (initial_agent_state, env_output, actor_agent_output, actor_action, loss_type, info) = actor_output with tf.GradientTape() as tape: loss = loss_fns.compute_loss( study_loss_types=study_loss_types, current_batch_loss_type=loss_type, agent=agent, agent_state=initial_agent_state, env_output=env_output, actor_agent_output=actor_agent_output, actor_action=actor_action, num_steps=iterations) grads = tape.gradient(loss, agent.trainable_variables) if FLAGS.gradient_clip_norm > 0.: for i, g in enumerate(grads): if g is not None: grads[i] = tf.clip_by_norm(g, FLAGS.gradient_clip_norm) grad_norms = {} for var, grad in zip(agent.trainable_variables, grads): # For parameters which are initialized but not used for loss # computation, gradient tape would return None. if grad is not None: grad_norms[var.name] = tf.norm(grad) optimizer.apply_gradients(zip(grads, agent.trainable_variables)) return info, grad_norms return step_fn(next(iterator)) ckpt_manager = _maybe_restore_from_ckpt(hparams['logdir'], agent=agent, optimizer=optimizer) server = _create_server(listen_address, specs, agent, queue, extra_variables=[iterations]) logging.info('Starting gRPC server') server.start() dataset = tf.data.Dataset.from_tensors(0).repeat(None) dataset = dataset.map(lambda _: queue.dequeue()) dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) # Transpose each batch to time-major order. This is relatively slow, so do # this work outside of the training loop. dataset = dataset.map(functools.partial(_transpose_batch, specs)) dataset = dataset.apply(tf.data.experimental.copy_to_device(agent_device)) with tf.device(agent_device): dataset = dataset.prefetch(1) iterator = iter(dataset) # Execute learning and track performance. summary_writer = tf.summary.create_file_writer(hparams['logdir'], flush_millis=20000, max_queue=1000) last_ckpt_time = time.time() with summary_writer.as_default(): last_log_iterations = iterations.numpy() last_log_num_env_frames = iterations * hparams['iter_frame_ratio'] last_log_time = time.time() while iterations < hparams['final_iteration']: logging.info('Iteration %d of %d', iterations + 1, hparams['final_iteration']) # Save checkpoint at specified intervals or if no previous ckpt exists. current_time = time.time() if (current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs or not ckpt_manager.latest_checkpoint): ckpt_manager.save(checkpoint_number=iterations) last_ckpt_time = current_time with utils.WallTimer() as wt: with tf.device(agent_device): info, grad_norms = train_step(iterator) tf.summary.scalar('steps_summary/step_seconds', wt.duration, step=iterations) norm_summ_family = 'grad_norms/' for name, norm in grad_norms.items(): tf.summary.scalar(norm_summ_family + name, norm, step=iterations) if current_time - last_log_time >= 120: num_env_frames = iterations.numpy( ) * hparams['iter_frame_ratio'] num_frames_since = num_env_frames - last_log_num_env_frames num_iterations_since = iterations.numpy() - last_log_iterations elapsed_time = time.time() - last_log_time tf.summary.scalar( 'steps_summary/num_environment_frames_per_sec', tf.cast(num_frames_since, tf.float32) / elapsed_time, step=iterations) tf.summary.scalar('steps_summary/num_iterations_per_sec', tf.cast(num_iterations_since, tf.float32) / elapsed_time, step=iterations) tf.summary.scalar('queue_size', queue.size(), step=iterations) tf.summary.scalar('learning_rate', optimizer._decayed_lr(var_dtype=tf.float32), step=iterations) last_log_num_env_frames, last_log_iterations, last_log_time = ( num_env_frames, iterations.numpy(), time.time()) logging.info('Number of environment frames: %d', num_env_frames) problem_type.create_summary(step=iterations, info=info) # Finishing up. ckpt_manager.save(checkpoint_number=iterations) queue.close() server.shutdown()