Esempio n. 1
0
 def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
     # pylint: disable=protected-access
     try:
         file_path = self.filepath.format(epoch=epoch + 1, **logs)
     except KeyError as e:
         raise KeyError('Failed to format this callback filepath: "{}". '
                        'Reason: {}'.format(self.filepath, e))
     self._write_filepath = distributed_file_utils.write_filepath(
         file_path, self.model.distribute_strategy)
     if self.opt:
         li = file_path.split(".")
         li[-2] += "_opt"
         filr_path_opt = ".".join(li)
         self._write_filepath_opt = distributed_file_utils.write_filepath(
             filr_path_opt, self.model.distribute_strategy)
     return self._write_filepath
 def testChiefWriteDirAndFilePath(self):
     dirpath = self.get_temp_dir()
     filepath = os.path.join(dirpath, 'foo.bar')
     strategy = DistributedFileUtilsTest.MockedChiefStrategy()
     self.assertEqual(
         distributed_file_utils.write_filepath(filepath, strategy),
         filepath)
     self.assertEqual(
         distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
 def testWorkerWriteDirAndFilePath(self):
     dirpath = self.get_temp_dir()
     filepath = os.path.join(dirpath, 'foo.bar')
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     self.assertEqual(
         distributed_file_utils.write_filepath(filepath, strategy),
         os.path.join(dirpath, 'workertemp_3', 'foo.bar'))
     self.assertEqual(
         distributed_file_utils.write_dirpath(dirpath, strategy),
         os.path.join(dirpath, 'workertemp_3'))
Esempio n. 4
0
 def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
     try:
         file_path = self.filepath.format(
             epoch=epoch + 1,
             timer=datetime.datetime.now().strftime('%m%d_%H%M%S'),
             **logs)
     except KeyError as e:
         raise KeyError('Failed to format this callback filepath: "{}". '
                        'Reason: {}'.format(self.filepath, e))
     self._write_filepath = distributed_file_utils.write_filepath(
         file_path, self.model.distribute_strategy)
     return self._write_filepath
Esempio n. 5
0
 def _get_file_path(self, epoch, logs):
     """Returns the file path for checkpoint."""
     # pylint: disable=protected-access
     try:
         # `filepath` may contain placeholders such as `{epoch:02d}` and
         # `{mape:.2f}`. A mismatch between logged metrics and the path's
         # placeholders can cause formatting to fail.
         file_path = self.filepath.format(epoch=epoch + 1, **logs)
     except KeyError as e:
         raise KeyError('Failed to format this callback filepath: "{}". '
                        'Reason: {}'.format(self.filepath, e))
     self._write_filepath = distributed_file_utils.write_filepath(
         file_path, self.model.distribute_strategy)
     return self._write_filepath
Esempio n. 6
0
        def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
                test_obj, file_format):

            model, saving_filepath, train_ds, steps = _model_setup(
                test_obj, file_format)
            num_epoch = 2
            extension = os.path.splitext(saving_filepath)[1]

            # Incorporate type/index information and thread id in saving_filepath to
            # ensure every worker has a unique path. Note that in normal use case the
            # saving_filepath will be the same for all workers, but we use different
            # ones here just to test out chief saves checkpoint but non-chief doesn't.
            saving_filepath = os.path.join(
                test_obj.get_temp_dir(),
                'checkpoint_%s_%d%s' % (test_base.get_task_type(),
                                        test_base.get_task_index(), extension))

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(
                training_state.checkpoint_exists(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      validation_data=train_ds,
                      validation_steps=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(
                              filepath=saving_filepath,
                              save_weights_only=save_weights_only)
                      ])

            # If it's chief, the model should be saved; if not, the model shouldn't.
            test_obj.assertEqual(
                training_state.checkpoint_exists(saving_filepath),
                test_base.is_chief())

            # If it's chief, the model should be saved (`write_filepath` should
            # simply return `saving_filepath`); if not, i.e. for non-chief workers,
            # the temporary path generated by `write_filepath` should no longer
            # contain the checkpoint that has been deleted.
            test_obj.assertEqual(
                training_state.checkpoint_exists(
                    distributed_file_utils.write_filepath(
                        saving_filepath, model._distribution_strategy)),
                test_base.is_chief())