def reset(self, output_dir): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. """ self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, "eval")) # Reset the train and eval streams. self._train_stream = self._inputs.train_stream() # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream(self._inputs.eval_stream) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream) # Restore the training state. state = restore_state(output_dir) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._initialize() model_state = layers.nested_map(model_state, self._maybe_replicate) self._opt_state = OptState( *layers.nested_map(opt_state, self._maybe_replicate)) self._model_state = model_state if not state.opt_state: self._maybe_save_state(keep=False) self.update_optimizer_params()
def __init__( self, train_env, eval_env, output_dir, policy_trainer_class, n_real_epochs=10, data_eval_frac=0.125, model_train_batch_size=64, n_model_train_steps=1000, simulated_env_problem_class=( simulated_env_problem.SerializedSequenceSimulatedEnvProblem), simulated_batch_size=16, n_simulated_epochs=1000, trajectory_dump_dir=None, initial_trajectory_dir=None, initial_trajectory_mix_prob=0.5, **kwargs): super(SimPLe, self).__init__(train_env, eval_env, output_dir, **kwargs) self._policy_dir = os.path.join(output_dir, "policy") self._policy_trainer = policy_trainer_class( train_env=train_env, eval_env=eval_env, output_dir=self._policy_dir, async_mode=self._async_mode, async_mode_trajectory_subdir=self._async_mode_trajectory_subdir, ) self._n_real_epochs = n_real_epochs self._model_train_batch_size = model_train_batch_size self._n_model_train_steps = n_model_train_steps self._data_eval_frac = data_eval_frac self._model_dir = os.path.join(output_dir, "model") self._sim_env = simulated_env_problem_class( batch_size=None, observation_space=train_env.observation_space, action_space=train_env.action_space, reward_range=train_env.reward_range, discrete_rewards=train_env.discrete_rewards, history_stream=None, # TODO(pkozakowski): Support this. output_dir=self._model_dir, ) self._simulated_batch_size = simulated_batch_size self._n_simulated_epochs = n_simulated_epochs # If trajectory_dump_dir is not provided explicitly, save the trajectories # in output_dir. if trajectory_dump_dir is None: trajectory_dump_dir = os.path.join(output_dir, "trajectories") self._trajectory_dump_root_dir = trajectory_dump_dir self._initial_trajectory_dir = initial_trajectory_dir self._initial_trajectory_mix_prob = initial_trajectory_mix_prob self._summary_writer = jaxboard.SummaryWriter(self._output_dir) self._simple_epoch = 0 self._policy_epoch = 0 self._model_train_step = 0
def reset(self, output_dir): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. """ self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, "eval")) # Reset the training stream. self._train_stream = self._inputs.train_stream() # Restore the training state. state = restore_state(output_dir) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._initialize() model_state = layers.nested_map(model_state, self._maybe_replicate) self._opt_state = OptState( *layers.nested_map(opt_state, self._maybe_replicate)) self._model_state = model_state if not state.opt_state: self._maybe_save_state(keep=False) self.update_learning_rate()
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures assert isinstance(self.train_env.action_space, gym.spaces.Discrete) n_actions = self.train_env.action_space.n # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + self.train_env.observation_space.shape observations_dtype = self.train_env.observation_space.dtype self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) # Initialize the policy and value network. policy_and_value_net_params, self._model_state, policy_and_value_net_apply = ( ppo.policy_and_value_net( rng_key=key1, batch_observations_shape=batch_observations_shape, observations_dtype=observations_dtype, n_actions=n_actions, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, )) self._policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Maybe restore the optimization state. If there is nothing to restore, then # iteration = 0 and policy_and_value_opt_state is returned as is. (restored, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state( output_dir, policy_and_value_opt_state, self._model_state) if restored: logging.info("Restored parameters from iteration [%d]", self._epoch) # We should start from the next iteration. self._epoch += 1 # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, policy_and_value_vocab_size=None, n_optimizer_steps=N_OPTIMIZER_STEPS, optimizer_batch_size=64, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=100, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, save_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), separate_eval=True, init_policy_from_world_model_output_dir=None, **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. policy_and_value_vocab_size: Vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. n_optimizer_steps: Number of optimizer steps. optimizer_batch_size: Batch size of an optimizer step. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. save_every_n: How frequently to save the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. separate_eval: Whether to run separate evaluation using a set of temperatures. If False, the training reward is reported as evaluation reward with temperature 1.0. init_policy_from_world_model_output_dir: Model output dir for initializing the policy. If None, initialize randomly. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._optimizer_batch_size = optimizer_batch_size self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._save_every_n = save_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures self._separate_eval = separate_eval action_space = self.train_env.action_space assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)) if isinstance(action_space, gym.spaces.Discrete): n_actions = action_space.n n_controls = 1 else: (n_controls, ) = action_space.nvec.shape assert n_controls > 0 assert onp.min(action_space.nvec) == onp.max(action_space.nvec), ( "Every control must have the same number of actions.") n_actions = action_space.nvec[0] self._n_actions = n_actions self._n_controls = n_controls self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) vocab_size = policy_and_value_vocab_size self._serialized_sequence_policy = vocab_size is not None if self._serialized_sequence_policy: self._serialization_kwargs = self._init_serialization(vocab_size) else: self._serialization_kwargs = {} # Initialize the policy and value network. policy_and_value_net = ppo.policy_and_value_net( n_actions=n_actions, n_controls=n_controls, vocab_size=vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) self._policy_and_value_net_apply = jit(policy_and_value_net) (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype policy_and_value_net_params, self._model_state = ( policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype, key1)) if init_policy_from_world_model_output_dir is not None: policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( policy_and_value_net_params, init_policy_from_world_model_output_dir) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Restore the optimizer state. self._policy_and_value_opt_state = policy_and_value_opt_state self._epoch = 0 self._total_opt_step = 0 self.update_optimization_state( output_dir, policy_and_value_opt_state=policy_and_value_opt_state) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0 if self._async_mode: logging.info( "Saving model on startup to have a model policy file.") self.save() self._rewards_to_actions = self._init_rewards_to_actions()
def training_loop( env=None, epochs=EPOCHS, policy_and_value_net_fn=None, policy_and_value_optimizer_fn=None, batch_size=BATCH_TRAJECTORIES, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, output_dir=None, eval_every_n=1000, eval_env=None, done_frac_for_policy_save=0.5, enable_early_stopping=True, env_name=None, n_evals=1, len_history_for_policy=4, ): """Runs the training loop for PPO, with fixed policy and value nets.""" assert env assert output_dir assert env_name gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) avg_reward, avg_reward_unclipped = evaluate_policy( eval_env, get_predictions, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) for k, v in avg_reward.items(): eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f", i, k, v) for k, v in avg_reward_unclipped.items(): eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f", i, k, v) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/mean_reward", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Calculate log-probabilities and value predictions of the trajectories. # We'll pass these to the loss functions so as to not get recomputed. # NOTE: # There is a slight problem here, if the policy network contains # stochasticity in the log-probabilities (ex: dropout), then calculating # these again here is not going to be correct and should be done in the # collect function. log_prob_recompute_start_time = time.time() jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (((n_trajectories_done >= done_frac_for_policy_save * batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=None, n_devices=None, save_steps=None): if save_steps is None: save_steps = [] self._save_steps = save_steps device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) # Create input streams. inputs = inputs(n_devices) self._inputs = inputs self._train_stream = inputs.train_stream() # Setup optimizer and model. state = restore_state(output_dir) history = state.history self._lr_fn = lr_schedule(history) opt = optimizer(self._lr_fn) model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. step = state.step or 0 rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map( model_input_shape, lambda x: x if x else 1) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize( model_input_shape, inputs.input_dtype, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn( model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._step = step self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._optimizer = optimizer self._opt_state = opt_state self._history = history self._lr_schedule = lr_schedule
def train(output_dir, model=gin.REQUIRED, loss_fn=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, n_devices=None, random_seed=None, run_debug_step=False, save_graphs=True, save_backward_graph=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. n_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. Returns: trax.State """ if save_steps is None: save_steps = [] device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(n_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fn = lr_schedule(history) opt = optimizer(lr_fn) model_train = layers.Serial(model(mode="train")) model_predict_eval = layers.Serial(model(mode="eval")) # Setup state step = state.step or 0 rng, init_rng = jax_random.split(rng) rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [-1] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( [tuple([-1] + list(shape)) for shape in inputs.input_shape]) else: # Otherwise just add [-1] to the input shape. model_input_shape = tuple([-1] + list(inputs.input_shape)) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize(model_input_shape, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % n_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fn(params, next(train_stream), model_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if n_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, n_devices) opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs) step += 1 if step in save_steps: _save_replicated(opt_state, step, history, n_devices, output_dir, True) # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fn(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Print number of parameters if step == 1: sizes = layers.sizes(opt_state[0]) if n_devices > 1: unreplicate = lambda x: x.mean(0) single_params = layers.nested_map(opt_state[0], unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size) # Evaluate in parallel evaluate_train_and_eval( step=step, inputs=inputs, predict_fn=functools.partial(jit_model_predict_eval, params=opt_state[0]), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save computation graph (single-device only for now). if save_graphs and step == 1 and n_devices == 1: params = opt_state[0] # Dump computation graphs to files. forward_computation = jax.xla_computation(model_predict_eval)( next_train_batch[0], params=params, rng=rng) with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(jit_update_fn)( step, opt_state, next_train_batch, rngs) with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: f.write(backward_computation.GetHloDotGraph()) # Save state _save_replicated(opt_state, step, history, n_devices, output_dir, False) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fn = lr_fn lr_fn = lr_schedule(history) if lr_fn != old_lr_fn: # For performance, only jit if there is a change. opt = optimizer(lr_fn) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=opt_state, step=step, history=history)
def train(output_dir, model=gin.REQUIRED, loss_fun=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, num_devices=None, random_seed=None, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. num_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ num_devices = num_devices or jax.lib.xla_bridge.device_count() rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(num_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict_train = model(mode="train") _, model_predict_eval = model(mode="eval") # Setup state step = state.step or 0 rng, init_key = jax_random.split(rng) params_initializer = \ lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when pmap is stable. opt_state = jax.replicate(opt_state) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fun(model_predict_eval, num_devices) jit_update_fun = _jit_update_fun( model_predict_train, loss_fun, optimizer, lr_fun, num_devices) print() train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % num_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fun(params, next(train_stream), model_predict_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device_pair(next_train_batch, num_devices) rng, subrng = jax_random.split(rng) opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate params = trax_opt.get_params(opt_state) evaluate_train_and_eval( step=step, inputs=inputs, predict_fun=functools.partial(jit_model_predict_eval, params), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun( model_predict_train, loss_fun, optimizer, lr_fun, num_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
def train(output_dir, model=gin.REQUIRED, inputs=gin.REQUIRED, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ rng = random.PRNGKey(0) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs() # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict_original = model() # We need a model_predict that fills in the random generator if needed. def model_predict(x, y, **kwargs): """Same as model_predict_original but fill in rng if it isn't passed.""" if "rng" in kwargs: return model_predict_original(x, y, **kwargs) return model_predict_original(x, y, rng=rng, **kwargs) # Setup state step = state.step or 0 params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) # jit model_predict and update so they're fast jit_model_predict = jax.jit(model_predict) # for evaluation jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun) print() train_stream = inputs.train_stream() epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) step_log(step, "Starting training") # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss(params, next(train_stream), model_predict) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train opt_state = jit_update_fun(step, opt_state, next(train_stream)) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log( step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate params = jax_opt.get_params(opt_state) evaluate_train_and_eval(step=step, inputs=inputs, predict_fun=functools.partial( jit_model_predict, params), eval_steps=eval_steps, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun) # Flush summary writers train_sw.writer.flush() eval_sw.writer.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
def train(output_dir, model=gin.REQUIRED, inputs=gin.REQUIRED, optimizer=trax_opt.adam, train_steps=1000, eval_steps=10, eval_frequency=100): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. Returns: trax.State """ gfile.makedirs(output_dir) inputs = inputs() # Setup optimizer and model opt_init, opt_update = optimizer(learning_rate) model_init, model_predict = model() # Setup state state = restore_state(output_dir) step = state.step or 0 params_initializer = lambda: model_init([-1] + inputs.input_shape)[1] opt_state = opt_init(state.params or params_initializer()) # Create summary writers. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) # jit model_predict and update so they're fast jit_predict = jax.jit(model_predict) # for evaluation @jax.jit def update(i, opt_state, batch): params = jax_opt.get_params(opt_state) return opt_update(i, jax.grad(loss)(params, batch, model_predict), opt_state) print() step_log(step, "starting training") inputs_stream = inputs.train_fn() eval_enabled = eval_steps and eval_frequency is_first_step = True # Evaluate after the first training step, then reset to normal_epoch_steps normal_epoch_steps = (eval_enabled and eval_frequency) or train_steps epoch_steps = 1 while step < train_steps: print() # separate logging for each loop iteration # Train start_time = time.time() for _ in range(epoch_steps): opt_state = update(step, opt_state, next(inputs_stream)) if step % 10 == 0: # Log learning rate curve each 10 steps. train_sw.scalar("training/learning rate", learning_rate(step), step=step) step += 1 epoch_time = time.time() - start_time step_log( step, "ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) # Save state params = jax_opt.get_params(opt_state) save_state(State(params=params, step=step), output_dir) # Evaluate if eval_enabled: step_log(step, "starting evaluation") train_metrics, eval_metrics = evaluate( inputs, functools.partial(jit_predict, params), eval_steps) log_metrics(train_metrics, train_sw, "train", step) log_metrics(eval_metrics, eval_sw, "eval ", step) eval_sw.writer.flush() # Gin only tracks the used parameters, so we save it after the first step. if is_first_step: save_gin(output_dir, train_sw) # Log non-metric reports. if not is_first_step: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) train_sw.writer.flush() # After the first step, train for normal_epoch_steps steps before evaluating epoch_steps = ((normal_epoch_steps - 1) if is_first_step else normal_epoch_steps) is_first_step = False print() step_log(step, "finished training") return State(params=params, step=step)
def train(output_dir, data_dir, model=gin.REQUIRED, dataset=gin.REQUIRED, optimizer=trax_opt.adam, train_steps=1000, eval_steps=10, eval_frequency=100): """Train the given model on the given dataset. Args: output_dir: Directory where to put the logs and checkpoints. data_dir: Directory where the data is located. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. dataset: The name of the TFDS dataset to train on. To train on a T2T dataset, prefix the name with "t2t_". optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. eval_frequency: int, how often to run evaluation (every eval_frequency steps). """ gfile.makedirs(output_dir) # Make Inputs inputs = inputs_lib.make_inputs(dataset, data_dir) # Setup optimizer and model opt_init, opt_update = optimizer(learning_rate) model_init, model_predict = model() # Setup state state = restore_state(output_dir) step = state.step or 0 params_initializer = lambda: model_init([-1] + inputs.input_shape)[1] opt_state = opt_init(state.params or params_initializer()) # Create summary writers. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) # jit model_predict and update so they're fast jit_predict = jax.jit(model_predict) # for evaluation @jax.jit def update(i, opt_state, batch): params = jax_opt.get_params(opt_state) return opt_update(i, jax.grad(loss)( params, batch, model_predict), opt_state) print() step_log(step, "starting training") inputs_stream = inputs.train_fn() is_first_step = True epoch_steps = 1 # First evaluation after the first training step. while step < train_steps: print() # Train start_time = time.time() for _ in range(epoch_steps): opt_state = update(step, opt_state, next(inputs_stream)) if step % 10 == 0: # Log learning rate curve each 10 steps. train_sw.scalar("training/learning rate", learning_rate(step), step=step) step += 1 epoch_time = time.time() - start_time step_log(step, "ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) # Save state params = jax_opt.get_params(opt_state) save_state(State(params=params, step=step), output_dir, save_gin=is_first_step) # Evaluate step_log(step, "starting evaluation") train_metrics, eval_metrics = evaluate( inputs, functools.partial(jit_predict, params), eval_steps) log_metrics(train_metrics, train_sw, "train", step) log_metrics(eval_metrics, eval_sw, "eval ", step) # Log non-metric reports and flush. if not is_first_step: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) train_sw.writer.flush() eval_sw.writer.flush() # After the first step, train for eval_frequency steps before evaluating epoch_steps = (eval_frequency - 1) if is_first_step else eval_frequency is_first_step = False print() step_log(step, "finished training")
def train(output_dir, model=gin.REQUIRED, inputs=gin.REQUIRED, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. Returns: trax.State """ gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs() # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, opt_update = optimizer(lr_fun) model_init, model_predict = model() # Setup state step = state.step or 0 params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) # jit model_predict and update so they're fast jit_predict = jax.jit(model_predict) # for evaluation @jax.jit def update(i, opt_state, batch): params = jax_opt.get_params(opt_state) return opt_update(i, jax.grad(loss)(params, batch, model_predict), opt_state) print() train_stream = inputs.train_stream() epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) step_log(step, "Starting training") for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train opt_state = update(step, opt_state, next(train_stream)) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log( step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate params = jax_opt.get_params(opt_state) evaluate_train_and_eval(step=step, inputs=inputs, predict_fun=functools.partial( jit_predict, params), eval_steps=eval_steps, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Flush summary writers train_sw.writer.flush() eval_sw.writer.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)
def training_loop( env, eval_env, env_name, policy_and_value_net_fn, policy_and_value_optimizer_fn, output_dir, epochs=EPOCHS, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, enable_early_stopping=True, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), ): """Runs the training loop for PPO, with fixed policy and value nets. Args: env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. env_name: Name of the environment. policy_and_value_net_fn: Function defining the policy and value network. policy_and_value_optimizer_fn: Function defining the optimizer. output_dir: Output dir. epochs: Number of epochs to run for. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. epsilon: Random action probability in epsilon-greedy sampling. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. enable_early_stopping: Whether to enable early stopping. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. """ gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) reward_stats = evaluate_policy( eval_env, get_predictions, temperatures=eval_temperatures, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) write_eval_reward_summaries(reward_stats, eval_sw, epoch=i) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=env.batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards, padded_infos) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape log_prob_recompute_start_time = time.time() assert ("log_prob_actions" in padded_infos and "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. actual_log_probabs_traj = padded_infos["log_prob_actions"] actual_value_predictions_traj = padded_infos["value_predictions"] assert (B, T) == actual_log_probabs_traj.shape[:2] A = actual_log_probabs_traj.shape[2] # pylint: disable=invalid-name assert (B, T, 1) == actual_value_predictions_traj.shape # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with # (B, T, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) assert (B, T + 1, A) == log_probabs_traj.shape assert (B, T + 1, 1) == value_predictions_traj.shape # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. log_probabs_traj = np.concatenate( (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) value_predictions_traj = np.concatenate( (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), axis=1) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (( (n_trajectories_done >= done_frac_for_policy_save * env.batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def __init__( self, train_env, eval_env, policy_and_value_model, policy_and_value_optimizer_fn, policy_and_value_two_towers, output_dir, n_optimizer_steps, print_every_optimizer_steps, target_kl, boundary, max_timestep, max_timestep_eval, random_seed, gamma, lambda_, c1, c2, eval_every_n, done_frac_for_policy_save, n_evals, len_history_for_policy, eval_temperatures, ): self._train_env = train_env self._eval_env = eval_env self._n_optimizer_steps = n_optimizer_steps self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures assert isinstance(self._train_env.action_space, gym.spaces.Discrete) n_actions = self._train_env.action_space.n # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + self._train_env.observation_space.shape observations_dtype = self._train_env.observation_space.dtype self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net( rng_key=key1, batch_observations_shape=batch_observations_shape, observations_dtype=observations_dtype, n_actions=n_actions, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) ) self._policy_and_value_net_apply = jit(policy_and_value_net_apply) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restored, policy_and_value_net_params, self._epoch = ( maybe_restore_params(output_dir, policy_and_value_net_params)) if restored: logging.info("Restored parameters from iteration [%d]", self._epoch) # We should start from the next iteration. self._epoch += 1 # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = policy_and_value_optimizer self._output_dir = output_dir gfile.makedirs(self._output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._should_reset = True self._n_trajectories_done = 0 self._last_saved_at = 0
def __init__( self, train_env, eval_env, output_dir, policy_trainer_class, n_real_epochs=10, data_eval_frac=0.125, model_train_batch_size=64, n_model_initial_train_steps=1000, n_model_train_steps_per_epoch=1000, simulated_env_problem_class=( simulated_env_problem.SerializedSequenceSimulatedEnvProblem), simulated_batch_size=16, n_simulated_epochs=1000, trajectory_dump_dir=None, initial_trajectory_dir=None, initial_trajectory_mix_prob=0.5, initial_model=None, init_policy_from_world_model=False, **kwargs): super(SimPLe, self).__init__(train_env, eval_env, output_dir, **kwargs) self._policy_dir = os.path.join(output_dir, "policy") self._model_dir = os.path.join(output_dir, "model") # Initialize the policy trainer lazily, so in case of initializing the # policy from world model checkpoint, the trainer will try to load the # checkpoint _after_ it's been created in train_model(). self._policy_trainer_fn = functools.partial( policy_trainer_class, train_env=train_env, eval_env=eval_env, output_dir=self._policy_dir, async_mode=self._async_mode, init_policy_from_world_model_output_dir=( self._model_dir if init_policy_from_world_model else None), ) self._policy_trainer = None self._n_real_epochs = n_real_epochs self._model_train_batch_size = model_train_batch_size self._n_model_initial_train_steps = n_model_initial_train_steps self._n_model_train_steps_per_epoch = n_model_train_steps_per_epoch self._data_eval_frac = data_eval_frac gfile.makedirs(self._model_dir) if initial_model is not None: gfile.copy( initial_model, os.path.join(self._model_dir, "model.pkl"), overwrite=True, ) self._initial_model = initial_model self._initial_trajectories = None self._sim_env = simulated_env_problem_class( batch_size=None, observation_space=train_env.observation_space, action_space=train_env.action_space, reward_range=train_env.reward_range, discrete_rewards=train_env.discrete_rewards, history_stream=None, # TODO(pkozakowski): Support this. output_dir=self._model_dir, ) self._simulated_batch_size = simulated_batch_size self._n_simulated_epochs = n_simulated_epochs # If trajectory_dump_dir is not provided explicitly, save the trajectories # in output_dir. if trajectory_dump_dir is None: trajectory_dump_dir = os.path.join(output_dir, "trajectories") self._trajectory_dump_root_dir = trajectory_dump_dir self._initial_trajectory_dir = initial_trajectory_dir self._initial_trajectory_mix_prob = initial_trajectory_mix_prob self._summary_writer = jaxboard.SummaryWriter(self._output_dir) self._simple_epoch = 0 self._policy_epoch = 0 self._model_train_step = 0