def proc_model_checkpoint_works_with_same_file_path( test_obj, saving_filepath): if multi_process_runner.is_oss(): test_obj.skipTest('TODO(b/170838633): Failing in OSS') model, _, train_ds, steps = _model_setup(test_obj, file_format='') num_epoch = 4 # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse(file_io.file_exists_v2(saving_filepath)) bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup') try: model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath), callbacks.BackupAndRestore(backup_dir=bar_dir), InterruptingCallback() ]) except RuntimeError as e: if 'Interrupting!' not in str(e): raise multi_process_runner.get_barrier().wait() backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint') test_obj.assertTrue(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath), callbacks.BackupAndRestore(backup_dir=bar_dir), AssertCallback() ]) multi_process_runner.get_barrier().wait() test_obj.assertFalse(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
def proc_model_checkpoint_works_with_same_file_path( test_obj, saving_filepath): model, _, train_ds, steps = _model_setup(test_obj, file_format='') num_epoch = 4 # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse(file_io.file_exists(saving_filepath)) bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup') try: model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath), callbacks.BackupAndRestore(backup_dir=bar_dir), InterruptingCallback() ]) except RuntimeError as e: if 'Interrupting!' not in str(e): raise backup_filepath = os.path.join(bar_dir, 'checkpoint') test_obj.assertTrue(file_io.file_exists(backup_filepath)) test_obj.assertTrue(file_io.file_exists(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[ callbacks.ModelCheckpoint(filepath=saving_filepath), callbacks.BackupAndRestore(backup_dir=bar_dir), AssertCallback() ]) test_obj.assertFalse(file_io.file_exists(backup_filepath)) test_obj.assertTrue(file_io.file_exists(saving_filepath))