Ejemplo n.º 1
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(output_dir, "train"))
        self._eval_sw = jaxboard.SummaryWriter(os.path.join(
            output_dir, "eval"))

        # Reset the train and eval streams.
        self._train_stream = self._inputs.train_stream()
        # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
        #   set by adding a padding and stopping the stream when too large.
        self._eval_stream = _repeat_stream(self._inputs.eval_stream)
        self._train_eval_stream = _repeat_stream(
            self._inputs.train_eval_stream)

        # Restore the training state.
        state = restore_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._initialize()
            model_state = layers.nested_map(model_state, self._maybe_replicate)
        self._opt_state = OptState(
            *layers.nested_map(opt_state, self._maybe_replicate))
        self._model_state = model_state
        if not state.opt_state:
            self._maybe_save_state(keep=False)

        self.update_optimizer_params()
Ejemplo n.º 2
0
    def __init__(
            self,
            train_env,
            eval_env,
            output_dir,
            policy_trainer_class,
            n_real_epochs=10,
            data_eval_frac=0.125,
            model_train_batch_size=64,
            n_model_train_steps=1000,
            simulated_env_problem_class=(
                simulated_env_problem.SerializedSequenceSimulatedEnvProblem),
            simulated_batch_size=16,
            n_simulated_epochs=1000,
            trajectory_dump_dir=None,
            initial_trajectory_dir=None,
            initial_trajectory_mix_prob=0.5,
            **kwargs):
        super(SimPLe, self).__init__(train_env, eval_env, output_dir, **kwargs)
        self._policy_dir = os.path.join(output_dir, "policy")
        self._policy_trainer = policy_trainer_class(
            train_env=train_env,
            eval_env=eval_env,
            output_dir=self._policy_dir,
            async_mode=self._async_mode,
            async_mode_trajectory_subdir=self._async_mode_trajectory_subdir,
        )
        self._n_real_epochs = n_real_epochs
        self._model_train_batch_size = model_train_batch_size
        self._n_model_train_steps = n_model_train_steps
        self._data_eval_frac = data_eval_frac
        self._model_dir = os.path.join(output_dir, "model")
        self._sim_env = simulated_env_problem_class(
            batch_size=None,
            observation_space=train_env.observation_space,
            action_space=train_env.action_space,
            reward_range=train_env.reward_range,
            discrete_rewards=train_env.discrete_rewards,
            history_stream=None,  # TODO(pkozakowski): Support this.
            output_dir=self._model_dir,
        )
        self._simulated_batch_size = simulated_batch_size
        self._n_simulated_epochs = n_simulated_epochs

        # If trajectory_dump_dir is not provided explicitly, save the trajectories
        # in output_dir.
        if trajectory_dump_dir is None:
            trajectory_dump_dir = os.path.join(output_dir, "trajectories")
        self._trajectory_dump_root_dir = trajectory_dump_dir

        self._initial_trajectory_dir = initial_trajectory_dir
        self._initial_trajectory_mix_prob = initial_trajectory_mix_prob

        self._summary_writer = jaxboard.SummaryWriter(self._output_dir)

        self._simple_epoch = 0
        self._policy_epoch = 0
        self._model_train_step = 0
Ejemplo n.º 3
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(output_dir, "train"))
        self._eval_sw = jaxboard.SummaryWriter(os.path.join(
            output_dir, "eval"))

        # Reset the training stream.
        self._train_stream = self._inputs.train_stream()

        # Restore the training state.
        state = restore_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._initialize()
            model_state = layers.nested_map(model_state, self._maybe_replicate)
        self._opt_state = OptState(
            *layers.nested_map(opt_state, self._maybe_replicate))
        self._model_state = model_state
        if not state.opt_state:
            self._maybe_save_state(keep=False)

        self.update_learning_rate()
Ejemplo n.º 4
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=None,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      n_optimizer_steps: Number of optimizer steps.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in
        a trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures

        assert isinstance(self.train_env.action_space, gym.spaces.Discrete)
        n_actions = self.train_env.action_space.n

        # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
        # policy and value networks on shape [B, T] +_OBS
        batch_observations_shape = (1,
                                    1) + self.train_env.observation_space.shape
        observations_dtype = self.train_env.observation_space.dtype

        self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        # Initialize the policy and value network.
        policy_and_value_net_params, self._model_state, policy_and_value_net_apply = (
            ppo.policy_and_value_net(
                rng_key=key1,
                batch_observations_shape=batch_observations_shape,
                observations_dtype=observations_dtype,
                n_actions=n_actions,
                bottom_layers_fn=policy_and_value_model,
                two_towers=policy_and_value_two_towers,
            ))
        self._policy_and_value_net_apply = jit(policy_and_value_net_apply)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Maybe restore the optimization state. If there is nothing to restore, then
        # iteration = 0 and policy_and_value_opt_state is returned as is.
        (restored, self._policy_and_value_opt_state, self._model_state,
         self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state(
             output_dir, policy_and_value_opt_state, self._model_state)

        if restored:
            logging.info("Restored parameters from iteration [%d]",
                         self._epoch)
            # We should start from the next iteration.
            self._epoch += 1

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
Ejemplo n.º 5
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 policy_and_value_vocab_size=None,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 optimizer_batch_size=64,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=100,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 save_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 separate_eval=True,
                 init_policy_from_world_model_output_dir=None,
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      policy_and_value_vocab_size: Vocabulary size of a policy and value network
        operating on serialized representation. If None, use raw continuous
        representation.
      n_optimizer_steps: Number of optimizer steps.
      optimizer_batch_size: Batch size of an optimizer step.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in a
        trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      save_every_n: How frequently to save the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      separate_eval: Whether to run separate evaluation using a set of
        temperatures. If False, the training reward is reported as evaluation
        reward with temperature 1.0.
      init_policy_from_world_model_output_dir: Model output dir for initializing
        the policy. If None, initialize randomly.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._optimizer_batch_size = optimizer_batch_size
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._save_every_n = save_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures
        self._separate_eval = separate_eval

        action_space = self.train_env.action_space
        assert isinstance(action_space,
                          (gym.spaces.Discrete, gym.spaces.MultiDiscrete))
        if isinstance(action_space, gym.spaces.Discrete):
            n_actions = action_space.n
            n_controls = 1
        else:
            (n_controls, ) = action_space.nvec.shape
            assert n_controls > 0
            assert onp.min(action_space.nvec) == onp.max(action_space.nvec), (
                "Every control must have the same number of actions.")
            n_actions = action_space.nvec[0]
        self._n_actions = n_actions
        self._n_controls = n_controls

        self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        vocab_size = policy_and_value_vocab_size
        self._serialized_sequence_policy = vocab_size is not None
        if self._serialized_sequence_policy:
            self._serialization_kwargs = self._init_serialization(vocab_size)
        else:
            self._serialization_kwargs = {}

        # Initialize the policy and value network.
        policy_and_value_net = ppo.policy_and_value_net(
            n_actions=n_actions,
            n_controls=n_controls,
            vocab_size=vocab_size,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
        self._policy_and_value_net_apply = jit(policy_and_value_net)
        (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype
        policy_and_value_net_params, self._model_state = (
            policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype,
                                                 key1))
        if init_policy_from_world_model_output_dir is not None:
            policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint(
                policy_and_value_net_params,
                init_policy_from_world_model_output_dir)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Restore the optimizer state.
        self._policy_and_value_opt_state = policy_and_value_opt_state
        self._epoch = 0
        self._total_opt_step = 0
        self.update_optimization_state(
            output_dir, policy_and_value_opt_state=policy_and_value_opt_state)

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
        if self._async_mode:
            logging.info(
                "Saving model on startup to have a model policy file.")
            self.save()

        self._rewards_to_actions = self._init_rewards_to_actions()
Ejemplo n.º 6
0
def training_loop(
    env=None,
    epochs=EPOCHS,
    policy_and_value_net_fn=None,
    policy_and_value_optimizer_fn=None,
    batch_size=BATCH_TRAJECTORIES,
    n_optimizer_steps=N_OPTIMIZER_STEPS,
    print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
    target_kl=0.01,
    boundary=20,
    max_timestep=None,
    max_timestep_eval=20000,
    random_seed=None,
    gamma=GAMMA,
    lambda_=LAMBDA,
    epsilon=EPSILON,
    c1=1.0,
    c2=0.01,
    output_dir=None,
    eval_every_n=1000,
    eval_env=None,
    done_frac_for_policy_save=0.5,
    enable_early_stopping=True,
    env_name=None,
    n_evals=1,
    len_history_for_policy=4,
):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    assert env
    assert output_dir
    assert env_name

    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            avg_reward, avg_reward_unclipped = evaluate_policy(
                eval_env,
                get_predictions,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            for k, v in avg_reward.items():
                eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f",
                    i, k, v)
            for k, v in avg_reward_unclipped.items():
                eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f",
                    i, k, v)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/mean_reward", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Calculate log-probabilities and value predictions of the trajectories.
        # We'll pass these to the loss functions so as to not get recomputed.

        # NOTE:
        # There is a slight problem here, if the policy network contains
        # stochasticity in the log-probabilities (ex: dropout), then calculating
        # these again here is not going to be correct and should be done in the
        # collect function.

        log_prob_recompute_start_time = time.time()
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)
        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if (((n_trajectories_done >= done_frac_for_policy_save * batch_size)
             and (i - last_saved_at > eval_every_n) and
             (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
Ejemplo n.º 7
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir,
               random_seed=None, n_devices=None, save_steps=None):
    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    device_count = jax.lib.xla_bridge.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count:
      raise ValueError("Jax cannot work yet with n_devices != all devices: "
                       "%d != %d" % (n_devices, device_count))
    self._n_devices = n_devices
    rng = get_random_number_generator_and_set_seed(random_seed)
    self._output_dir = output_dir
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    # Create input streams.
    inputs = inputs(n_devices)
    self._inputs = inputs
    self._train_stream = inputs.train_stream()

    # Setup optimizer and model.
    state = restore_state(output_dir)
    history = state.history
    self._lr_fn = lr_schedule(history)
    opt = optimizer(self._lr_fn)

    model_train = model(mode="train")
    model_predict_eval = model(mode="eval")

    # Setup state.
    step = state.step or 0
    rng, init_rng = jax_random.split(rng)
    self._rngs = jax_random.split(rng, n_devices)
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
    # Change all None to 1 in input shape.
    model_input_shape = layers.nested_map(
        model_input_shape, lambda x: x if x else 1)
    if state.params:
      params = state.params[0]
      opt_state = state.params
    else:
      params = model_train.initialize(
          model_input_shape, inputs.input_dtype, init_rng)
      opt_state = (params, opt.tree_init(params))
    if n_devices > 1:
      replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape)
      opt_state = layers.nested_map(opt_state, replicate)

    # jit model_predict and update so they're fast
    self._jit_model_predict_eval = _jit_predict_fn(
        model_predict_eval, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._step = step
    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    self._optimizer = optimizer
    self._opt_state = opt_state
    self._history = history
    self._lr_schedule = lr_schedule
Ejemplo n.º 8
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          save_steps=None,
          eval_steps=10,
          eval_frequency=100,
          n_devices=None,
          random_seed=None,
          run_debug_step=False,
          save_graphs=True,
          save_backward_graph=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    save_steps: list of integers. Keep a model file at each of the supplied save
      steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    n_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.
    save_graphs: bool, if True, save computation graph to file.
    save_backward_graph: bool, if True, save backward graph to file too.
  Returns:
    trax.State
  """
  if save_steps is None:
    save_steps = []
  device_count = jax.lib.xla_bridge.device_count()
  n_devices = n_devices or device_count
  # TODO(lukaszkaiser): remove this restriction when possible.
  if n_devices != device_count:
    raise ValueError("Jax cannot work yet with n_devices != all devices: "
                     "%d != %d" % (n_devices, device_count))
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(n_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fn = lr_schedule(history)
  opt = optimizer(lr_fn)

  model_train = layers.Serial(model(mode="train"))
  model_predict_eval = layers.Serial(model(mode="eval"))

  # Setup state
  step = state.step or 0
  rng, init_rng = jax_random.split(rng)
  rngs = jax_random.split(rng, n_devices)
  first_shape = inputs.input_shape[0]
  # If the inputs are a tuple/list, add [-1] (batch) to each element.
  if isinstance(first_shape, (list, tuple)):
    model_input_shape = tuple(
        [tuple([-1] + list(shape)) for shape in inputs.input_shape])
  else:  # Otherwise just add [-1] to the input shape.
    model_input_shape = tuple([-1] + list(inputs.input_shape))
  if state.params:
    params = state.params[0]
    opt_state = state.params
  else:
    params = model_train.initialize(model_input_shape, init_rng)
    opt_state = (params, opt.tree_init(params))
  if n_devices > 1:
    replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape)
    opt_state = layers.nested_map(opt_state, replicate)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices)
  jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % n_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fn(params, next(train_stream), model_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if n_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device(next_train_batch, n_devices)
      opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs)
      step += 1

      if step in save_steps:
        _save_replicated(opt_state, step, history, n_devices, output_dir, True)

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fn(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Print number of parameters
    if step == 1:
      sizes = layers.sizes(opt_state[0])
      if n_devices > 1:
        unreplicate = lambda x: x.mean(0)
        single_params = layers.nested_map(opt_state[0], unreplicate)
        sizes = layers.sizes(single_params)
      total_size = layers.nested_reduce(sizes, sum)
      step_log(step, "Total trainable parameters size: %d" % total_size)

    # Evaluate in parallel
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fn=functools.partial(jit_model_predict_eval,
                                     params=opt_state[0]),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save computation graph (single-device only for now).
    if save_graphs and step == 1 and n_devices == 1:
      params = opt_state[0]
      # Dump computation graphs to files.
      forward_computation = jax.xla_computation(model_predict_eval)(
          next_train_batch[0], params=params, rng=rng)
      with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f:
        f.write(forward_computation.GetHloText())
      with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f:
        f.write(forward_computation.GetHloDotGraph())
      backward_computation = jax.xla_computation(jit_update_fn)(
          step, opt_state, next_train_batch, rngs)
      with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f:
        f.write(backward_computation.GetHloText())
      if save_backward_graph:  # Backward graphs can be large so we guard it.
        with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f:
          f.write(backward_computation.GetHloDotGraph())

    # Save state
    _save_replicated(opt_state, step, history, n_devices, output_dir, False)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fn = lr_fn
    lr_fn = lr_schedule(history)
    if lr_fn != old_lr_fn:  # For performance, only jit if there is a change.
      opt = optimizer(lr_fn)
      jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=opt_state, step=step, history=history)
Ejemplo n.º 9
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fun=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          random_seed=None,
          run_debug_step=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
  num_devices = num_devices or jax.lib.xla_bridge.device_count()
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(num_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fun = lr_schedule(history)
  opt_init, _ = optimizer(lr_fun)
  model_init, model_predict_train = model(mode="train")
  _, model_predict_eval = model(mode="eval")

  # Setup state
  step = state.step or 0
  rng, init_key = jax_random.split(rng)
  params_initializer = \
      lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1]
  params = state.params or params_initializer()
  opt_state = opt_init(params)
  if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when pmap is stable.
    opt_state = jax.replicate(opt_state)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fun(model_predict_eval, num_devices)
  jit_update_fun = _jit_update_fun(
      model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

  print()
  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % num_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fun(params, next(train_stream), model_predict_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device_pair(next_train_batch, num_devices)
      rng, subrng = jax_random.split(rng)
      opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng)
      step += 1

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fun(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Evaluate
    params = trax_opt.get_params(opt_state)
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fun=functools.partial(jit_model_predict_eval, params),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save state
    save_state(State(params=params, step=step, history=history), output_dir)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fun = lr_fun
    lr_fun = lr_schedule(history)
    if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
      jit_update_fun = _jit_update_fun(
          model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=params, step=step, history=history)
Ejemplo n.º 10
0
def train(output_dir,
          model=gin.REQUIRED,
          inputs=gin.REQUIRED,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          run_debug_step=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
    rng = random.PRNGKey(0)
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    inputs = inputs()

    # Setup optimizer and model
    state = restore_state(output_dir)
    history = state.history
    lr_fun = lr_schedule(history)
    opt_init, _ = optimizer(lr_fun)
    model_init, model_predict_original = model()

    # We need a model_predict that fills in the random generator if needed.
    def model_predict(x, y, **kwargs):
        """Same as model_predict_original but fill in rng if it isn't passed."""
        if "rng" in kwargs:
            return model_predict_original(x, y, **kwargs)
        return model_predict_original(x, y, rng=rng, **kwargs)

    # Setup state
    step = state.step or 0
    params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1]
    params = state.params or params_initializer()
    opt_state = opt_init(params)

    # jit model_predict and update so they're fast
    jit_model_predict = jax.jit(model_predict)  # for evaluation
    jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun)

    print()
    train_stream = inputs.train_stream()
    epoch_steps = itertools.chain(
        [
            1,  # first epoch only 1 step
            eval_frequency - 1
        ],
        itertools.repeat(eval_frequency))
    step_log(step, "Starting training")

    # Non-compiled debug step helps find problems in models easier.
    if run_debug_step:
        debug_loss = loss(params, next(train_stream), model_predict)
        step_log(step, "Debug step loss %.8f" % debug_loss)

    for epoch, epoch_steps in epochs(train_steps, epoch_steps):
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            opt_state = jit_update_fun(step, opt_state, next(train_stream))
            step += 1

            # LR log
            if step == 1 or step % 10 == 0:
                train_sw.scalar("training/learning rate",
                                lr_fun(step),
                                step=step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)

        # Evaluate
        params = jax_opt.get_params(opt_state)
        evaluate_train_and_eval(step=step,
                                inputs=inputs,
                                predict_fun=functools.partial(
                                    jit_model_predict, params),
                                eval_steps=eval_steps,
                                train_sw=train_sw,
                                eval_sw=eval_sw,
                                history=history)

        # Save state
        save_state(State(params=params, step=step, history=history),
                   output_dir)

        # Save Gin config
        # Gin only tracks the used parameters, so we save it after the first epoch.
        if epoch == 1:
            save_gin(output_dir, train_sw)

        # Update learning rate with new history
        old_lr_fun = lr_fun
        lr_fun = lr_schedule(history)
        if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
            jit_update_fun = _jit_update_fun(model_predict, loss, optimizer,
                                             lr_fun)

        # Flush summary writers
        train_sw.writer.flush()
        eval_sw.writer.flush()

    step_log(step, "Training done")
    return State(params=params, step=step, history=history)
Ejemplo n.º 11
0
def train(output_dir,
          model=gin.REQUIRED,
          inputs=gin.REQUIRED,
          optimizer=trax_opt.adam,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.

  Returns:
    trax.State
  """
    gfile.makedirs(output_dir)

    inputs = inputs()

    # Setup optimizer and model
    opt_init, opt_update = optimizer(learning_rate)
    model_init, model_predict = model()

    # Setup state
    state = restore_state(output_dir)
    step = state.step or 0
    params_initializer = lambda: model_init([-1] + inputs.input_shape)[1]
    opt_state = opt_init(state.params or params_initializer())

    # Create summary writers.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    # jit model_predict and update so they're fast
    jit_predict = jax.jit(model_predict)  # for evaluation

    @jax.jit
    def update(i, opt_state, batch):
        params = jax_opt.get_params(opt_state)
        return opt_update(i,
                          jax.grad(loss)(params, batch, model_predict),
                          opt_state)

    print()
    step_log(step, "starting training")
    inputs_stream = inputs.train_fn()
    eval_enabled = eval_steps and eval_frequency
    is_first_step = True
    # Evaluate after the first training step, then reset to normal_epoch_steps
    normal_epoch_steps = (eval_enabled and eval_frequency) or train_steps
    epoch_steps = 1
    while step < train_steps:
        print()  # separate logging for each loop iteration

        # Train
        start_time = time.time()
        for _ in range(epoch_steps):
            opt_state = update(step, opt_state, next(inputs_stream))
            if step % 10 == 0:  # Log learning rate curve each 10 steps.
                train_sw.scalar("training/learning rate",
                                learning_rate(step),
                                step=step)
            step += 1
        epoch_time = time.time() - start_time
        step_log(
            step,
            "ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))

        # Save state
        params = jax_opt.get_params(opt_state)
        save_state(State(params=params, step=step), output_dir)

        # Evaluate
        if eval_enabled:
            step_log(step, "starting evaluation")
            train_metrics, eval_metrics = evaluate(
                inputs, functools.partial(jit_predict, params), eval_steps)
            log_metrics(train_metrics, train_sw, "train", step)
            log_metrics(eval_metrics, eval_sw, "eval ", step)
            eval_sw.writer.flush()

        # Gin only tracks the used parameters, so we save it after the first step.
        if is_first_step:
            save_gin(output_dir, train_sw)

        # Log non-metric reports.
        if not is_first_step:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)
        train_sw.writer.flush()

        # After the first step, train for normal_epoch_steps steps before evaluating
        epoch_steps = ((normal_epoch_steps -
                        1) if is_first_step else normal_epoch_steps)
        is_first_step = False

    print()
    step_log(step, "finished training")
    return State(params=params, step=step)
Ejemplo n.º 12
0
def train(output_dir,
          data_dir,
          model=gin.REQUIRED,
          dataset=gin.REQUIRED,
          optimizer=trax_opt.adam,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100):
  """Train the given model on the given dataset.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    data_dir: Directory where the data is located.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    dataset: The name of the TFDS dataset to train on. To train on a T2T
      dataset, prefix the name with "t2t_".
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps).
  """
  gfile.makedirs(output_dir)

  # Make Inputs
  inputs = inputs_lib.make_inputs(dataset, data_dir)

  # Setup optimizer and model
  opt_init, opt_update = optimizer(learning_rate)
  model_init, model_predict = model()

  # Setup state
  state = restore_state(output_dir)
  step = state.step or 0
  params_initializer = lambda: model_init([-1] + inputs.input_shape)[1]
  opt_state = opt_init(state.params or params_initializer())

  # Create summary writers.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  # jit model_predict and update so they're fast
  jit_predict = jax.jit(model_predict)  # for evaluation

  @jax.jit
  def update(i, opt_state, batch):
    params = jax_opt.get_params(opt_state)
    return opt_update(i, jax.grad(loss)(
        params, batch, model_predict), opt_state)

  print()
  step_log(step, "starting training")
  inputs_stream = inputs.train_fn()
  is_first_step = True
  epoch_steps = 1  # First evaluation after the first training step.
  while step < train_steps:
    print()

    # Train
    start_time = time.time()
    for _ in range(epoch_steps):
      opt_state = update(step, opt_state, next(inputs_stream))
      if step % 10 == 0:  # Log learning rate curve each 10 steps.
        train_sw.scalar("training/learning rate",
                        learning_rate(step), step=step)
      step += 1
    epoch_time = time.time() - start_time
    step_log(step, "ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))

    # Save state
    params = jax_opt.get_params(opt_state)
    save_state(State(params=params, step=step), output_dir,
               save_gin=is_first_step)

    # Evaluate
    step_log(step, "starting evaluation")
    train_metrics, eval_metrics = evaluate(
        inputs, functools.partial(jit_predict, params), eval_steps)
    log_metrics(train_metrics, train_sw, "train", step)
    log_metrics(eval_metrics, eval_sw, "eval ", step)

    # Log non-metric reports and flush.
    if not is_first_step:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)
    train_sw.writer.flush()
    eval_sw.writer.flush()

    # After the first step, train for eval_frequency steps before evaluating
    epoch_steps = (eval_frequency - 1) if is_first_step else eval_frequency
    is_first_step = False

  print()
  step_log(step, "finished training")
Ejemplo n.º 13
0
def train(output_dir,
          model=gin.REQUIRED,
          inputs=gin.REQUIRED,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.

  Returns:
    trax.State
  """
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    inputs = inputs()

    # Setup optimizer and model
    state = restore_state(output_dir)
    history = state.history
    lr_fun = lr_schedule(history)
    opt_init, opt_update = optimizer(lr_fun)
    model_init, model_predict = model()

    # Setup state
    step = state.step or 0
    params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1]
    params = state.params or params_initializer()
    opt_state = opt_init(params)

    # jit model_predict and update so they're fast
    jit_predict = jax.jit(model_predict)  # for evaluation

    @jax.jit
    def update(i, opt_state, batch):
        params = jax_opt.get_params(opt_state)
        return opt_update(i,
                          jax.grad(loss)(params, batch, model_predict),
                          opt_state)

    print()
    train_stream = inputs.train_stream()
    epoch_steps = itertools.chain(
        [
            1,  # first epoch only 1 step
            eval_frequency - 1
        ],
        itertools.repeat(eval_frequency))
    step_log(step, "Starting training")

    for epoch, epoch_steps in epochs(train_steps, epoch_steps):
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            opt_state = update(step, opt_state, next(train_stream))
            step += 1

            # LR log
            if step == 1 or step % 10 == 0:
                train_sw.scalar("training/learning rate",
                                lr_fun(step),
                                step=step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)

        # Evaluate
        params = jax_opt.get_params(opt_state)
        evaluate_train_and_eval(step=step,
                                inputs=inputs,
                                predict_fun=functools.partial(
                                    jit_predict, params),
                                eval_steps=eval_steps,
                                train_sw=train_sw,
                                eval_sw=eval_sw,
                                history=history)

        # Save state
        save_state(State(params=params, step=step, history=history),
                   output_dir)

        # Save Gin config
        # Gin only tracks the used parameters, so we save it after the first epoch.
        if epoch == 1:
            save_gin(output_dir, train_sw)

        # Flush summary writers
        train_sw.writer.flush()
        eval_sw.writer.flush()

    step_log(step, "Training done")
    return State(params=params, step=step, history=history)
Ejemplo n.º 14
0
def training_loop(
        env,
        eval_env,
        env_name,
        policy_and_value_net_fn,
        policy_and_value_optimizer_fn,
        output_dir,
        epochs=EPOCHS,
        n_optimizer_steps=N_OPTIMIZER_STEPS,
        print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
        target_kl=0.01,
        boundary=20,
        max_timestep=None,
        max_timestep_eval=20000,
        random_seed=None,
        gamma=GAMMA,
        lambda_=LAMBDA,
        epsilon=EPSILON,
        c1=1.0,
        c2=0.01,
        eval_every_n=1000,
        done_frac_for_policy_save=0.5,
        enable_early_stopping=True,
        n_evals=1,
        len_history_for_policy=4,
        eval_temperatures=(1.0, 0.5),
):
    """Runs the training loop for PPO, with fixed policy and value nets.

  Args:
    env: gym.Env to use for training.
    eval_env: gym.Env to use for evaluation.
    env_name: Name of the environment.
    policy_and_value_net_fn: Function defining the policy and value network.
    policy_and_value_optimizer_fn: Function defining the optimizer.
    output_dir: Output dir.
    epochs: Number of epochs to run for.
    n_optimizer_steps: Number of optimizer steps.
    print_every_optimizer_steps: How often to log during the policy optimization
      process.
    target_kl: Policy iteration early stopping.
    boundary: We pad trajectories at integer multiples of this number.
    max_timestep: If set to an integer, maximum number of time-steps in
      a trajectory. Used in the collect procedure.
    max_timestep_eval: If set to an integer, maximum number of time-steps in an
      evaluation trajectory. Used in the collect procedure.
    random_seed: Random seed.
    gamma: Reward discount factor.
    lambda_: N-step TD-error discount factor in GAE.
    epsilon: Random action probability in epsilon-greedy sampling.
    c1: Value loss coefficient.
    c2: Entropy loss coefficient.
    eval_every_n: How frequently to eval the policy.
    done_frac_for_policy_save: Fraction of the trajectories that should be done
      to checkpoint the policy.
    enable_early_stopping: Whether to enable early stopping.
    n_evals: Number of times to evaluate.
    len_history_for_policy: How much of history to give to the policy.
    eval_temperatures: Sequence of temperatures to try for categorical sampling
      during evaluation.
  """
    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            reward_stats = evaluate_policy(
                eval_env,
                get_predictions,
                temperatures=eval_temperatures,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            write_eval_reward_summaries(reward_stats, eval_sw, epoch=i)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=env.batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions, padded_rewards,
         padded_infos) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        log_prob_recompute_start_time = time.time()
        assert ("log_prob_actions" in padded_infos
                and "value_predictions" in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        actual_log_probabs_traj = padded_infos["log_prob_actions"]
        actual_value_predictions_traj = padded_infos["value_predictions"]

        assert (B, T) == actual_log_probabs_traj.shape[:2]
        A = actual_log_probabs_traj.shape[2]  # pylint: disable=invalid-name
        assert (B, T, 1) == actual_value_predictions_traj.shape

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with
        # (B, T, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)

        assert (B, T + 1, A) == log_probabs_traj.shape
        assert (B, T + 1, 1) == value_predictions_traj.shape

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        log_probabs_traj = np.concatenate(
            (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        value_predictions_traj = np.concatenate(
            (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
            axis=1)

        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if ((
            (n_trajectories_done >= done_frac_for_policy_save * env.batch_size)
                and (i - last_saved_at > eval_every_n) and
            (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
Ejemplo n.º 15
0
  def __init__(
      self,
      train_env,
      eval_env,
      policy_and_value_model,
      policy_and_value_optimizer_fn,
      policy_and_value_two_towers,
      output_dir,
      n_optimizer_steps,
      print_every_optimizer_steps,
      target_kl,
      boundary,
      max_timestep,
      max_timestep_eval,
      random_seed,
      gamma,
      lambda_,
      c1,
      c2,
      eval_every_n,
      done_frac_for_policy_save,
      n_evals,
      len_history_for_policy,
      eval_temperatures,
  ):
    self._train_env = train_env
    self._eval_env = eval_env
    self._n_optimizer_steps = n_optimizer_steps
    self._print_every_optimizer_steps = print_every_optimizer_steps
    self._target_kl = target_kl
    self._boundary = boundary
    self._max_timestep = max_timestep
    self._max_timestep_eval = max_timestep_eval
    self._gamma = gamma
    self._lambda_ = lambda_
    self._c1 = c1
    self._c2 = c2
    self._eval_every_n = eval_every_n
    self._done_frac_for_policy_save = done_frac_for_policy_save
    self._n_evals = n_evals
    self._len_history_for_policy = len_history_for_policy
    self._eval_temperatures = eval_temperatures

    assert isinstance(self._train_env.action_space, gym.spaces.Discrete)
    n_actions = self._train_env.action_space.n

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + self._train_env.observation_space.shape
    observations_dtype = self._train_env.observation_space.dtype

    self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
    self._rng, key1 = jax_random.split(self._rng, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net(
            rng_key=key1,
            batch_observations_shape=batch_observations_shape,
            observations_dtype=observations_dtype,
            n_actions=n_actions,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
    )
    self._policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restored, policy_and_value_net_params, self._epoch = (
        maybe_restore_params(output_dir, policy_and_value_net_params))

    if restored:
      logging.info("Restored parameters from iteration [%d]", self._epoch)
      # We should start from the next iteration.
      self._epoch += 1

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (self._policy_and_value_opt_state, self._policy_and_value_opt_update,
     self._policy_and_value_get_params) = policy_and_value_optimizer

    self._output_dir = output_dir
    gfile.makedirs(self._output_dir)

    # Create summary writers and history.
    self._train_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "train"))
    self._timing_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "timing"))
    self._eval_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "eval"))

    self._should_reset = True
    self._n_trajectories_done = 0

    self._last_saved_at = 0
Ejemplo n.º 16
0
    def __init__(
            self,
            train_env,
            eval_env,
            output_dir,
            policy_trainer_class,
            n_real_epochs=10,
            data_eval_frac=0.125,
            model_train_batch_size=64,
            n_model_initial_train_steps=1000,
            n_model_train_steps_per_epoch=1000,
            simulated_env_problem_class=(
                simulated_env_problem.SerializedSequenceSimulatedEnvProblem),
            simulated_batch_size=16,
            n_simulated_epochs=1000,
            trajectory_dump_dir=None,
            initial_trajectory_dir=None,
            initial_trajectory_mix_prob=0.5,
            initial_model=None,
            init_policy_from_world_model=False,
            **kwargs):
        super(SimPLe, self).__init__(train_env, eval_env, output_dir, **kwargs)
        self._policy_dir = os.path.join(output_dir, "policy")
        self._model_dir = os.path.join(output_dir, "model")
        # Initialize the policy trainer lazily, so in case of initializing the
        # policy from world model checkpoint, the trainer will try to load the
        # checkpoint _after_ it's been created in train_model().
        self._policy_trainer_fn = functools.partial(
            policy_trainer_class,
            train_env=train_env,
            eval_env=eval_env,
            output_dir=self._policy_dir,
            async_mode=self._async_mode,
            init_policy_from_world_model_output_dir=(
                self._model_dir if init_policy_from_world_model else None),
        )
        self._policy_trainer = None
        self._n_real_epochs = n_real_epochs
        self._model_train_batch_size = model_train_batch_size
        self._n_model_initial_train_steps = n_model_initial_train_steps
        self._n_model_train_steps_per_epoch = n_model_train_steps_per_epoch
        self._data_eval_frac = data_eval_frac

        gfile.makedirs(self._model_dir)
        if initial_model is not None:
            gfile.copy(
                initial_model,
                os.path.join(self._model_dir, "model.pkl"),
                overwrite=True,
            )
        self._initial_model = initial_model
        self._initial_trajectories = None

        self._sim_env = simulated_env_problem_class(
            batch_size=None,
            observation_space=train_env.observation_space,
            action_space=train_env.action_space,
            reward_range=train_env.reward_range,
            discrete_rewards=train_env.discrete_rewards,
            history_stream=None,  # TODO(pkozakowski): Support this.
            output_dir=self._model_dir,
        )
        self._simulated_batch_size = simulated_batch_size
        self._n_simulated_epochs = n_simulated_epochs

        # If trajectory_dump_dir is not provided explicitly, save the trajectories
        # in output_dir.
        if trajectory_dump_dir is None:
            trajectory_dump_dir = os.path.join(output_dir, "trajectories")
        self._trajectory_dump_root_dir = trajectory_dump_dir

        self._initial_trajectory_dir = initial_trajectory_dir
        self._initial_trajectory_mix_prob = initial_trajectory_mix_prob

        self._summary_writer = jaxboard.SummaryWriter(self._output_dir)

        self._simple_epoch = 0
        self._policy_epoch = 0
        self._model_train_step = 0