def initialFitting(test_obj, model, train_ds, num_epoch, steps, saving_filepath): # The saving_filepath shouldn't exist at the beginning. test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath, save_weights_only=True) ]) # The saving_filepath should exist after fitting with callback. Both chief # and non-chief worker should both see it exists (which was saved only by # chief). test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath)) history_after_one_more_epoch = model.fit(x=train_ds, epochs=1, steps_per_epoch=steps) # The saving_filepath should continue to exist (if it did) after fitting # without callback. test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath)) return saving_filepath, history_after_one_more_epoch
def testCheckpointExists(self, file_format, save_weights_only): train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2) model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) saving_dir = self.get_temp_dir() saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) callbacks_list = [ callbacks.ModelCheckpoint(filepath=saving_filepath, save_weights_only=save_weights_only) ] self.assertFalse(training_state.checkpoint_exists(saving_filepath)) try: model.fit(x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) except NotFoundError as e: if 'Failed to create a NewWriteableFile' in e.message: self.skipTest( 'b/138941852, path not found error in Windows py35.') self.assertTrue(training_state.checkpoint_exists(saving_filepath)) self.assertTrue( training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)) self.assertFalse(training_state.checkpoint_exists(saving_filepath))
def callableForTestModelCheckpointSavesOnChiefButNotOtherwise( model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath, **kwargs): extension = os.path.splitext(saving_filepath)[1] # Incorporate type/index information and thread id in saving_filepath to # ensure every worker has a unique path. Note that in normal use case the # saving_filepath will be the same for all workers, but we use different # ones here just to test out chief saves checkpoint but non-chief doesn't. saving_filepath = os.path.join( test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' % (test_base.get_task_type(), test_base.get_task_index(), extension)) # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath)) model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)]) # If it's chief, the model should be saved; if not, the model shouldn't. test_obj.assertEqual(training_state.checkpoint_exists(saving_filepath), test_base.is_chief())
def callableForTestUnmatchedModelFile(model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath, **kwargs): # The saving_filepath shouldn't exist at the beginning. test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath, save_weights_only=True) ]) (train_ds, _), (_, _) = testing_utils.get_test_data(train_samples=10, test_samples=10, input_shape=(3, ), num_classes=2) # Switch to a model of different structure. with strategy.scope(): model = keras.models.Sequential() model.add(keras.layers.Dense(5, input_dim=3, activation='relu')) model.add(keras.layers.Dense(2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc']) test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath)) if saving_filepath.endswith('.tf'): test_obj.skipTest( 'Loading mismatched TF checkpoint would cause Fatal ' 'Python error: Aborted. Skipping.') # Unmatched format. Should raise ValueError. with test_obj.assertRaisesRegexp(ValueError, 'Error loading file from'): model.fit(x=train_ds, epochs=num_epoch, batch_size=8, callbacks=[ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=True, load_weights_on_restart=True) ])
def on_train_begin(self, logs=None): # pylint: disable=protected-access if self.model._in_multi_worker_mode(): # MultiWorkerTrainingState is used to manage the training state needed # for preemption-recovery of a worker in multi-worker training. self.model._training_state = ( training_state.MultiWorkerTrainingState( self.model, self.filepath)) self._training_state = self.model._training_state if self._training_state.restore(): # If the training state needs to be and is successfully restored, # it is recovering from a previous failure (or preemption). In such # case, do not load the weights from user specified file path. return # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. if self.load_weights_on_restart: if (not self.model._in_multi_worker_mode() or multi_worker_util.should_load_checkpoint()): filepath_to_load = ( self._get_most_recently_modified_file_matching_pattern( self.filepath)) if (filepath_to_load is not None and training_state.checkpoint_exists(filepath_to_load)): try: # `filepath` may contain placeholders such as `{epoch:02d}`, and # thus it attempts to load the most recently modified file with file # name matching the pattern. self.model.load_weights(filepath_to_load) except (IOError, ValueError) as e: raise ValueError( 'Error loading file from {}. Reason: {}'.format( filepath_to_load, e))
def callableForTestLoadWeightFromModelCheckpoint(model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath, **kwargs): filepaths = [] real_mkstemp = tempfile.mkstemp def mocked_mkstemp(): # Only non-chief should call tempfile.mkstemp() inside fit() in sync # training. assert not test_base.is_chief() file_handle, temp_file_name = real_mkstemp() extension = os.path.splitext(saving_filepath)[1] temp_filepath = temp_file_name + extension filepaths.append(temp_filepath) return file_handle, temp_file_name # Mock tempfile.mkstemp() so the filepaths can be stored and verified later. with test.mock.patch.object(tempfile, 'mkstemp', mocked_mkstemp): saving_filepath, history_after_one_more_epoch = \ KerasMultiWorkerCallbackTest.initialFitting( test_obj, model, train_ds, num_epoch, steps, saving_filepath) with strategy.scope(): model.load_weights(saving_filepath) history_after_loading_weight_and_one_more_epoch = model.fit( x=train_ds, epochs=1, steps_per_epoch=steps) test_obj.assertAllClose( history_after_one_more_epoch.history, history_after_loading_weight_and_one_more_epoch.history, rtol=5e-5) # Verify the temp files are indeed removed (no trace left behind). for filepath in filepaths: assert not training_state.checkpoint_exists(filepath)
def testCheckpointExists(self, file_format, save_weights_only): train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2) model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) saving_dir = self.get_temp_dir() saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) callbacks_list = [ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only) ] self.assertFalse(training_state.checkpoint_exists(saving_filepath)) model.fit(x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) self.assertTrue(training_state.checkpoint_exists(saving_filepath)) self.assertTrue( training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)) self.assertFalse(training_state.checkpoint_exists(saving_filepath))
def proc_model_checkpoint_saves_on_chief_but_not_otherwise( test_obj, file_format): model, saving_filepath, train_ds, steps = _model_setup( test_obj, file_format) num_epoch = 2 extension = os.path.splitext(saving_filepath)[1] # Incorporate type/index information and thread id in saving_filepath to # ensure every worker has a unique path. Note that in normal use case the # saving_filepath will be the same for all workers, but we use different # ones here just to test out chief saves checkpoint but non-chief doesn't. saving_filepath = os.path.join( test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' % (test_base.get_task_type(), test_base.get_task_index(), extension)) # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse( training_state.checkpoint_exists(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, validation_data=train_ds, validation_steps=steps, callbacks=[ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only) ]) # If it's chief, the model should be saved; if not, the model shouldn't. test_obj.assertEqual( training_state.checkpoint_exists(saving_filepath), test_base.is_chief()) # If it's chief, the model should be saved (`write_filepath` should # simply return `saving_filepath`); if not, i.e. for non-chief workers, # the temporary path generated by `write_filepath` should no longer # contain the checkpoint that has been deleted. test_obj.assertEqual( training_state.checkpoint_exists( distributed_file_utils.write_filepath( saving_filepath, model._distribution_strategy)), test_base.is_chief())
def callableForTestBackupModelNotRemovedIfInterrupted(model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath, **kwargs): # `barrier` object needs to be passed in from parent # thread so both threads refer to the same object. barrier = kwargs['barrier'] num_epoch = 4 # Testing the backup filepath `multi_worker_training_state` uses. _, backup_filepath = training_state._get_backup_filepath(saving_filepath) # The backup_filepath shouldn't exist at the beginning. test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath)) # Callback to interrupt in the middle of training. class InterruptingCallback(callbacks.Callback): def on_epoch_begin(self, epoch, logs=None): if epoch == 2: raise RuntimeError('Interrupting!') try: model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=True), InterruptingCallback() ]) except RuntimeError as e: if 'Interrupting!' not in e.message: raise # Sync on the two threads. barrier.wait() # The back up file should exist after interruption of `model.fit()`. test_obj.assertTrue(training_state.checkpoint_exists(backup_filepath))
def callableForTestBackupModelRemoved(model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath, **kwargs): # `barrier` object needs to be passed in from parent # thread so both threads refer to the same object. barrier = kwargs['barrier'] num_epoch = 3 # Testing the backup filepath `multi_worker_training_state` uses. _, backup_filepath = training_state._get_backup_filepath( saving_filepath) # The backup_filepath shouldn't exist at the beginning. test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath)) # Callback to verify that the backup file exists in the middle of training. class BackupFilepathVerifyingCallback(callbacks.Callback): def on_epoch_begin(self, epoch, logs=None): if epoch > 1: # Asserting that after the first two epochs, the backup file should # exist. test_obj.assertTrue( training_state.checkpoint_exists(backup_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath, save_weights_only=True), BackupFilepathVerifyingCallback() ]) # Sync on the two threads so we make sure the backup file is removed before # we move on. barrier.wait() # The back up file should not exist at successful exit of `model.fit()`. test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))
def test_template(self, strategy_cls, file_format): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec(num_workers=num_workers, test_obj=self) self._barrier = dc._Barrier(2) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 2 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model( (28, 28, 1)) custom_callable(model, self, train_ds, num_epoch, steps, strategy, saving_filepath=kwargs['saving_filepath'], barrier=kwargs['barrier'], threading_local=kwargs['threading_local']) # Pass saving_filepath from the parent thread to ensure every worker has the # same fileapth to save. saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.' + file_format) barrier = dc._Barrier(2) threading_local = threading.local() threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, barrier=barrier, threading_local=threading_local) self.assertFalse(training_state.checkpoint_exists(saving_filepath)) threads_to_join = [] strategy = get_strategy_object(strategy_cls) if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def on_train_begin(self, logs=None): # pylint: disable=protected-access if self.model._in_multi_worker_mode(): # MultiWorkerTrainingState is used to manage the training state needed # for preemption-recovery of a worker in multi-worker training. self.model._training_state = ( training_state.MultiWorkerTrainingState( self.model, self.filepath)) self._training_state = self.model._training_state if self._training_state.restore(): # If the training state needs to be and is successfully restored, # it is recovering from a previous failure (or preemption). In such # case, do not load the weights from user specified file path. return # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. if (not self.model._in_multi_worker_mode() or multi_worker_util.should_load_checkpoint()): filepath_to_load = self.manager.latest_checkpoint if (filepath_to_load is not None and training_state.checkpoint_exists(filepath_to_load)): try: # `filepath` may contain placeholders such as `{epoch:02d}`, and # thus it attempts to load the most recently modified file with file # name matching the pattern. if not self.save_best_only: self.ckpt.restore(filepath_to_load).expect_partial() self.start_epoch[0] = self.ckpt.step.numpy() + 1 logging.info( f"Restored checkpoint from {filepath_to_load}. " f"Start from epoch {self.start_epoch[0]}.") else: # Try to restore best metric self._load_metric(filepath_to_load) except (IOError, ValueError) as e: raise ValueError( 'Error loading file from {}. Reason: {}'.format( filepath_to_load, e))
def on_epoch_begin(self, epoch, logs=None): if epoch > 1: # Asserting that after the first two epochs, the backup file should # exist. test_obj.assertTrue( training_state.checkpoint_exists(backup_filepath))