Esempio n. 1
0
    def __init__(self,
                 r_model,
                 observation_history_size=20000,
                 training_interval=20000,
                 num_epochs=6,
                 checkpoint_dir=None):
        # The training interval is assumed to be the same as the history size
        # for invalid negative values.
        if training_interval < 0:
            training_interval = observation_history_size

        self._r_model = r_model
        self._training_interval = training_interval
        self._batch_size = 64
        self._num_epochs = num_epochs

        # Keeps track of the last N observations.
        # Those are used to train the R network in an online way.
        self._fifo_observations = [None] * observation_history_size
        self._fifo_dones = [None] * observation_history_size
        self._fifo_index = 0
        self._fifo_count = 0

        # Used to save checkpoints.
        self._current_epoch = 0
        self._checkpointer = None
        if checkpoint_dir is not None:
            checkpoint_period_in_epochs = self._num_epochs
            self._checkpointer = keras_checkpoint.GFileModelCheckpoint(
                os.path.join(checkpoint_dir,
                             'r_network_weights.{epoch:05d}.h5'),
                save_summary=False,
                save_weights_only=True,
                period=checkpoint_period_in_epochs)
            self._checkpointer.set_model(self._r_model)
Esempio n. 2
0
 def test_full_model_checkpoint(self):
     path = os.path.join(FLAGS.test_tmpdir, 'r_network_full.{epoch:05d}.h5')
     self._fit_model_with_callback(
         keras_checkpoint.GFileModelCheckpoint(path,
                                               save_summary=False,
                                               save_weights_only=False,
                                               period=1))
     self.assertTrue(tf.gfile.Exists(path.format(epoch=1)))
     self.assertFalse(
         tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
Esempio n. 3
0
 def test_model_weights_checkpoint(self):
     path = os.path.join(FLAGS.test_tmpdir,
                         'r_network_weights.{epoch:05d}.h5')
     self._fit_model_with_callback(
         keras_checkpoint.GFileModelCheckpoint(
             path,
             save_summary=True,
             summary=constants.Level(
                 'explore_goal_locations_small').asdict(),
             save_weights_only=True,
             period=1))
     self.assertTrue(tf.gfile.Exists(path.format(epoch=1)))
     self.assertTrue(
         tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
Esempio n. 4
0
    def __init__(self,
                 rlb_model_wrapper,
                 ensure_train_between_episodes=True,
                 checkpoint_dir=None):

        observation_history_size = rlb_model_wrapper.all_rlb_args.outer_args[
            'rlb_ot_history_size']
        training_interval = rlb_model_wrapper.all_rlb_args.outer_args[
            'rlb_ot_train_interval']
        num_epochs = rlb_model_wrapper.all_rlb_args.outer_args[
            'rlb_ot_num_epochs']
        batch_size = rlb_model_wrapper.all_rlb_args.outer_args[
            'rlb_ot_batch_size']

        # The training interval is assumed to be the same as the history size
        # for invalid negative values.
        if training_interval < 0:
            training_interval = observation_history_size

        self._rlb_model_wrapper = rlb_model_wrapper
        self.training_interval = training_interval
        self._ensure_train_between_episodes = ensure_train_between_episodes
        self._batch_size = batch_size
        self._num_epochs = num_epochs

        # Keeps track of the last N observations.
        # Those are used to train the R network in an online way.
        self._fifo_observations = [None] * observation_history_size
        self._fifo_actions = [None] * observation_history_size
        self._fifo_dones = [None] * observation_history_size
        self._fifo_index = 0
        self._fifo_count = 0

        # Used to save checkpoints.
        self._current_epoch = 0
        self._checkpointer = None
        if checkpoint_dir is not None:
            checkpoint_period_in_epochs = self._num_epochs
            self._checkpointer = keras_checkpoint.GFileModelCheckpoint(
                os.path.join(checkpoint_dir,
                             'r_network_weights.{epoch:05d}.h5'),
                save_summary=False,
                save_weights_only=True,
                period=checkpoint_period_in_epochs)
            self._checkpointer.set_model(self._rlb_model_wrapper)
Esempio n. 5
0
    def train(self):
        """Launch training."""
        logging.info('Started training!')
        model, initial_epoch = self.create_model()
        print('Model we train:')
        model.summary()
        checkpoint_summary = copy.copy(self.level.asdict())
        checkpoint_summary.update(
            action_set=FLAGS.action_set,
            episode_length=FLAGS.training_episode_length,
            max_input_env_steps=FLAGS.max_input_env_steps)
        # Keep saving the weights only (even though not technically needed since we
        # save full model below), so that stay compatible with downstream code.
        weights_checkpoint_cb = keras_checkpoint.GFileModelCheckpoint(
            os.path.join(self.workdir, 'r_network_weights.{epoch:05d}.h5'),
            save_summary=True,
            summary=checkpoint_summary,
            save_weights_only=True,
            period=Const.STORE_CHECKPOINT_EVERY_N_EPOCHS)
        model_checkpoint_cb = keras_checkpoint.GFileModelCheckpoint(
            os.path.join(
                self.workdir,
                '%s.{epoch:05d}.h5' % self.saved_model_basename_prefix),
            save_summary=False,
            save_weights_only=False,
            # Keras does not store the dataset/iterator state, so training will
            # start from the beginning of the dataset. To make sure this is not a
            # problem (e.g. too many epochs training on the beginning of the dataset
            # because of restarts, leading to overfitting), we ensure that we've
            # iterated at least once through the dataset in the worst case before
            # dumping the full model checkpoint.
            # The dataset contains at most 800k examples, and we consume 6400 of
            # those per epoch, so one iteration over the dataset is at most 123
            # epochs.
            period=123)
        callbacks = [weights_checkpoint_cb, model_checkpoint_cb]
        if self.xm_series:
            callbacks.append(ExportStatsToXm(self.xm_series))

        assert (
            bool(FLAGS.input_r_training_dir) +
            bool(FLAGS.training_data_glob) == 1), (
                'Exactly one of --input_r_training_dir, training_data_glob, '
                'input_from_experiment should be set.')

        if FLAGS.training_data_glob:
            filenames = tf.gfile.Glob(FLAGS.training_data_glob)
            filenames.sort()
            logging.info('Files with glob %s: %s', FLAGS.training_data_glob,
                         filenames)
            training_start_index = int(FLAGS.percent_validation_files *
                                       len(filenames) / 100.)
            training_filenames = filenames[training_start_index:]
            validation_filenames = filenames[:training_start_index]
            assert validation_filenames, (
                'No validation filename. Increase --percent_validation_files.')
        elif FLAGS.input_r_training_dir:
            training_glob = build_r_training_data_glob(
                self.level, Const.MIXER_SEEDS[constants.SplitType.R_TRAINING],
                FLAGS.training_episode_length, FLAGS.action_set,
                FLAGS.noise_type, FLAGS.tv_num_images,
                FLAGS.max_action_distance)
            logging.info('Looking for training files with glob: %s',
                         training_glob)
            training_filenames = tf.gfile.Glob(training_glob)
            validation_glob = build_r_training_data_glob(
                self.level, Const.MIXER_SEEDS[constants.SplitType.VALIDATION],
                FLAGS.validation_episode_length, FLAGS.action_set,
                FLAGS.noise_type, FLAGS.tv_num_images,
                FLAGS.max_action_distance)
            logging.info('Looking for validation files with glob: %s',
                         validation_glob)
            validation_filenames = tf.gfile.Glob(validation_glob)

        # pylint: disable=g-long-lambda
        training_filter_fn = lambda example, label: filter_examples_fn(
            example, label, FLAGS.max_input_env_steps, FLAGS.
            training_episode_length)

        training_dataset = self.create_dataset(training_filenames,
                                               training_filter_fn)

        validation_filter_fn = lambda example, label: filter_examples_fn(
            example, label, FLAGS.max_input_env_steps, FLAGS.
            validation_episode_length)
        # pylint: enable=g-long-lambda

        validation_dataset = self.create_dataset(validation_filenames,
                                                 validation_filter_fn)

        model.fit(training_dataset.make_one_shot_iterator(),
                  steps_per_epoch=Const.DUMP_AFTER_BATCHES,
                  epochs=Const.EDGE_MAX_EPOCHS,
                  validation_data=validation_dataset.make_one_shot_iterator(),
                  validation_steps=100,
                  callbacks=callbacks,
                  initial_epoch=initial_epoch)
        logging.info('Done training!')