コード例 #1
0
  def testReadWriteSpecs(self):
    logdir = FLAGS.test_tmpdir
    specs = {
        'a': tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
        'b': {
            'b_1': tf.TensorSpec(shape=(5,), dtype=tf.string),
            'b_2': tf.TensorSpec(shape=(5, 6), dtype=tf.int32),
        }
    }
    utils.write_specs(logdir, specs)
    # Now read and verify
    specs_read = utils.read_specs(logdir)

    def _check_equal(sp1, sp2):
      self.assertEqual(sp1, sp2)

    tf.nest.map_structure(_check_equal, specs, specs_read)
コード例 #2
0
def run_with_address(
    problem_type: framework_problem_type.ProblemType,
    listen_address: Text,
    hparams: Dict[Text, Any],
):
    """Runs the learner with the given problem type.

  Args:
    problem_type: An instance of `framework_problem_type.ProblemType`.
    listen_address: The network address on which to listen.
    hparams: A dict containing hyperparameter settings.
  """
    devices = device_lib.list_local_devices()
    logging.info('Found devices: %s', devices)
    devices = [d for d in devices if d.device_type == FLAGS.agent_device]
    assert devices, 'Could not find a device of type %s' % FLAGS.agent_device
    agent_device = devices[0].name
    logging.info('Using agent device: %s', agent_device)

    # Initialize agent, variables.
    specs = utils.read_specs(hparams['logdir'])
    flat_specs = [
        tf.TensorSpec.from_spec(s, str(i))
        for i, s in enumerate(tf.nest.flatten(specs))
    ]
    queue_capacity = FLAGS.queue_capacity or FLAGS.batch_size * 10
    queue = tf.queue.FIFOQueue(
        queue_capacity,
        [t.dtype for t in flat_specs],
        [t.shape for t in flat_specs],
    )
    agent = problem_type.get_agent()
    # Create dummy environment output of shape [num_timesteps, batch_size, ...].
    env_output = tf.nest.map_structure(
        lambda s: tf.zeros(
            list(s.shape)[0:1] + [FLAGS.batch_size] + list(s.shape)[1:], s.
            dtype), specs.env_output)
    init_observation = utils.get_row_nested_tensor(env_output.observation, 0)
    init_agent_state = agent.get_initial_state(init_observation,
                                               batch_size=FLAGS.batch_size)
    env_output = _convert_uint8_to_bfloat16(env_output)
    with tf.device(agent_device):
        agent(env_output, init_agent_state)

        # Create optimizer.

        if FLAGS.lr_decay_steps > 0 and FLAGS.lr_decay_rate < 1.:
            lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=FLAGS.learning_rate,
                decay_steps=FLAGS.lr_decay_steps,
                decay_rate=FLAGS.lr_decay_rate)
        else:
            lr_schedule = FLAGS.learning_rate
        optimizer = problem_type.get_optimizer(lr_schedule)
        # NOTE: `iterations` is a non-trainable variable which is managed by
        # optimizer (created inside optimizer as well as incremented by 1 on every
        # call to optimizer.minimize).
        iterations = optimizer.iterations
        study_loss_types = problem_type.get_study_loss_types()

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(actor_output):
            """Per-replica StepFn."""
            actor_output = tf.nest.pack_sequence_as(specs, actor_output)
            (initial_agent_state, env_output, actor_agent_output, actor_action,
             loss_type, info) = actor_output
            with tf.GradientTape() as tape:
                loss = loss_fns.compute_loss(
                    study_loss_types=study_loss_types,
                    current_batch_loss_type=loss_type,
                    agent=agent,
                    agent_state=initial_agent_state,
                    env_output=env_output,
                    actor_agent_output=actor_agent_output,
                    actor_action=actor_action,
                    num_steps=iterations)
            grads = tape.gradient(loss, agent.trainable_variables)
            if FLAGS.gradient_clip_norm > 0.:
                for i, g in enumerate(grads):
                    if g is not None:
                        grads[i] = tf.clip_by_norm(g, FLAGS.gradient_clip_norm)
            grad_norms = {}
            for var, grad in zip(agent.trainable_variables, grads):
                # For parameters which are initialized but not used for loss
                # computation, gradient tape would return None.
                if grad is not None:
                    grad_norms[var.name] = tf.norm(grad)
            optimizer.apply_gradients(zip(grads, agent.trainable_variables))
            return info, grad_norms

        return step_fn(next(iterator))

    ckpt_manager = _maybe_restore_from_ckpt(hparams['logdir'],
                                            agent=agent,
                                            optimizer=optimizer)
    server = _create_server(listen_address,
                            specs,
                            agent,
                            queue,
                            extra_variables=[iterations])
    logging.info('Starting gRPC server')
    server.start()

    dataset = tf.data.Dataset.from_tensors(0).repeat(None)
    dataset = dataset.map(lambda _: queue.dequeue())
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    # Transpose each batch to time-major order. This is relatively slow, so do
    # this work outside of the training loop.
    dataset = dataset.map(functools.partial(_transpose_batch, specs))
    dataset = dataset.apply(tf.data.experimental.copy_to_device(agent_device))
    with tf.device(agent_device):
        dataset = dataset.prefetch(1)
        iterator = iter(dataset)

    # Execute learning and track performance.
    summary_writer = tf.summary.create_file_writer(hparams['logdir'],
                                                   flush_millis=20000,
                                                   max_queue=1000)
    last_ckpt_time = time.time()
    with summary_writer.as_default():
        last_log_iterations = iterations.numpy()
        last_log_num_env_frames = iterations * hparams['iter_frame_ratio']
        last_log_time = time.time()
        while iterations < hparams['final_iteration']:
            logging.info('Iteration %d of %d', iterations + 1,
                         hparams['final_iteration'])

            # Save checkpoint at specified intervals or if no previous ckpt exists.
            current_time = time.time()
            if (current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs
                    or not ckpt_manager.latest_checkpoint):
                ckpt_manager.save(checkpoint_number=iterations)
                last_ckpt_time = current_time

            with utils.WallTimer() as wt:
                with tf.device(agent_device):
                    info, grad_norms = train_step(iterator)
            tf.summary.scalar('steps_summary/step_seconds',
                              wt.duration,
                              step=iterations)
            norm_summ_family = 'grad_norms/'
            for name, norm in grad_norms.items():
                tf.summary.scalar(norm_summ_family + name,
                                  norm,
                                  step=iterations)

            if current_time - last_log_time >= 120:
                num_env_frames = iterations.numpy(
                ) * hparams['iter_frame_ratio']
                num_frames_since = num_env_frames - last_log_num_env_frames
                num_iterations_since = iterations.numpy() - last_log_iterations
                elapsed_time = time.time() - last_log_time
                tf.summary.scalar(
                    'steps_summary/num_environment_frames_per_sec',
                    tf.cast(num_frames_since, tf.float32) / elapsed_time,
                    step=iterations)
                tf.summary.scalar('steps_summary/num_iterations_per_sec',
                                  tf.cast(num_iterations_since, tf.float32) /
                                  elapsed_time,
                                  step=iterations)
                tf.summary.scalar('queue_size', queue.size(), step=iterations)
                tf.summary.scalar('learning_rate',
                                  optimizer._decayed_lr(var_dtype=tf.float32),
                                  step=iterations)
                last_log_num_env_frames, last_log_iterations, last_log_time = (
                    num_env_frames, iterations.numpy(), time.time())
                logging.info('Number of environment frames: %d',
                             num_env_frames)

            problem_type.create_summary(step=iterations, info=info)

    # Finishing up.
    ckpt_manager.save(checkpoint_number=iterations)
    queue.close()
    server.shutdown()