def init_policy_from_world_model_checkpoint(policy_params, model_output_dir): """Initializes policy parameters from world model parameters.""" pkl_module = utils.get_pickle_module() params_file = os.path.join(model_output_dir, "model.pkl") # Don't use trax.restore_state to avoid a circular import. with gfile.GFile(params_file, "rb") as f: model_params = pkl_module.load(f)[0][0] # TODO(pkozakowski): The following, brittle line of code is hardcoded for # transplanting parameters from TransformerLM to TransformerDecoder-based # policy network of the same configuration. Figure out a more general method. policy_params[0] = model_params[0][1:-2] return policy_params
def init_policy_from_world_model_checkpoint(policy_weights, model_output_dir, substitute_fn): """Initializes policy parameters from world model parameters.""" pkl_module = utils.get_pickle_module() weights_file = os.path.join(model_output_dir, 'model.pkl') # Don't use trax.load_trainer_state to avoid a circular import. with tf.io.gfile.GFile(weights_file, 'rb') as f: model_weights = pkl_module.load(f)['weights'] model_weights = serialization_utils.extract_inner_model(model_weights) # TODO(pkozakowski): The following, brittle line of code is hardcoded for # transplanting parameters from TransformerLM to TransformerDecoder-based # policy network of the same configuration. Figure out a more general method. return substitute_fn(policy_weights, model_weights[1:-2])
def save_trainer_state(state, output_dir, keep=False): """Saves a TrainerState instance to the given `output_dir`.""" pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(state.opt_state), state.step, state.history, state.model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(state.step)) with gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(state.opt_state), state.step, state.history, state.model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def save_opt_state(output_dir, policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step): """Saves the policy and value network optimization state etc.""" pkl_module = utils.get_pickle_module() old_model_files = get_policy_model_files(output_dir) params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch) with gfile.GFile(params_file, "wb") as f: pkl_module.dump((policy_and_value_opt_state, policy_and_value_state, total_opt_step), f) # Keep the last k model files lying around (note k > 1 because the latest # model file might be in the process of getting read async). for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]: if path != params_file: gfile.remove(path)
def load_trainer_state(output_dir): """Returns a TrainerState instance loaded from the given `output_dir`.""" weights_file = os.path.join(output_dir, 'model.pkl') if not gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) pkl_module = utils.get_pickle_module() with gfile.GFile(weights_file, 'rb') as f: (opt_state, step, history, model_state) = pkl_module.load(f) log('Model loaded from %s at step %d' % (weights_file, step)) logging.debug('From loaded model : history = %s', history) return TrainerState(step=step, opt_state=OptState(*opt_state), history=history, model_state=model_state)
def dump_trajectories(self, force=False): """Dumps trajectories in a new shard. Should be called at most once per epoch. Args: force: (bool) Whether to complete unfinished trajectories and create a new shard even if we have not reached the minimum size. """ pkl_module = utils.get_pickle_module() if self.trajectory_dump_dir is None: return gfile.makedirs(self.trajectory_dump_dir) trajectories = self.train_env.trajectories if force: trajectories.complete_all_trajectories() # complete_all_trajectories() also adds trajectories that were just reset. # We don't want them since they have just the initial observation and no # actions, so we filter them out. def has_any_action(trajectory): return (trajectory.time_steps and trajectory.time_steps[0].action is not None) self._trajectory_buffer.extend( filter(has_any_action, trajectories.completed_trajectories)) trajectories.clear_completed_trajectories() ready = (len(self._trajectory_buffer) >= self._trajectory_dump_min_count_per_shard) if ready or force: shard_path = os.path.join(self.trajectory_dump_dir, "{}.pkl".format(self.epoch)) if gfile.exists(shard_path): # Since we do an extra dump at the end of the training loop, we # sometimes dump 2 times in the same epoch. When this happens, merge the # two sets of trajectories. with gfile.GFile(shard_path, "rb") as f: self._trajectory_buffer = pkl_module.load( f) + self._trajectory_buffer with gfile.GFile(shard_path, "wb") as f: pkl_module.dump(self._trajectory_buffer, f) self._trajectory_buffer = []
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state=None, policy_and_value_state=None): """Maybe restore the optimization state from the checkpoint dir. Optimization state includes parameters and optimizer slots. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_opt_state: Default optimization state, returned if model isn't found. policy_and_value_state: state of the policy and value network. Returns: tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the epoch from which we restored the optimization state, 0 if no checkpoint was found, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ pkl_module = utils.get_pickle_module() epoch = 0 total_opt_step = 0 history = trax_history.History() for model_file in get_policy_model_files(output_dir): logging.info('Trying to restore model from %s', model_file) try: with tf.io.gfile.GFile(model_file, 'rb') as f: (policy_and_value_opt_state, policy_and_value_state, total_opt_step, history) = pkl_module.load(f) epoch = get_epoch_from_policy_model_file(model_file) break except EOFError as e: logging.error('Unable to load model from: %s with %s', model_file, e) # Try an older version. continue return ( policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step, history, )
def save_state(self, keep): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*backend.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if backend.get_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def load_trajectories(trajectory_dir, eval_frac): """Loads trajectories from a possibly nested directory of pickles.""" pkl_module = utils.get_pickle_module() train_trajectories = [] eval_trajectories = [] # Search the entire directory subtree for trajectories. for (subdir, _, filenames) in tf.io.gfile.walk(trajectory_dir): for filename in filenames: shard_path = os.path.join(subdir, filename) try: with tf.io.gfile.GFile(shard_path, 'rb') as f: trajectories = pkl_module.load(f) pivot = int(len(trajectories) * (1 - eval_frac)) train_trajectories.extend(trajectories[:pivot]) eval_trajectories.extend(trajectories[pivot:]) except EOFError: logging.warning( 'Could not load trajectories from a corrupted shard %s.', shard_path, ) assert train_trajectories, "Can't find training data in %s" % trajectory_dir assert eval_trajectories, "Can't find evaluation data in %s" % trajectory_dir return train_trajectories, eval_trajectories
def _dump_trajectory_pickle(self, observations, path): pkl_module = utils.get_pickle_module() trajectories = list(map(self._make_singleton_trajectory, observations)) with gfile.GFile(path, 'wb') as f: pkl_module.dump(trajectories, f)