def get_policy_output(observations): # Get the fresh params for collecting the policy. if policy_net_apply is not None: return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state)) assert policy_and_value_net_apply policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, trax_opt.get_params(policy_and_value_opt_state)) return policy_predictions
def ppo_opt_step(i, opt_state, ppo_opt_update, policy_net_apply, old_policy_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=0.99, lambda_=0.95, epsilon=0.1): """PPO optimizer step.""" new_policy_params = trax_opt.get_params(opt_state) g = grad(ppo_loss, argnums=1)(policy_net_apply, new_policy_params, old_policy_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) return ppo_opt_update(i, g, opt_state)
def single_update(i, opt_state, batch, rng): _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) return opt_update( i, backend.grad(loss_fun)(params, batch, predict_fun, rng), opt_state)
def single_update(i, opt_state, batch, rng): rng, subrng = jax_random.split(rng[0]) _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) return opt_update( i, backend.grad(loss_fun)(params, batch, predict_fun, rng), opt_state), [subrng]
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) grads = backend.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map( lambda g: lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def value_opt_step(i, opt_state, opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=0.99): """Value optimizer step.""" value_params = trax_opt.get_params(opt_state) # Note this partial application here and argnums above in ppo_opt_step. g = grad(functools.partial(value_loss, value_net_apply))(value_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) return opt_update(i, g, opt_state)
def policy_and_value_opt_step(i, opt_state, opt_update, policy_and_value_net_apply, old_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=1.0, c2=0.01, gamma=0.99, lambda_=0.95, epsilon=0.1): """Policy and Value optimizer step.""" # Combined loss function given the new params. def policy_and_value_loss(params): """Returns the combined loss given just parameters.""" (loss, _, _, _) = combined_loss( params, old_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon) return loss new_params = trax_opt.get_params(opt_state) g = grad(policy_and_value_loss)(new_params) return opt_update(i, g, opt_state)
def train(output_dir, model=gin.REQUIRED, loss_fun=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, num_devices=None, random_seed=None, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. num_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ num_devices = num_devices or jax.lib.xla_bridge.device_count() rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(num_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict_train = model(mode="train") _, model_predict_eval = model(mode="eval") # Setup state step = state.step or 0 rng, init_key = jax_random.split(rng) params_initializer = \ lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when pmap is stable. opt_state = jax.replicate(opt_state) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fun(model_predict_eval, num_devices) jit_update_fun = _jit_update_fun( model_predict_train, loss_fun, optimizer, lr_fun, num_devices) print() train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % num_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fun(params, next(train_stream), model_predict_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device_pair(next_train_batch, num_devices) rng, subrng = jax_random.split(rng) opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate params = trax_opt.get_params(opt_state) evaluate_train_and_eval( step=step, inputs=inputs, predict_fun=functools.partial(jit_model_predict_eval, params), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun( model_predict_train, loss_fun, optimizer, lr_fun, num_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
def training_loop( env=None, env_name="CartPole-v0", epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, # TODO(afrozm): Implement. policy_optimizer_fun=optimizer_fun, value_optimizer_fun=optimizer_fun, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, boundary=20, max_timestep=None, random_seed=None): """Runs the training loop for PPO, with fixed policy and value nets.""" jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] average_rewards = [] env = env if env is not None else gym.make(env_name) # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n # TODO(afrozm): Have a single net for both policy and action. assert policy_and_value_net_fun is None # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun( key2, batch_observations_shape, num_actions) # Initialize the optimizers. assert policy_optimizer_fun and value_optimizer_fun ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun(value_net_params) for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_net_apply, policy_net_params, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1))) t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) # ppo_loss = 11.00110011 logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) ppo_objective.append(-cur_ppo_loss) # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step(j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss(policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog( 1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) logging.vlog(1, "value_losses: %s", np.stack(value_losses)) logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective)) logging.vlog(1, "average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def training_loop(env=None, env_name="CartPole-v0", epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, policy_optimizer_fun=None, value_optimizer_fun=None, policy_and_value_optimizer_fun=None, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, boundary=20, max_timestep=None, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01): """Runs the training loop for PPO, with fixed policy and value nets.""" jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] combined_losses = [] average_rewards = [] env = env if env is not None else gym.make(env_name) # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n policy_and_value_net_params, policy_and_value_net_apply = None, None policy_and_value_opt_state, policy_and_value_opt_update = None, None policy_net_params, policy_net_apply = None, None value_net_params, value_net_apply = None, None if policy_and_value_net_fun is not None: jax_rng_key, subkey = jax_random.split(jax_rng_key) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fun(subkey, batch_observations_shape, num_actions)) # Initialize the optimizers. policy_and_value_opt_state, policy_and_value_opt_update = ( policy_and_value_optimizer_fun(policy_and_value_net_params)) else: # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun( key2, batch_observations_shape, num_actions) # Initialize the optimizers. ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun( value_net_params) # A function that will call the appropriate policy function with parameters. def get_policy_output(observations): if policy_net_apply is not None: assert policy_net_params return policy_net_apply(observations, policy_net_params) assert policy_and_value_net_apply and policy_and_value_net_params policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, policy_and_value_net_params) return policy_predictions for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_fun=get_policy_output, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 epsilon_schedule = epsilon if epochs == 1 else epsilon * ( 1.0 - (i / (epochs - 1))) # Compute value and ppo losses. cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None if policy_and_value_net_apply is not None: t = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = ( combined_loss(policy_and_value_net_params, policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t)) else: t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) ppo_objective.append(-1.0 * cur_ppo_loss) combined_losses.append(cur_combined_loss) if policy_and_value_net_apply: logging.vlog(1, "Policy and Value Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_net_apply, policy_and_value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) t2 = time.time() # Get the new params. new_policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = ( combined_loss( new_policy_and_value_net_params, policy_and_value_net_params, # old params policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo) # Update the params. policy_and_value_net_params = new_policy_and_value_net_params logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined" " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, get_time(t1)) else: # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step( j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss( policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.vlog( 1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=gamma) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog( 1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) # Log the parameters, just for the sake of it. if policy_net_params: log_params(policy_net_params, "policy_net_params") if value_net_params: log_params(value_net_params, "value_net_params") if policy_and_value_net_params: log_params(policy_and_value_net_params, "policy_and_value_net_params") if value_losses: logging.vlog(1, "value_losses: %s", np.stack(value_losses)) if ppo_objective: logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective)) if average_rewards: logging.vlog(1, "average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def training_loop( env=None, epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, policy_optimizer_fun=None, value_optimizer_fun=None, policy_and_value_optimizer_fun=None, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, policy_only_num_optimizer_steps=POLICY_ONLY_NUM_OPTIMIZER_STEPS, value_only_num_optimizer_steps=VALUE_ONLY_NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01): """Runs the training loop for PPO, with fixed policy and value nets.""" assert env jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] combined_losses = [] average_rewards = [] # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n policy_and_value_net_params, policy_and_value_net_apply = None, None policy_and_value_opt_state, policy_and_value_opt_update = None, None policy_net_params, policy_net_apply = None, None value_net_params, value_net_apply = None, None if policy_and_value_net_fun is not None: jax_rng_key, subkey = jax_random.split(jax_rng_key) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fun(subkey, batch_observations_shape, num_actions)) # Initialize the optimizers. policy_and_value_opt_state, policy_and_value_opt_update = ( policy_and_value_optimizer_fun(policy_and_value_net_params)) policy_and_value_net_apply = jit(policy_and_value_net_apply) else: # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun(key2, batch_observations_shape, num_actions) policy_net_apply = jit(policy_net_apply) value_net_apply = jit(value_net_apply) # Initialize the optimizers. ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun(value_net_params) # A function that will call the appropriate policy function with parameters. def get_policy_output(observations): # Get the fresh params for collecting the policy. if policy_net_apply is not None: return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state)) assert policy_and_value_net_apply policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, trax_opt.get_params(policy_and_value_opt_state)) return policy_predictions for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_fun=get_policy_output, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, boundary=boundary, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) # These were the params that were used to collect the trajectory. if policy_and_value_net_apply: policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) else: policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(2, "Rewards: %s", [float(np.sum(traj[2])) for traj in trajs]) logging.vlog(1, "Average Rewards: %s", average_rewards) logging.vlog(1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories( trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None if policy_and_value_net_apply is not None: logging.vlog(2, "Starting to compute P&V loss.") t = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = ( combined_loss( policy_and_value_net_params, policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t)) else: t = time.time() cur_value_loss = value_loss( value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) t = time.time() cur_ppo_loss = ppo_loss( policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) ppo_objective.append(-1.0 * cur_ppo_loss) if cur_combined_loss: combined_losses.append(cur_combined_loss) if policy_and_value_net_apply: logging.vlog(1, "Policy and Value Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_net_apply, # for the entirety of this loop, this should refer to params that # were used to collect the trajectory. policy_and_value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) t2 = time.time() if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): # Compute and log the loss. # Get the new params. new_policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = ( combined_loss( new_policy_and_value_net_params, # old params, that were used to collect the trajectory policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog(1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo) # Update the params. policy_and_value_net_params = new_policy_and_value_net_params logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined" " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, get_time(t1)) else: # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(policy_only_num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step( j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) # These are the "old" params - policy_net_params # Compute the approx KL for early stopping. log_probab_actions_old = policy_net_apply(padded_observations, policy_net_params) log_probab_actions_new = policy_net_apply(padded_observations, new_policy_net_params) approx_kl = np.mean(log_probab_actions_old - log_probab_actions_new) early_stopping = approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. if (((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1) or early_stopping): new_ppo_loss = ppo_loss( policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) if early_stopping: break # Update the params ONLY AND ONLY AFTER we complete all the optimization # iterations, till then `policy_net_params` should refer to the params # that were used in collecting the policy. # policy_net_params = trax_opt.get_params(ppo_opt_state) logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(value_only_num_optimizer_steps): t = time.time() value_opt_state = value_opt_step( j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=gamma) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss( value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog(1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) # Log the parameters, just for the sake of it. if policy_net_params: log_params(policy_net_params, "policy_net_params") if value_net_params: log_params(value_net_params, "value_net_params") if policy_and_value_net_params: log_params(policy_and_value_net_params, "policy_and_value_net_params") if value_losses: logging.vlog(1, "value_losses: %s", np.stack(value_losses)) if ppo_objective: logging.vlog(1, "ppo_objective:\n%s", np.stack(ppo_objective)) if combined_losses: logging.vlog(1, "combined_losses:\n%s", np.stack(combined_losses)) if average_rewards: logging.vlog(1, "average_rewards:\n%s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))