def proc_model_checkpoint_works_with_same_file_path(
                test_obj, saving_filepath):
            if multi_process_runner.is_oss():
                test_obj.skipTest('TODO(b/170838633): Failing in OSS')
            model, _, train_ds, steps = _model_setup(test_obj, file_format='')
            num_epoch = 4

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
            bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')

            try:
                model.fit(
                    x=train_ds,
                    epochs=num_epoch,
                    steps_per_epoch=steps,
                    callbacks=[
                        callbacks.ModelCheckpoint(filepath=saving_filepath),
                        callbacks.BackupAndRestore(backup_dir=bar_dir),
                        InterruptingCallback()
                    ])
            except RuntimeError as e:
                if 'Interrupting!' not in str(e):
                    raise

            multi_process_runner.get_barrier().wait()
            backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint')
            test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(filepath=saving_filepath),
                          callbacks.BackupAndRestore(backup_dir=bar_dir),
                          AssertCallback()
                      ])
            multi_process_runner.get_barrier().wait()
            test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
Exemple #2
0
        def proc_model_checkpoint_works_with_same_file_path(
                test_obj, saving_filepath):
            model, _, train_ds, steps = _model_setup(test_obj, file_format='')
            num_epoch = 4

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(file_io.file_exists(saving_filepath))
            bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')

            try:
                model.fit(
                    x=train_ds,
                    epochs=num_epoch,
                    steps_per_epoch=steps,
                    callbacks=[
                        callbacks.ModelCheckpoint(filepath=saving_filepath),
                        callbacks.BackupAndRestore(backup_dir=bar_dir),
                        InterruptingCallback()
                    ])
            except RuntimeError as e:
                if 'Interrupting!' not in str(e):
                    raise

            backup_filepath = os.path.join(bar_dir, 'checkpoint')
            test_obj.assertTrue(file_io.file_exists(backup_filepath))
            test_obj.assertTrue(file_io.file_exists(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(filepath=saving_filepath),
                          callbacks.BackupAndRestore(backup_dir=bar_dir),
                          AssertCallback()
                      ])
            test_obj.assertFalse(file_io.file_exists(backup_filepath))
            test_obj.assertTrue(file_io.file_exists(saving_filepath))