def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def log_video(writer, video, tb_key, name, step, work_unit_dir, save_raw=False, scale=False): """Save video frames to tensorboard and a file.""" video_raw = video if scale: video = scale_depth(video) if writer is not None: logging.info('Logging video frames') writer.write_images(step, {f'{tb_key}/{name}': make_image_grid(video)}) filename = f'{tb_key}_{name}_{step:05d}.mp4' local_path = os.path.join('/tmp', filename) logging.info('Writing video to %s', local_path) media.write_video(local_path, video, fps=30) wu_path = os.path.join(work_unit_dir, filename) logging.info('Copying video to %s', wu_path) gfile.copy(local_path, wu_path, overwrite=True) gfile.remove(local_path) if save_raw: # save raw floating point values to scale depth properly raw_filename = f'{tb_key}_{name}_{step:05d}.npy' raw_path = os.path.join(work_unit_dir, raw_filename) logging.info('Saving raw video to %s', raw_path) with gfile.GFile(raw_path, 'wb') as raw_f: onp.save(raw_f, video_raw) logging.info('Done logging video.')
def training_loop( env=None, epochs=EPOCHS, policy_and_value_net_fn=None, policy_and_value_optimizer_fn=None, batch_size=BATCH_TRAJECTORIES, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, output_dir=None, eval_every_n=1000, eval_env=None, done_frac_for_policy_save=0.5, enable_early_stopping=True, env_name=None, n_evals=1, len_history_for_policy=4, ): """Runs the training loop for PPO, with fixed policy and value nets.""" assert env assert output_dir assert env_name gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) avg_reward, avg_reward_unclipped = evaluate_policy( eval_env, get_predictions, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) for k, v in avg_reward.items(): eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f", i, k, v) for k, v in avg_reward_unclipped.items(): eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f", i, k, v) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/mean_reward", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Calculate log-probabilities and value predictions of the trajectories. # We'll pass these to the loss functions so as to not get recomputed. # NOTE: # There is a slight problem here, if the policy network contains # stochasticity in the log-probabilities (ex: dropout), then calculating # these again here is not going to be correct and should be done in the # collect function. log_prob_recompute_start_time = time.time() jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (((n_trajectories_done >= done_frac_for_policy_save * batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def save_checkpoint(ckpt_dir: Union[str, os.PathLike], target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str or pathlib-like path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). Returns: Filename of saved checkpoint. """ ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) if ckpt_dir.startswith('./'): ckpt_dir = ckpt_dir[2:] # gfile.glob() can remove leading './' ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) base_path = os.path.join(ckpt_dir, prefix) checkpoint_files = gfile.glob(base_path + '*') if ckpt_path in checkpoint_files: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) else: checkpoint_files.append(ckpt_path) checkpoint_files = natural_sort(checkpoint_files) if checkpoint_files[-1] == ckpt_tmp_path: checkpoint_files.pop(-1) if ckpt_path != checkpoint_files[-1]: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) print(ckpt_path) # Remove newer checkpoints if overwrite: ind = checkpoint_files.index(ckpt_path) + 1 newer_ckpts = checkpoint_files[ind:] checkpoint_files = checkpoint_files[:ind] for path in newer_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) # Remove old checkpoint files. if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def run(workdir, data, strategy, architecture, n_layers, n_hiddens, activation, dropout_rate, l2_penalty, w_init_name, b_init_name, optimizer_name, learning_rate, n_epochs, epochs_between_checkpoints, init_stddev, cnn_stride, reduce_learningrate=False, verbosity=0): """Runs the whole training procedure.""" data_tr, data_te, dataset_info = data n_outputs = dataset_info['num_classes'] with strategy.scope(): optimizer = tf.keras.optimizers.get(optimizer_name) optimizer.learning_rate = learning_rate w_init = tf.keras.initializers.get(w_init_name) if w_init_name.lower() in ['truncatednormal', 'randomnormal']: w_init.stddev = init_stddev b_init = tf.keras.initializers.get(b_init_name) if b_init_name.lower() in ['truncatednormal', 'randomnormal']: b_init.stddev = init_stddev w_reg = tf.keras.regularizers.l2( l2_penalty) if l2_penalty > 0 else None if architecture == 'cnn' or architecture == 'cnnbn': model = build_cnn(n_layers, n_hiddens, n_outputs, dropout_rate, activation, cnn_stride, w_reg, w_init, b_init, architecture == 'cnnbn') elif architecture == 'fcn': model = build_fcn(n_layers, n_hiddens, n_outputs, dropout_rate, activation, w_reg, w_init, b_init, False) else: assert False, 'Unknown architecture: ' % architecture model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), metrics=['accuracy', 'mse', 'sparse_categorical_crossentropy']) # force the model to set input shapes and init weights for x, _ in data_tr: model.predict(x) if verbosity: model.summary() break ckpt = tf.train.Checkpoint(step=optimizer.iterations, optimizer=optimizer, model=model) ckpt_dir = os.path.join(workdir, 'temporary-ckpt') ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) if ckpt_manager.latest_checkpoint: logging.info('restoring checkpoint: %s', ckpt_manager.latest_checkpoint) print('restoring from %s' % ckpt_manager.latest_checkpoint) with strategy.scope(): ckpt.restore(ckpt_manager.latest_checkpoint) info = restore_results( os.path.join(workdir, '.intermediate-results.json')) print(info, flush=True) else: info = { 'steps': 0, 'start_time': time.time(), 'train_loss': dict(), 'train_accuracy': dict(), 'test_loss': dict(), 'test_accuracy': dict(), } info.update(_get_workunit_params()) # Add command line parameters. logger = None starting_epoch = len(info['train_loss']) cur_epoch = starting_epoch for cur_epoch in range(starting_epoch, n_epochs): if reduce_learningrate and cur_epoch == n_epochs - (n_epochs // 10): optimizer.learning_rate = learning_rate / 10 elif reduce_learningrate and cur_epoch == n_epochs - 2: optimizer.learning_rate = learning_rate / 100 # Train until we reach the criterion or get NaNs try: # always keep checkpoints for the first few epochs # we evaluate first and train afterwards so we have the at-init data if cur_epoch < 4 or (cur_epoch % epochs_between_checkpoints) == 0: eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir) model.fit(data_tr, epochs=1, verbose=verbosity) ckpt_manager.save() store_results(info, os.path.join(workdir, '.intermediate-results.json')) dt = time.time() - info['start_time'] logging.info('epoch %d (%3.2fs)', cur_epoch, dt) except tf.errors.InvalidArgumentError as e: # We got NaN in the loss, most likely gradients resulted in NaNs logging.info(str(e)) info['status'] = 'NaN' logging.info('Stop training because NaNs encountered') break eval_model(model, data_tr, data_te, info, logger, cur_epoch + 1, workdir) store_results(info, os.path.join(workdir, 'results.json')) # we don't need the temporary checkpoints anymore gfile.rmtree(os.path.join(workdir, 'temporary-ckpt')) gfile.remove(os.path.join(workdir, '.intermediate-results.json'))
def training_loop( env, eval_env, env_name, policy_and_value_net_fn, policy_and_value_optimizer_fn, output_dir, epochs=EPOCHS, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, enable_early_stopping=True, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), ): """Runs the training loop for PPO, with fixed policy and value nets. Args: env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. env_name: Name of the environment. policy_and_value_net_fn: Function defining the policy and value network. policy_and_value_optimizer_fn: Function defining the optimizer. output_dir: Output dir. epochs: Number of epochs to run for. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. epsilon: Random action probability in epsilon-greedy sampling. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. enable_early_stopping: Whether to enable early stopping. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. """ gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) reward_stats = evaluate_policy( eval_env, get_predictions, temperatures=eval_temperatures, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) write_eval_reward_summaries(reward_stats, eval_sw, epoch=i) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=env.batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards, padded_infos) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape log_prob_recompute_start_time = time.time() assert ("log_prob_actions" in padded_infos and "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. actual_log_probabs_traj = padded_infos["log_prob_actions"] actual_value_predictions_traj = padded_infos["value_predictions"] assert (B, T) == actual_log_probabs_traj.shape[:2] A = actual_log_probabs_traj.shape[2] # pylint: disable=invalid-name assert (B, T, 1) == actual_value_predictions_traj.shape # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with # (B, T, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) assert (B, T + 1, A) == log_probabs_traj.shape assert (B, T + 1, 1) == value_predictions_traj.shape # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. log_probabs_traj = np.concatenate( (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) value_predictions_traj = np.concatenate( (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), axis=1) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (( (n_trajectories_done >= done_frac_for_policy_save * env.batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def remove_remote(filename): """ Wrapper that can remove local and remote files like `gs://...` """ # Conditional import return gfile.remove(filename)