def __init__(self, r_model, observation_history_size=20000, training_interval=20000, num_epochs=6, checkpoint_dir=None): # The training interval is assumed to be the same as the history size # for invalid negative values. if training_interval < 0: training_interval = observation_history_size self._r_model = r_model self._training_interval = training_interval self._batch_size = 64 self._num_epochs = num_epochs # Keeps track of the last N observations. # Those are used to train the R network in an online way. self._fifo_observations = [None] * observation_history_size self._fifo_dones = [None] * observation_history_size self._fifo_index = 0 self._fifo_count = 0 # Used to save checkpoints. self._current_epoch = 0 self._checkpointer = None if checkpoint_dir is not None: checkpoint_period_in_epochs = self._num_epochs self._checkpointer = keras_checkpoint.GFileModelCheckpoint( os.path.join(checkpoint_dir, 'r_network_weights.{epoch:05d}.h5'), save_summary=False, save_weights_only=True, period=checkpoint_period_in_epochs) self._checkpointer.set_model(self._r_model)
def test_full_model_checkpoint(self): path = os.path.join(FLAGS.test_tmpdir, 'r_network_full.{epoch:05d}.h5') self._fit_model_with_callback( keras_checkpoint.GFileModelCheckpoint(path, save_summary=False, save_weights_only=False, period=1)) self.assertTrue(tf.gfile.Exists(path.format(epoch=1))) self.assertFalse( tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
def test_model_weights_checkpoint(self): path = os.path.join(FLAGS.test_tmpdir, 'r_network_weights.{epoch:05d}.h5') self._fit_model_with_callback( keras_checkpoint.GFileModelCheckpoint( path, save_summary=True, summary=constants.Level( 'explore_goal_locations_small').asdict(), save_weights_only=True, period=1)) self.assertTrue(tf.gfile.Exists(path.format(epoch=1))) self.assertTrue( tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
def __init__(self, rlb_model_wrapper, ensure_train_between_episodes=True, checkpoint_dir=None): observation_history_size = rlb_model_wrapper.all_rlb_args.outer_args[ 'rlb_ot_history_size'] training_interval = rlb_model_wrapper.all_rlb_args.outer_args[ 'rlb_ot_train_interval'] num_epochs = rlb_model_wrapper.all_rlb_args.outer_args[ 'rlb_ot_num_epochs'] batch_size = rlb_model_wrapper.all_rlb_args.outer_args[ 'rlb_ot_batch_size'] # The training interval is assumed to be the same as the history size # for invalid negative values. if training_interval < 0: training_interval = observation_history_size self._rlb_model_wrapper = rlb_model_wrapper self.training_interval = training_interval self._ensure_train_between_episodes = ensure_train_between_episodes self._batch_size = batch_size self._num_epochs = num_epochs # Keeps track of the last N observations. # Those are used to train the R network in an online way. self._fifo_observations = [None] * observation_history_size self._fifo_actions = [None] * observation_history_size self._fifo_dones = [None] * observation_history_size self._fifo_index = 0 self._fifo_count = 0 # Used to save checkpoints. self._current_epoch = 0 self._checkpointer = None if checkpoint_dir is not None: checkpoint_period_in_epochs = self._num_epochs self._checkpointer = keras_checkpoint.GFileModelCheckpoint( os.path.join(checkpoint_dir, 'r_network_weights.{epoch:05d}.h5'), save_summary=False, save_weights_only=True, period=checkpoint_period_in_epochs) self._checkpointer.set_model(self._rlb_model_wrapper)
def train(self): """Launch training.""" logging.info('Started training!') model, initial_epoch = self.create_model() print('Model we train:') model.summary() checkpoint_summary = copy.copy(self.level.asdict()) checkpoint_summary.update( action_set=FLAGS.action_set, episode_length=FLAGS.training_episode_length, max_input_env_steps=FLAGS.max_input_env_steps) # Keep saving the weights only (even though not technically needed since we # save full model below), so that stay compatible with downstream code. weights_checkpoint_cb = keras_checkpoint.GFileModelCheckpoint( os.path.join(self.workdir, 'r_network_weights.{epoch:05d}.h5'), save_summary=True, summary=checkpoint_summary, save_weights_only=True, period=Const.STORE_CHECKPOINT_EVERY_N_EPOCHS) model_checkpoint_cb = keras_checkpoint.GFileModelCheckpoint( os.path.join( self.workdir, '%s.{epoch:05d}.h5' % self.saved_model_basename_prefix), save_summary=False, save_weights_only=False, # Keras does not store the dataset/iterator state, so training will # start from the beginning of the dataset. To make sure this is not a # problem (e.g. too many epochs training on the beginning of the dataset # because of restarts, leading to overfitting), we ensure that we've # iterated at least once through the dataset in the worst case before # dumping the full model checkpoint. # The dataset contains at most 800k examples, and we consume 6400 of # those per epoch, so one iteration over the dataset is at most 123 # epochs. period=123) callbacks = [weights_checkpoint_cb, model_checkpoint_cb] if self.xm_series: callbacks.append(ExportStatsToXm(self.xm_series)) assert ( bool(FLAGS.input_r_training_dir) + bool(FLAGS.training_data_glob) == 1), ( 'Exactly one of --input_r_training_dir, training_data_glob, ' 'input_from_experiment should be set.') if FLAGS.training_data_glob: filenames = tf.gfile.Glob(FLAGS.training_data_glob) filenames.sort() logging.info('Files with glob %s: %s', FLAGS.training_data_glob, filenames) training_start_index = int(FLAGS.percent_validation_files * len(filenames) / 100.) training_filenames = filenames[training_start_index:] validation_filenames = filenames[:training_start_index] assert validation_filenames, ( 'No validation filename. Increase --percent_validation_files.') elif FLAGS.input_r_training_dir: training_glob = build_r_training_data_glob( self.level, Const.MIXER_SEEDS[constants.SplitType.R_TRAINING], FLAGS.training_episode_length, FLAGS.action_set, FLAGS.noise_type, FLAGS.tv_num_images, FLAGS.max_action_distance) logging.info('Looking for training files with glob: %s', training_glob) training_filenames = tf.gfile.Glob(training_glob) validation_glob = build_r_training_data_glob( self.level, Const.MIXER_SEEDS[constants.SplitType.VALIDATION], FLAGS.validation_episode_length, FLAGS.action_set, FLAGS.noise_type, FLAGS.tv_num_images, FLAGS.max_action_distance) logging.info('Looking for validation files with glob: %s', validation_glob) validation_filenames = tf.gfile.Glob(validation_glob) # pylint: disable=g-long-lambda training_filter_fn = lambda example, label: filter_examples_fn( example, label, FLAGS.max_input_env_steps, FLAGS. training_episode_length) training_dataset = self.create_dataset(training_filenames, training_filter_fn) validation_filter_fn = lambda example, label: filter_examples_fn( example, label, FLAGS.max_input_env_steps, FLAGS. validation_episode_length) # pylint: enable=g-long-lambda validation_dataset = self.create_dataset(validation_filenames, validation_filter_fn) model.fit(training_dataset.make_one_shot_iterator(), steps_per_epoch=Const.DUMP_AFTER_BATCHES, epochs=Const.EDGE_MAX_EPOCHS, validation_data=validation_dataset.make_one_shot_iterator(), validation_steps=100, callbacks=callbacks, initial_epoch=initial_epoch) logging.info('Done training!')