Example #1
0
    def initialFitting(test_obj, model, train_ds, num_epoch, steps,
                       saving_filepath):
        # The saving_filepath shouldn't exist at the beginning.
        test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))

        model.fit(x=train_ds,
                  epochs=num_epoch,
                  steps_per_epoch=steps,
                  callbacks=[
                      callbacks.ModelCheckpoint(filepath=saving_filepath,
                                                save_weights_only=True)
                  ])

        # The saving_filepath should exist after fitting with callback. Both chief
        # and non-chief worker should both see it exists (which was saved only by
        # chief).
        test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))

        history_after_one_more_epoch = model.fit(x=train_ds,
                                                 epochs=1,
                                                 steps_per_epoch=steps)

        # The saving_filepath should continue to exist (if it did) after fitting
        # without callback.
        test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))

        return saving_filepath, history_after_one_more_epoch
Example #2
0
    def testCheckpointExists(self, file_format, save_weights_only):
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
        saving_dir = self.get_temp_dir()
        saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
        callbacks_list = [
            callbacks.ModelCheckpoint(filepath=saving_filepath,
                                      save_weights_only=save_weights_only)
        ]
        self.assertFalse(training_state.checkpoint_exists(saving_filepath))

        try:
            model.fit(x=train_ds,
                      epochs=2,
                      steps_per_epoch=2,
                      callbacks=callbacks_list)
        except NotFoundError as e:
            if 'Failed to create a NewWriteableFile' in e.message:
                self.skipTest(
                    'b/138941852, path not found error in Windows py35.')

        self.assertTrue(training_state.checkpoint_exists(saving_filepath))
        self.assertTrue(
            training_state.remove_checkpoint_if_exists(saving_dir,
                                                       saving_filepath))
        self.assertFalse(training_state.checkpoint_exists(saving_filepath))
Example #3
0
    def callableForTestModelCheckpointSavesOnChiefButNotOtherwise(
            model, test_obj, train_ds, num_epoch, steps, strategy,
            saving_filepath, **kwargs):

        extension = os.path.splitext(saving_filepath)[1]

        # Incorporate type/index information and thread id in saving_filepath to
        # ensure every worker has a unique path. Note that in normal use case the
        # saving_filepath will be the same for all workers, but we use different
        # ones here just to test out chief saves checkpoint but non-chief doesn't.

        saving_filepath = os.path.join(
            test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' %
            (test_base.get_task_type(), test_base.get_task_index(), extension))

        # The saving_filepath shouldn't exist at the beginning (as it's unique).
        test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))

        model.fit(
            x=train_ds,
            epochs=num_epoch,
            steps_per_epoch=steps,
            callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)])

        # If it's chief, the model should be saved; if not, the model shouldn't.
        test_obj.assertEqual(training_state.checkpoint_exists(saving_filepath),
                             test_base.is_chief())
Example #4
0
    def callableForTestUnmatchedModelFile(model, test_obj, train_ds, num_epoch,
                                          steps, strategy, saving_filepath,
                                          **kwargs):

        # The saving_filepath shouldn't exist at the beginning.
        test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))

        model.fit(x=train_ds,
                  epochs=num_epoch,
                  steps_per_epoch=steps,
                  callbacks=[
                      callbacks.ModelCheckpoint(filepath=saving_filepath,
                                                save_weights_only=True)
                  ])

        (train_ds, _), (_, _) = testing_utils.get_test_data(train_samples=10,
                                                            test_samples=10,
                                                            input_shape=(3, ),
                                                            num_classes=2)

        # Switch to a model of different structure.
        with strategy.scope():
            model = keras.models.Sequential()
            model.add(keras.layers.Dense(5, input_dim=3, activation='relu'))
            model.add(keras.layers.Dense(2, activation='softmax'))
            model.compile(loss='categorical_crossentropy',
                          optimizer='rmsprop',
                          metrics=['acc'])

        test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))

        if saving_filepath.endswith('.tf'):
            test_obj.skipTest(
                'Loading mismatched TF checkpoint would cause Fatal '
                'Python error: Aborted. Skipping.')

        # Unmatched format. Should raise ValueError.
        with test_obj.assertRaisesRegexp(ValueError,
                                         'Error loading file from'):
            model.fit(x=train_ds,
                      epochs=num_epoch,
                      batch_size=8,
                      callbacks=[
                          callbacks.ModelCheckpoint(
                              filepath=saving_filepath,
                              save_weights_only=True,
                              load_weights_on_restart=True)
                      ])
Example #5
0
    def on_train_begin(self, logs=None):
        # pylint: disable=protected-access
        if self.model._in_multi_worker_mode():
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        if self.load_weights_on_restart:
            if (not self.model._in_multi_worker_mode()
                    or multi_worker_util.should_load_checkpoint()):
                filepath_to_load = (
                    self._get_most_recently_modified_file_matching_pattern(
                        self.filepath))
                if (filepath_to_load is not None and
                        training_state.checkpoint_exists(filepath_to_load)):
                    try:
                        # `filepath` may contain placeholders such as `{epoch:02d}`, and
                        # thus it attempts to load the most recently modified file with file
                        # name matching the pattern.
                        self.model.load_weights(filepath_to_load)
                    except (IOError, ValueError) as e:
                        raise ValueError(
                            'Error loading file from {}. Reason: {}'.format(
                                filepath_to_load, e))
  def callableForTestLoadWeightFromModelCheckpoint(model, test_obj, train_ds,
                                                   num_epoch, steps, strategy,
                                                   saving_filepath, **kwargs):
    filepaths = []
    real_mkstemp = tempfile.mkstemp
    def mocked_mkstemp():
      # Only non-chief should call tempfile.mkstemp() inside fit() in sync
      # training.
      assert not test_base.is_chief()
      file_handle, temp_file_name = real_mkstemp()
      extension = os.path.splitext(saving_filepath)[1]
      temp_filepath = temp_file_name + extension
      filepaths.append(temp_filepath)
      return file_handle, temp_file_name

    # Mock tempfile.mkstemp() so the filepaths can be stored and verified later.
    with test.mock.patch.object(tempfile, 'mkstemp', mocked_mkstemp):
      saving_filepath, history_after_one_more_epoch = \
          KerasMultiWorkerCallbackTest.initialFitting(
              test_obj, model, train_ds, num_epoch, steps, saving_filepath)

      with strategy.scope():
        model.load_weights(saving_filepath)

      history_after_loading_weight_and_one_more_epoch = model.fit(
          x=train_ds, epochs=1, steps_per_epoch=steps)

      test_obj.assertAllClose(
          history_after_one_more_epoch.history,
          history_after_loading_weight_and_one_more_epoch.history,
          rtol=5e-5)

    # Verify the temp files are indeed removed (no trace left behind).
    for filepath in filepaths:
      assert not training_state.checkpoint_exists(filepath)
  def testCheckpointExists(self, file_format, save_weights_only):
    train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
    model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    saving_dir = self.get_temp_dir()
    saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
    callbacks_list = [
        callbacks.ModelCheckpoint(
            filepath=saving_filepath, save_weights_only=save_weights_only)
    ]
    self.assertFalse(training_state.checkpoint_exists(saving_filepath))

    model.fit(x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list)
    self.assertTrue(training_state.checkpoint_exists(saving_filepath))
    self.assertTrue(
        training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
    self.assertFalse(training_state.checkpoint_exists(saving_filepath))
Example #8
0
        def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
                test_obj, file_format):

            model, saving_filepath, train_ds, steps = _model_setup(
                test_obj, file_format)
            num_epoch = 2
            extension = os.path.splitext(saving_filepath)[1]

            # Incorporate type/index information and thread id in saving_filepath to
            # ensure every worker has a unique path. Note that in normal use case the
            # saving_filepath will be the same for all workers, but we use different
            # ones here just to test out chief saves checkpoint but non-chief doesn't.
            saving_filepath = os.path.join(
                test_obj.get_temp_dir(),
                'checkpoint_%s_%d%s' % (test_base.get_task_type(),
                                        test_base.get_task_index(), extension))

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(
                training_state.checkpoint_exists(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      validation_data=train_ds,
                      validation_steps=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(
                              filepath=saving_filepath,
                              save_weights_only=save_weights_only)
                      ])

            # If it's chief, the model should be saved; if not, the model shouldn't.
            test_obj.assertEqual(
                training_state.checkpoint_exists(saving_filepath),
                test_base.is_chief())

            # If it's chief, the model should be saved (`write_filepath` should
            # simply return `saving_filepath`); if not, i.e. for non-chief workers,
            # the temporary path generated by `write_filepath` should no longer
            # contain the checkpoint that has been deleted.
            test_obj.assertEqual(
                training_state.checkpoint_exists(
                    distributed_file_utils.write_filepath(
                        saving_filepath, model._distribution_strategy)),
                test_base.is_chief())
  def callableForTestBackupModelNotRemovedIfInterrupted(model, test_obj,
                                                        train_ds, num_epoch,
                                                        steps, strategy,
                                                        saving_filepath,
                                                        **kwargs):

    # `barrier` object needs to be passed in from parent
    # thread so both threads refer to the same object.
    barrier = kwargs['barrier']

    num_epoch = 4

    # Testing the backup filepath `multi_worker_training_state` uses.
    _, backup_filepath = training_state._get_backup_filepath(saving_filepath)

    # The backup_filepath shouldn't exist at the beginning.
    test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))

    # Callback to interrupt in the middle of training.
    class InterruptingCallback(callbacks.Callback):

      def on_epoch_begin(self, epoch, logs=None):
        if epoch == 2:
          raise RuntimeError('Interrupting!')

    try:
      model.fit(
          x=train_ds,
          epochs=num_epoch,
          steps_per_epoch=steps,
          callbacks=[
              callbacks.ModelCheckpoint(
                  filepath=saving_filepath, save_weights_only=True),
              InterruptingCallback()
          ])
    except RuntimeError as e:
      if 'Interrupting!' not in e.message:
        raise

    # Sync on the two threads.
    barrier.wait()

    # The back up file should exist after interruption of `model.fit()`.
    test_obj.assertTrue(training_state.checkpoint_exists(backup_filepath))
Example #10
0
    def callableForTestBackupModelRemoved(model, test_obj, train_ds, num_epoch,
                                          steps, strategy, saving_filepath,
                                          **kwargs):

        # `barrier` object needs to be passed in from parent
        # thread so both threads refer to the same object.
        barrier = kwargs['barrier']

        num_epoch = 3

        # Testing the backup filepath `multi_worker_training_state` uses.
        _, backup_filepath = training_state._get_backup_filepath(
            saving_filepath)

        # The backup_filepath shouldn't exist at the beginning.
        test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))

        # Callback to verify that the backup file exists in the middle of training.
        class BackupFilepathVerifyingCallback(callbacks.Callback):
            def on_epoch_begin(self, epoch, logs=None):
                if epoch > 1:
                    # Asserting that after the first two epochs, the backup file should
                    # exist.
                    test_obj.assertTrue(
                        training_state.checkpoint_exists(backup_filepath))

        model.fit(x=train_ds,
                  epochs=num_epoch,
                  steps_per_epoch=steps,
                  callbacks=[
                      callbacks.ModelCheckpoint(filepath=saving_filepath,
                                                save_weights_only=True),
                      BackupFilepathVerifyingCallback()
                  ])

        # Sync on the two threads so we make sure the backup file is removed before
        # we move on.
        barrier.wait()

        # The back up file should not exist at successful exit of `model.fit()`.
        test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))
Example #11
0
    def test_template(self, strategy_cls, file_format):
        num_workers = 2
        num_epoch = 2

        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers,
                                                     test_obj=self)
        self._barrier = dc._Barrier(2)

        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 2
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                custom_callable(model,
                                self,
                                train_ds,
                                num_epoch,
                                steps,
                                strategy,
                                saving_filepath=kwargs['saving_filepath'],
                                barrier=kwargs['barrier'],
                                threading_local=kwargs['threading_local'])

        # Pass saving_filepath from the parent thread to ensure every worker has the
        # same fileapth to save.
        saving_filepath = os.path.join(self.get_temp_dir(),
                                       'checkpoint.' + file_format)
        barrier = dc._Barrier(2)
        threading_local = threading.local()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            barrier=barrier,
            threading_local=threading_local)
        self.assertFalse(training_state.checkpoint_exists(saving_filepath))

        threads_to_join = []
        strategy = get_strategy_object(strategy_cls)
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
Example #12
0
    def on_train_begin(self, logs=None):
        # pylint: disable=protected-access
        if self.model._in_multi_worker_mode():
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        if (not self.model._in_multi_worker_mode()
                or multi_worker_util.should_load_checkpoint()):
            filepath_to_load = self.manager.latest_checkpoint
            if (filepath_to_load is not None
                    and training_state.checkpoint_exists(filepath_to_load)):
                try:
                    # `filepath` may contain placeholders such as `{epoch:02d}`, and
                    # thus it attempts to load the most recently modified file with file
                    # name matching the pattern.
                    if not self.save_best_only:
                        self.ckpt.restore(filepath_to_load).expect_partial()
                        self.start_epoch[0] = self.ckpt.step.numpy() + 1
                        logging.info(
                            f"Restored checkpoint from {filepath_to_load}. "
                            f"Start from epoch {self.start_epoch[0]}.")
                    else:
                        # Try to restore best metric
                        self._load_metric(filepath_to_load)
                except (IOError, ValueError) as e:
                    raise ValueError(
                        'Error loading file from {}. Reason: {}'.format(
                            filepath_to_load, e))
Example #13
0
 def on_epoch_begin(self, epoch, logs=None):
     if epoch > 1:
         # Asserting that after the first two epochs, the backup file should
         # exist.
         test_obj.assertTrue(
             training_state.checkpoint_exists(backup_filepath))