Esempio n. 1
0
  def get_policy_output(observations):
    # Get the fresh params for collecting the policy.
    if policy_net_apply is not None:
      return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state))

    assert policy_and_value_net_apply

    policy_predictions, unused_value_predictions = policy_and_value_net_apply(
        observations, trax_opt.get_params(policy_and_value_opt_state))
    return policy_predictions
Esempio n. 2
0
def ppo_opt_step(i,
                 opt_state,
                 ppo_opt_update,
                 policy_net_apply,
                 old_policy_params,
                 value_net_apply,
                 value_net_params,
                 padded_observations,
                 padded_actions,
                 padded_rewards,
                 reward_mask,
                 gamma=0.99,
                 lambda_=0.95,
                 epsilon=0.1):
    """PPO optimizer step."""
    new_policy_params = trax_opt.get_params(opt_state)
    g = grad(ppo_loss, argnums=1)(policy_net_apply,
                                  new_policy_params,
                                  old_policy_params,
                                  value_net_apply,
                                  value_net_params,
                                  padded_observations,
                                  padded_actions,
                                  padded_rewards,
                                  reward_mask,
                                  gamma=gamma,
                                  lambda_=lambda_,
                                  epsilon=epsilon)
    return ppo_opt_update(i, g, opt_state)
Esempio n. 3
0
 def single_update(i, opt_state, batch, rng):
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     return opt_update(
         i,
         backend.grad(loss_fun)(params, batch, predict_fun, rng),
         opt_state)
Esempio n. 4
0
 def single_update(i, opt_state, batch, rng):
     rng, subrng = jax_random.split(rng[0])
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     return opt_update(
         i,
         backend.grad(loss_fun)(params, batch, predict_fun, rng),
         opt_state), [subrng]
Esempio n. 5
0
 def mapped_update(i, opt_state, batch, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = num_devices.
   _, opt_update = optimizer(lr_fun)
   params = trax_opt.get_params(opt_state)
   grads = backend.grad(loss_fun)(params, batch, predict_fun, rng)
   grads = jax.tree_util.tree_map(
       lambda g: lax.psum(g, "batch"), grads)
   return opt_update(i, grads, opt_state)
Esempio n. 6
0
def value_opt_step(i,
                   opt_state,
                   opt_update,
                   value_net_apply,
                   padded_observations,
                   padded_rewards,
                   reward_mask,
                   gamma=0.99):
    """Value optimizer step."""
    value_params = trax_opt.get_params(opt_state)
    # Note this partial application here and argnums above in ppo_opt_step.
    g = grad(functools.partial(value_loss,
                               value_net_apply))(value_params,
                                                 padded_observations,
                                                 padded_rewards,
                                                 reward_mask,
                                                 gamma=gamma)
    return opt_update(i, g, opt_state)
Esempio n. 7
0
def policy_and_value_opt_step(i,
                              opt_state,
                              opt_update,
                              policy_and_value_net_apply,
                              old_params,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              c1=1.0,
                              c2=0.01,
                              gamma=0.99,
                              lambda_=0.95,
                              epsilon=0.1):
  """Policy and Value optimizer step."""

  # Combined loss function given the new params.
  def policy_and_value_loss(params):
    """Returns the combined loss given just parameters."""
    (loss, _, _, _) = combined_loss(
        params,
        old_params,
        policy_and_value_net_apply,
        padded_observations,
        padded_actions,
        padded_rewards,
        reward_mask,
        c1=c1,
        c2=c2,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon)
    return loss

  new_params = trax_opt.get_params(opt_state)
  g = grad(policy_and_value_loss)(new_params)
  return opt_update(i, g, opt_state)
Esempio n. 8
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fun=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          random_seed=None,
          run_debug_step=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
  num_devices = num_devices or jax.lib.xla_bridge.device_count()
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(num_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fun = lr_schedule(history)
  opt_init, _ = optimizer(lr_fun)
  model_init, model_predict_train = model(mode="train")
  _, model_predict_eval = model(mode="eval")

  # Setup state
  step = state.step or 0
  rng, init_key = jax_random.split(rng)
  params_initializer = \
      lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1]
  params = state.params or params_initializer()
  opt_state = opt_init(params)
  if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when pmap is stable.
    opt_state = jax.replicate(opt_state)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fun(model_predict_eval, num_devices)
  jit_update_fun = _jit_update_fun(
      model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

  print()
  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % num_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fun(params, next(train_stream), model_predict_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device_pair(next_train_batch, num_devices)
      rng, subrng = jax_random.split(rng)
      opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng)
      step += 1

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fun(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Evaluate
    params = trax_opt.get_params(opt_state)
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fun=functools.partial(jit_model_predict_eval, params),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save state
    save_state(State(params=params, step=step, history=history), output_dir)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fun = lr_fun
    lr_fun = lr_schedule(history)
    if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
      jit_update_fun = _jit_update_fun(
          model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=params, step=step, history=history)
Esempio n. 9
0
def training_loop(
        env=None,
        env_name="CartPole-v0",
        epochs=EPOCHS,
        policy_net_fun=None,
        value_net_fun=None,
        policy_and_value_net_fun=None,  # TODO(afrozm): Implement.
        policy_optimizer_fun=optimizer_fun,
        value_optimizer_fun=optimizer_fun,
        batch_size=BATCH_TRAJECTORIES,
        num_optimizer_steps=NUM_OPTIMIZER_STEPS,
        print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
        boundary=20,
        max_timestep=None,
        random_seed=None):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    value_losses = []
    ppo_objective = []
    average_rewards = []

    env = env if env is not None else gym.make(env_name)

    # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (-1, -1) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    # TODO(afrozm): Have a single net for both policy and action.
    assert policy_and_value_net_fun is None

    # Initialize the policy and value functions.
    assert policy_net_fun and value_net_fun
    jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

    policy_net_params, policy_net_apply = policy_net_fun(
        key1, batch_observations_shape, num_actions)
    value_net_params, value_net_apply = value_net_fun(
        key2, batch_observations_shape, num_actions)

    # Initialize the optimizers.
    assert policy_optimizer_fun and value_optimizer_fun

    ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
    value_opt_state, value_opt_update = value_optimizer_fun(value_net_params)

    for i in range(epochs):
        t = time.time()
        t0 = t
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        trajs = collect_trajectories(
            env,
            policy_net_apply,
            policy_net_params,
            num_trajectories=batch_size,
            policy=POLICY,
            max_timestep=max_timestep,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)
        average_rewards.append(avg_reward)

        logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                     avg_reward, max_reward, min_reward)
        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     get_time(t))
        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)

        logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1)))

        t = time.time()
        cur_value_loss = value_loss(value_net_apply,
                                    value_net_params,
                                    padded_observations,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=GAMMA)

        logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t))
        value_losses.append(cur_value_loss)

        t = time.time()
        cur_ppo_loss = ppo_loss(policy_net_apply,
                                policy_net_params,
                                policy_net_params,
                                value_net_apply,
                                value_net_params,
                                padded_observations,
                                padded_actions,
                                padded_rewards,
                                reward_mask,
                                gamma=GAMMA,
                                lambda_=LAMBDA,
                                epsilon=epsilon)
        # ppo_loss = 11.00110011
        logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t))
        ppo_objective.append(-cur_ppo_loss)

        # Run optimizers.
        logging.vlog(1, "PPO Optimization")
        t1 = time.time()

        for j in range(num_optimizer_steps):
            t = time.time()
            # Update the optimizer state.
            ppo_opt_state = ppo_opt_step(j,
                                         ppo_opt_state,
                                         ppo_opt_update,
                                         policy_net_apply,
                                         policy_net_params,
                                         value_net_apply,
                                         value_net_params,
                                         padded_observations,
                                         padded_actions,
                                         padded_rewards,
                                         reward_mask,
                                         gamma=GAMMA,
                                         lambda_=LAMBDA,
                                         epsilon=epsilon)
            t2 = time.time()
            # Get the new params.
            new_policy_net_params = trax_opt.get_params(ppo_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_ppo_loss = ppo_loss(policy_net_apply,
                                        new_policy_net_params,
                                        policy_net_params,
                                        value_net_apply,
                                        value_net_params,
                                        padded_observations,
                                        padded_actions,
                                        padded_rewards,
                                        reward_mask,
                                        gamma=GAMMA,
                                        lambda_=LAMBDA,
                                        epsilon=epsilon)
                logging.vlog(1, "One PPO grad desc took: %0.2f msec",
                             get_time(t, t2))
                logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss,
                             new_ppo_loss)
            # Update the params.
            policy_net_params = new_policy_net_params

        logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                     (100 *
                      (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

        logging.vlog(1, "Value Optimization")

        for j in range(num_optimizer_steps):
            t = time.time()
            value_opt_state = value_opt_step(j,
                                             value_opt_state,
                                             value_opt_update,
                                             value_net_apply,
                                             padded_observations,
                                             padded_rewards,
                                             reward_mask,
                                             gamma=GAMMA)
            t2 = time.time()
            value_net_params = trax_opt.get_params(value_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_value_loss = value_loss(value_net_apply,
                                            value_net_params,
                                            padded_observations,
                                            padded_rewards,
                                            reward_mask,
                                            gamma=GAMMA)
                logging.vlog(1, "One value grad desc took: %0.2f msec",
                             get_time(t, t2))
                logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]",
                             cur_value_loss, new_value_loss)
        logging.vlog(
            1, "Total value loss reduction [%0.2f]%%",
            (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

        logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

        # Set the optimized params to new params.
        policy_net_params = trax_opt.get_params(ppo_opt_state)
        value_net_params = trax_opt.get_params(value_opt_state)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
            "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i,
            min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss,
            get_time(t0))

    logging.vlog(1, "value_losses: %s", np.stack(value_losses))
    logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective))
    logging.vlog(1, "average_rewards: %s", average_rewards)

    return ((policy_net_params, value_net_params), average_rewards,
            np.stack(value_losses), np.stack(ppo_objective))
Esempio n. 10
0
def training_loop(env=None,
                  env_name="CartPole-v0",
                  epochs=EPOCHS,
                  policy_net_fun=None,
                  value_net_fun=None,
                  policy_and_value_net_fun=None,
                  policy_optimizer_fun=None,
                  value_optimizer_fun=None,
                  policy_and_value_optimizer_fun=None,
                  batch_size=BATCH_TRAJECTORIES,
                  num_optimizer_steps=NUM_OPTIMIZER_STEPS,
                  print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                  boundary=20,
                  max_timestep=None,
                  random_seed=None,
                  gamma=GAMMA,
                  lambda_=LAMBDA,
                  epsilon=EPSILON,
                  c1=1.0,
                  c2=0.01):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    value_losses = []
    ppo_objective = []
    combined_losses = []
    average_rewards = []

    env = env if env is not None else gym.make(env_name)

    # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (-1, -1) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    policy_and_value_net_params, policy_and_value_net_apply = None, None
    policy_and_value_opt_state, policy_and_value_opt_update = None, None
    policy_net_params, policy_net_apply = None, None
    value_net_params, value_net_apply = None, None
    if policy_and_value_net_fun is not None:
        jax_rng_key, subkey = jax_random.split(jax_rng_key)

        # Initialize the policy and value network.
        policy_and_value_net_params, policy_and_value_net_apply = (
            policy_and_value_net_fun(subkey, batch_observations_shape,
                                     num_actions))

        # Initialize the optimizers.
        policy_and_value_opt_state, policy_and_value_opt_update = (
            policy_and_value_optimizer_fun(policy_and_value_net_params))
    else:
        # Initialize the policy and value functions.
        assert policy_net_fun and value_net_fun
        jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

        policy_net_params, policy_net_apply = policy_net_fun(
            key1, batch_observations_shape, num_actions)
        value_net_params, value_net_apply = value_net_fun(
            key2, batch_observations_shape, num_actions)

        # Initialize the optimizers.
        ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
        value_opt_state, value_opt_update = value_optimizer_fun(
            value_net_params)

    # A function that will call the appropriate policy function with parameters.
    def get_policy_output(observations):
        if policy_net_apply is not None:
            assert policy_net_params
            return policy_net_apply(observations, policy_net_params)

        assert policy_and_value_net_apply and policy_and_value_net_params
        policy_predictions, unused_value_predictions = policy_and_value_net_apply(
            observations, policy_and_value_net_params)
        return policy_predictions

    for i in range(epochs):
        t = time.time()
        t0 = t
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        trajs = collect_trajectories(
            env,
            policy_fun=get_policy_output,
            num_trajectories=batch_size,
            policy=POLICY,
            max_timestep=max_timestep,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)
        average_rewards.append(avg_reward)

        logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                     avg_reward, max_reward, min_reward)
        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     get_time(t))
        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)

        logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        epsilon_schedule = epsilon if epochs == 1 else epsilon * (
            1.0 - (i / (epochs - 1)))

        # Compute value and ppo losses.
        cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None
        if policy_and_value_net_apply is not None:
            t = time.time()
            cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = (
                combined_loss(policy_and_value_net_params,
                              policy_and_value_net_params,
                              policy_and_value_net_apply,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              gamma=gamma,
                              lambda_=lambda_,
                              epsilon=epsilon_schedule,
                              c1=c1,
                              c2=c2))
            logging.vlog(
                1,
                "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.",
                cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t))
        else:
            t = time.time()
            cur_value_loss = value_loss(value_net_apply,
                                        value_net_params,
                                        padded_observations,
                                        padded_rewards,
                                        reward_mask,
                                        gamma=gamma)

            logging.vlog(1, "Calculating value loss took %0.2f msec.",
                         get_time(t))

            t = time.time()
            cur_ppo_loss = ppo_loss(policy_net_apply,
                                    policy_net_params,
                                    policy_net_params,
                                    value_net_apply,
                                    value_net_params,
                                    padded_observations,
                                    padded_actions,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=gamma,
                                    lambda_=lambda_,
                                    epsilon=epsilon_schedule)
            logging.vlog(1, "Calculating PPO loss took %0.2f msec.",
                         get_time(t))

        value_losses.append(cur_value_loss)
        ppo_objective.append(-1.0 * cur_ppo_loss)
        combined_losses.append(cur_combined_loss)

        if policy_and_value_net_apply:
            logging.vlog(1, "Policy and Value Optimization")
            t1 = time.time()
            for j in range(num_optimizer_steps):
                t = time.time()
                # Update the optimizer state.
                policy_and_value_opt_state = policy_and_value_opt_step(
                    j,
                    policy_and_value_opt_state,
                    policy_and_value_opt_update,
                    policy_and_value_net_apply,
                    policy_and_value_net_params,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    c1=c1,
                    c2=c2,
                    gamma=gamma,
                    lambda_=lambda_,
                    epsilon=epsilon_schedule)
                t2 = time.time()
                # Get the new params.
                new_policy_and_value_net_params = trax_opt.get_params(
                    policy_and_value_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    # Compute and log the loss.
                    (loss_combined, loss_ppo, loss_value,
                     unused_entropy_bonus) = (
                         combined_loss(
                             new_policy_and_value_net_params,
                             policy_and_value_net_params,  # old params
                             policy_and_value_net_apply,
                             padded_observations,
                             padded_actions,
                             padded_rewards,
                             reward_mask,
                             gamma=gamma,
                             lambda_=lambda_,
                             epsilon=epsilon_schedule,
                             c1=c1,
                             c2=c2))
                    logging.vlog(
                        1, "One Policy and Value grad desc took: %0.2f msec",
                        get_time(t, t2))
                    logging.vlog(
                        1,
                        "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]",
                        cur_combined_loss, loss_combined, loss_value, loss_ppo)
                # Update the params.
                policy_and_value_net_params = new_policy_and_value_net_params

            logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                         (100 * (cur_combined_loss - loss_combined) /
                          np.abs(cur_combined_loss)))

            logging.info(
                "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined"
                " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]",
                i, min_reward, max_reward, avg_reward, loss_combined,
                loss_value, loss_ppo, get_time(t1))
        else:
            # Run optimizers.
            logging.vlog(1, "PPO Optimization")
            t1 = time.time()

            for j in range(num_optimizer_steps):
                t = time.time()
                # Update the optimizer state.
                ppo_opt_state = ppo_opt_step(
                    j,
                    ppo_opt_state,
                    ppo_opt_update,
                    policy_net_apply,
                    policy_net_params,
                    value_net_apply,
                    value_net_params,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    gamma=gamma,
                    lambda_=lambda_,
                    epsilon=epsilon_schedule,
                )
                t2 = time.time()
                # Get the new params.
                new_policy_net_params = trax_opt.get_params(ppo_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    new_ppo_loss = ppo_loss(
                        policy_net_apply,
                        new_policy_net_params,
                        policy_net_params,
                        value_net_apply,
                        value_net_params,
                        padded_observations,
                        padded_actions,
                        padded_rewards,
                        reward_mask,
                        gamma=gamma,
                        lambda_=lambda_,
                        epsilon=epsilon_schedule,
                    )
                    logging.vlog(1, "One PPO grad desc took: %0.2f msec",
                                 get_time(t, t2))
                    logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]",
                                 cur_ppo_loss, new_ppo_loss)
                # Update the params.
                policy_net_params = new_policy_net_params

            logging.vlog(
                1, "Total PPO loss reduction [%0.2f]%%",
                (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

            logging.vlog(1, "Value Optimization")

            for j in range(num_optimizer_steps):
                t = time.time()
                value_opt_state = value_opt_step(j,
                                                 value_opt_state,
                                                 value_opt_update,
                                                 value_net_apply,
                                                 padded_observations,
                                                 padded_rewards,
                                                 reward_mask,
                                                 gamma=gamma)
                t2 = time.time()
                value_net_params = trax_opt.get_params(value_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    new_value_loss = value_loss(value_net_apply,
                                                value_net_params,
                                                padded_observations,
                                                padded_rewards,
                                                reward_mask,
                                                gamma=gamma)
                    logging.vlog(1, "One value grad desc took: %0.2f msec",
                                 get_time(t, t2))
                    logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]",
                                 cur_value_loss, new_value_loss)
            logging.vlog(
                1, "Total value loss reduction [%0.2f]%%",
                (100 *
                 (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

            logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

            # Set the optimized params to new params.
            policy_net_params = trax_opt.get_params(ppo_opt_state)
            value_net_params = trax_opt.get_params(value_opt_state)

            logging.info(
                "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
                "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]",
                i, min_reward, max_reward, avg_reward, new_ppo_loss,
                new_value_loss, get_time(t0))

    # Log the parameters, just for the sake of it.
    if policy_net_params:
        log_params(policy_net_params, "policy_net_params")
    if value_net_params:
        log_params(value_net_params, "value_net_params")
    if policy_and_value_net_params:
        log_params(policy_and_value_net_params, "policy_and_value_net_params")

    if value_losses:
        logging.vlog(1, "value_losses: %s", np.stack(value_losses))
    if ppo_objective:
        logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective))
    if average_rewards:
        logging.vlog(1, "average_rewards: %s", average_rewards)

    return ((policy_net_params, value_net_params), average_rewards,
            np.stack(value_losses), np.stack(ppo_objective))
Esempio n. 11
0
def training_loop(
    env=None,
    epochs=EPOCHS,
    policy_net_fun=None,
    value_net_fun=None,
    policy_and_value_net_fun=None,
    policy_optimizer_fun=None,
    value_optimizer_fun=None,
    policy_and_value_optimizer_fun=None,
    batch_size=BATCH_TRAJECTORIES,
    num_optimizer_steps=NUM_OPTIMIZER_STEPS,
    policy_only_num_optimizer_steps=POLICY_ONLY_NUM_OPTIMIZER_STEPS,
    value_only_num_optimizer_steps=VALUE_ONLY_NUM_OPTIMIZER_STEPS,
    print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
    target_kl=0.01,
    boundary=20,
    max_timestep=None,
    random_seed=None,
    gamma=GAMMA,
    lambda_=LAMBDA,
    epsilon=EPSILON,
    c1=1.0,
    c2=0.01):
  """Runs the training loop for PPO, with fixed policy and value nets."""
  assert env
  jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

  value_losses = []
  ppo_objective = []
  combined_losses = []
  average_rewards = []

  # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
  # policy and value networks on shape [B, T] +_OBS
  batch_observations_shape = (-1, -1) + env.observation_space.shape

  assert isinstance(env.action_space, gym.spaces.Discrete)
  num_actions = env.action_space.n

  policy_and_value_net_params, policy_and_value_net_apply = None, None
  policy_and_value_opt_state, policy_and_value_opt_update = None, None
  policy_net_params, policy_net_apply = None, None
  value_net_params, value_net_apply = None, None
  if policy_and_value_net_fun is not None:
    jax_rng_key, subkey = jax_random.split(jax_rng_key)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fun(subkey, batch_observations_shape, num_actions))

    # Initialize the optimizers.
    policy_and_value_opt_state, policy_and_value_opt_update = (
        policy_and_value_optimizer_fun(policy_and_value_net_params))

    policy_and_value_net_apply = jit(policy_and_value_net_apply)
  else:
    # Initialize the policy and value functions.
    assert policy_net_fun and value_net_fun
    jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

    policy_net_params, policy_net_apply = policy_net_fun(
        key1, batch_observations_shape, num_actions)
    value_net_params, value_net_apply = value_net_fun(key2,
                                                      batch_observations_shape,
                                                      num_actions)

    policy_net_apply = jit(policy_net_apply)
    value_net_apply = jit(value_net_apply)

    # Initialize the optimizers.
    ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
    value_opt_state, value_opt_update = value_optimizer_fun(value_net_params)

  # A function that will call the appropriate policy function with parameters.
  def get_policy_output(observations):
    # Get the fresh params for collecting the policy.
    if policy_net_apply is not None:
      return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state))

    assert policy_and_value_net_apply

    policy_predictions, unused_value_predictions = policy_and_value_net_apply(
        observations, trax_opt.get_params(policy_and_value_opt_state))
    return policy_predictions

  for i in range(epochs):
    t = time.time()
    t0 = t
    logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
    trajs = collect_trajectories(
        env,
        policy_fun=get_policy_output,
        num_trajectories=batch_size,
        policy=POLICY,
        max_timestep=max_timestep,
        boundary=boundary,
        epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

    logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t))

    # These were the params that were used to collect the trajectory.
    if policy_and_value_net_apply:
      policy_and_value_net_params = trax_opt.get_params(
          policy_and_value_opt_state)
    else:
      policy_net_params = trax_opt.get_params(ppo_opt_state)
      value_net_params = trax_opt.get_params(value_opt_state)

    avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
    max_reward = max(np.sum(traj[2]) for traj in trajs)
    min_reward = min(np.sum(traj[2]) for traj in trajs)
    average_rewards.append(avg_reward)

    logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                 avg_reward, max_reward, min_reward)
    logging.vlog(2, "Rewards: %s", [float(np.sum(traj[2])) for traj in trajs])
    logging.vlog(1, "Average Rewards: %s", average_rewards)

    logging.vlog(1,
                 "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                 float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
                 max(len(traj[0]) for traj in trajs),
                 min(len(traj[0]) for traj in trajs))
    logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs])

    t = time.time()
    (_, reward_mask, padded_observations, padded_actions,
     padded_rewards) = pad_trajectories(
         trajs, boundary=boundary)

    logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
    logging.vlog(1, "Padded Observations' shape [%s]",
                 str(padded_observations.shape))
    logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape))
    logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape))

    # Some assertions.
    B, T = padded_actions.shape  # pylint: disable=invalid-name
    assert (B, T) == padded_rewards.shape
    assert (B, T) == reward_mask.shape
    assert (B, T + 1) == padded_observations.shape[:2]
    assert (B, T + 1) + env.observation_space.shape == padded_observations.shape

    # Linear annealing from 0.1 to 0.0
    # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
    #                                                           (i /
    #                                                            (epochs - 1)))

    # Constant epsilon.
    epsilon_schedule = epsilon

    # Compute value and ppo losses.
    cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None
    if policy_and_value_net_apply is not None:
      logging.vlog(2, "Starting to compute P&V loss.")
      t = time.time()
      cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = (
          combined_loss(
              policy_and_value_net_params,
              policy_and_value_net_params,
              policy_and_value_net_apply,
              padded_observations,
              padded_actions,
              padded_rewards,
              reward_mask,
              gamma=gamma,
              lambda_=lambda_,
              epsilon=epsilon_schedule,
              c1=c1,
              c2=c2))
      logging.vlog(
          1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.",
          cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t))
    else:
      t = time.time()
      cur_value_loss = value_loss(
          value_net_apply,
          value_net_params,
          padded_observations,
          padded_rewards,
          reward_mask,
          gamma=gamma)

      logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t))

      t = time.time()
      cur_ppo_loss = ppo_loss(
          policy_net_apply,
          policy_net_params,
          policy_net_params,
          value_net_apply,
          value_net_params,
          padded_observations,
          padded_actions,
          padded_rewards,
          reward_mask,
          gamma=gamma,
          lambda_=lambda_,
          epsilon=epsilon_schedule)
      logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t))

    value_losses.append(cur_value_loss)
    ppo_objective.append(-1.0 * cur_ppo_loss)
    if cur_combined_loss:
      combined_losses.append(cur_combined_loss)

    if policy_and_value_net_apply:
      logging.vlog(1, "Policy and Value Optimization")
      t1 = time.time()
      for j in range(num_optimizer_steps):
        t = time.time()
        # Update the optimizer state.
        policy_and_value_opt_state = policy_and_value_opt_step(
            j,
            policy_and_value_opt_state,
            policy_and_value_opt_update,
            policy_and_value_net_apply,
            # for the entirety of this loop, this should refer to params that
            # were used to collect the trajectory.
            policy_and_value_net_params,
            padded_observations,
            padded_actions,
            padded_rewards,
            reward_mask,
            c1=c1,
            c2=c2,
            gamma=gamma,
            lambda_=lambda_,
            epsilon=epsilon_schedule)
        t2 = time.time()
        if ((j + 1) %
            print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1):
          # Compute and log the loss.
          # Get the new params.
          new_policy_and_value_net_params = trax_opt.get_params(
              policy_and_value_opt_state)
          (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = (
              combined_loss(
                  new_policy_and_value_net_params,
                  # old params, that were used to collect the trajectory
                  policy_and_value_net_params,
                  policy_and_value_net_apply,
                  padded_observations,
                  padded_actions,
                  padded_rewards,
                  reward_mask,
                  gamma=gamma,
                  lambda_=lambda_,
                  epsilon=epsilon_schedule,
                  c1=c1,
                  c2=c2))
          logging.vlog(1, "One Policy and Value grad desc took: %0.2f msec",
                       get_time(t, t2))
          logging.vlog(
              1,
              "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]",
              cur_combined_loss, loss_combined, loss_value, loss_ppo)

      # Update the params.
      policy_and_value_net_params = new_policy_and_value_net_params

      logging.vlog(
          1, "Total Combined Loss reduction [%0.2f]%%",
          (100 *
           (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

      logging.info(
          "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined"
          " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i,
          min_reward, max_reward, avg_reward, loss_combined, loss_value,
          loss_ppo, get_time(t1))
    else:
      # Run optimizers.
      logging.vlog(1, "PPO Optimization")
      t1 = time.time()

      for j in range(policy_only_num_optimizer_steps):
        t = time.time()
        # Update the optimizer state.
        ppo_opt_state = ppo_opt_step(
            j,
            ppo_opt_state,
            ppo_opt_update,
            policy_net_apply,
            policy_net_params,
            value_net_apply,
            value_net_params,
            padded_observations,
            padded_actions,
            padded_rewards,
            reward_mask,
            gamma=gamma,
            lambda_=lambda_,
            epsilon=epsilon_schedule,
        )
        t2 = time.time()
        # Get the new params.
        new_policy_net_params = trax_opt.get_params(ppo_opt_state)

        # These are the "old" params - policy_net_params

        # Compute the approx KL for early stopping.
        log_probab_actions_old = policy_net_apply(padded_observations,
                                                  policy_net_params)
        log_probab_actions_new = policy_net_apply(padded_observations,
                                                  new_policy_net_params)

        approx_kl = np.mean(log_probab_actions_old - log_probab_actions_new)

        early_stopping = approx_kl > 1.5 * target_kl
        if early_stopping:
          logging.vlog(
              1, "Early stopping policy optimization at iter: %d, "
              "with approx_kl: %0.2f", j, approx_kl)
          # We don't return right-away, we want the below to execute on the last
          # iteration.

        if (((j + 1) % print_every_optimizer_steps == 0) or
            (j == num_optimizer_steps - 1) or early_stopping):
          new_ppo_loss = ppo_loss(
              policy_net_apply,
              new_policy_net_params,
              policy_net_params,
              value_net_apply,
              value_net_params,
              padded_observations,
              padded_actions,
              padded_rewards,
              reward_mask,
              gamma=gamma,
              lambda_=lambda_,
              epsilon=epsilon_schedule,
          )
          logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2))
          logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss,
                       new_ppo_loss)

        if early_stopping:
          break

      # Update the params ONLY AND ONLY AFTER we complete all the optimization
      # iterations, till then `policy_net_params` should refer to the params
      # that were used in collecting the policy.
      # policy_net_params = trax_opt.get_params(ppo_opt_state)

      logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                   (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

      logging.vlog(1, "Value Optimization")

      for j in range(value_only_num_optimizer_steps):
        t = time.time()
        value_opt_state = value_opt_step(
            j,
            value_opt_state,
            value_opt_update,
            value_net_apply,
            padded_observations,
            padded_rewards,
            reward_mask,
            gamma=gamma)
        t2 = time.time()
        value_net_params = trax_opt.get_params(value_opt_state)
        if ((j + 1) %
            print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1):
          new_value_loss = value_loss(
              value_net_apply,
              value_net_params,
              padded_observations,
              padded_rewards,
              reward_mask,
              gamma=gamma)
          logging.vlog(1, "One value grad desc took: %0.2f msec",
                       get_time(t, t2))
          logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss,
                       new_value_loss)
      logging.vlog(1, "Total value loss reduction [%0.2f]%%",
                   (100 *
                    (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

      logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

      # Set the optimized params to new params.
      policy_net_params = trax_opt.get_params(ppo_opt_state)
      value_net_params = trax_opt.get_params(value_opt_state)

      logging.info(
          "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
          "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i,
          min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss,
          get_time(t0))

  # Log the parameters, just for the sake of it.
  if policy_net_params:
    log_params(policy_net_params, "policy_net_params")
  if value_net_params:
    log_params(value_net_params, "value_net_params")
  if policy_and_value_net_params:
    log_params(policy_and_value_net_params, "policy_and_value_net_params")

  if value_losses:
    logging.vlog(1, "value_losses: %s", np.stack(value_losses))
  if ppo_objective:
    logging.vlog(1, "ppo_objective:\n%s", np.stack(ppo_objective))
  if combined_losses:
    logging.vlog(1, "combined_losses:\n%s", np.stack(combined_losses))
  if average_rewards:
    logging.vlog(1, "average_rewards:\n%s", average_rewards)

  return ((policy_net_params, value_net_params), average_rewards,
          np.stack(value_losses), np.stack(ppo_objective))