def test_push_under_distribute_strategy(
            self, strategy: tf.distribute.Strategy) -> None:
        # Prepare nested variables under strategy scope to push into the server.
        with strategy.scope():
            variables = _create_nested_variable()
        logging.info('Variables: %s', variables)

        # Push the input to the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)  # pytype: disable=wrong-arg-types

        # Check the content of the server.
        self._assert_nested_variable_in_server()
Beispiel #2
0
  def _get_input_iterator(
      self, input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
      strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
    """Returns distributed dataset iterator.

    Args:
      input_fn: (params: dict) -> tf.data.Dataset.
      strategy: an instance of tf.distribute.Strategy.

    Returns:
      An iterator that yields input tensors.
    """

    if input_fn is None:
      return None
    # When training with multiple TPU workers, datasets needs to be cloned
    # across workers. Since Dataset instance cannot be cloned in eager mode,
    # we instead pass callable that returns a dataset.
    input_data = input_fn(self._params)
    return iter(strategy.experimental_distribute_dataset(input_data))
Beispiel #3
0
def run_train_loop(
        train_dataset_builder: ub.datasets.BaseDataset,
        validation_dataset_builder: Optional[ub.datasets.BaseDataset],
        test_dataset_builder: ub.datasets.BaseDataset, batch_size: int,
        eval_batch_size: int, model: tf.keras.Model,
        optimizer: tf.keras.optimizers.Optimizer, eval_frequency: int,
        log_frequency: int, trial_dir: Optional[str], train_steps: int,
        mode: str, strategy: tf.distribute.Strategy,
        metrics: Dict[str, Union[tf.keras.metrics.Metric,
                                 rm.metrics.KerasMetric]], hparams: Dict[str,
                                                                         Any]):
    """Train, possibly evaluate the model, and record metrics."""

    checkpoint_manager = None
    last_checkpoint_step = 0
    if trial_dir:
        # TODO(znado): add train_iterator to this once DistributedIterators are
        # checkpointable.
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                        trial_dir,
                                                        max_to_keep=None)
        checkpoint_path = tf.train.latest_checkpoint(trial_dir)
        if checkpoint_path:
            last_checkpoint_step = int(checkpoint_path.split('-')[-1])
            if last_checkpoint_step >= train_steps:
                # If we have already finished training, exit.
                logging.info(
                    'Training has already finished at step %d. Exiting.',
                    train_steps)
                return
            elif last_checkpoint_step > 0:
                # Restore from where we previously finished.
                checkpoint.restore(checkpoint_manager.latest_checkpoint)
                logging.info('Resuming training from step %d.',
                             last_checkpoint_step)

    train_dataset = train_dataset_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    train_iterator = iter(train_dataset)

    iterations_per_loop = min(eval_frequency, log_frequency)
    # We can only run `iterations_per_loop` steps at a time, because we cannot
    # checkpoint the model inside a tf.function.
    train_step_fn = _train_step_fn(model,
                                   optimizer,
                                   strategy,
                                   metrics,
                                   iterations_per_loop=iterations_per_loop)
    if trial_dir:
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(trial_dir, 'train'))
    else:
        train_summary_writer = None

    val_summary_writer = None
    test_summary_writer = None
    if mode == 'train_and_eval':
        (val_fn, val_dataset, val_summary_writer, test_fn, test_dataset,
         test_summary_writer) = eval_lib.setup_eval(
             validation_dataset_builder=validation_dataset_builder,
             test_dataset_builder=test_dataset_builder,
             batch_size=eval_batch_size,
             strategy=strategy,
             trial_dir=trial_dir,
             model=model,
             metrics=metrics)
    # Each call to train_step_fn will run iterations_per_loop steps.
    num_train_fn_steps = train_steps // iterations_per_loop
    # We are guaranteed that `last_checkpoint_step` will be divisible by
    # `iterations_per_loop` because that is how frequently we checkpoint.
    start_train_fn_step = last_checkpoint_step // iterations_per_loop
    for train_fn_step in range(start_train_fn_step, num_train_fn_steps):
        current_step = train_fn_step * iterations_per_loop
        # Checkpoint at the start of the step, before the training op is run.
        if (checkpoint_manager and current_step % eval_frequency == 0
                and current_step != last_checkpoint_step):
            checkpoint_manager.save(checkpoint_number=current_step)
        if mode == 'train_and_eval' and current_step % eval_frequency == 0:
            eval_lib.run_eval_epoch(
                val_fn,
                val_dataset,
                val_summary_writer,
                test_fn,
                test_dataset,
                test_summary_writer,
                current_step,
                hparams=None)  # Only write hparams on the last step.
        train_step_outputs = train_step_fn(train_iterator)
        if current_step % log_frequency == 0:
            _write_summaries(train_step_outputs, current_step,
                             train_summary_writer)
            train_step_outputs_np = {
                k: v.numpy()
                for k, v in train_step_outputs.items()
            }
            logging.info('Training metrics for step %d: %s', current_step,
                         train_step_outputs_np)

    if train_steps % iterations_per_loop != 0:
        remainder_train_step_fn = _train_step_fn(
            model,
            optimizer,
            strategy,
            metrics,
            iterations_per_loop=train_steps % iterations_per_loop)
        train_step_outputs = remainder_train_step_fn(train_iterator)

    # Always evaluate and record metrics at the end of training.
    _write_summaries(train_step_outputs, train_steps, train_summary_writer,
                     hparams)
    train_step_outputs_np = {
        k: v.numpy()
        for k, v in train_step_outputs.items()
    }
    logging.info('Training metrics for step %d: %s', current_step,
                 train_step_outputs_np)
    if mode == 'train_and_eval':
        eval_lib.run_eval_epoch(val_fn,
                                val_dataset,
                                val_summary_writer,
                                test_fn,
                                test_dataset,
                                test_summary_writer,
                                train_steps,
                                hparams=hparams)
    # Save checkpoint at the end of training.
    if checkpoint_manager:
        checkpoint_manager.save(checkpoint_number=train_steps)
Beispiel #4
0
def train(
    root_dir: Text,
    environment_name: Text,
    strategy: tf.distribute.Strategy,
    replay_buffer_server_address: Text,
    variable_container_server_address: Text,
    suite_load_fn: Callable[[Text],
                            py_environment.PyEnvironment] = suite_mujoco.load,
    # Training params
    learning_rate: float = 3e-4,
    batch_size: int = 256,
    num_iterations: int = 2000000,
    learner_iterations_per_call: int = 1) -> None:
  """Trains a DQN agent."""
  # Get the specs from the environment.
  logging.info('Training SAC with learning rate: %f', learning_rate)
  env = suite_load_fn(environment_name)
  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(env))

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    agent = _create_agent(
        train_step=train_step,
        observation_tensor_spec=observation_tensor_spec,
        action_tensor_spec=action_tensor_spec,
        time_step_tensor_spec=time_step_tensor_spec,
        learning_rate=learning_rate)

  # Create the policy saver which saves the initial model now, then it
  # periodically checkpoints the policy weigths.
  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  save_model_trigger = triggers.PolicySavedModelTrigger(
      saved_model_dir, agent, train_step, interval=1000)

  # Create the variable container.
  variables = {
      reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.push(variables)

  # Create the replay buffer.
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      server_address=replay_buffer_server_address)

  # Initialize the dataset.
  def experience_dataset_fn():
    with strategy.scope():
      return reverb_replay.as_dataset(
          sample_batch_size=batch_size, num_steps=2).prefetch(3)

  # Create the learner.
  learning_triggers = [
      save_model_trigger,
      triggers.StepPerSecondLogTrigger(train_step, interval=1000)
  ]
  sac_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=strategy)

  # Run the training loop.
  while train_step.numpy() < num_iterations:
    sac_learner.run(iterations=learner_iterations_per_call)
    variable_container.push(variables)