def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" # pylint: disable=protected-access try: file_path = self.filepath.format(epoch=epoch + 1, **logs) except KeyError as e: raise KeyError('Failed to format this callback filepath: "{}". ' 'Reason: {}'.format(self.filepath, e)) self._write_filepath = distributed_file_utils.write_filepath( file_path, self.model.distribute_strategy) if self.opt: li = file_path.split(".") li[-2] += "_opt" filr_path_opt = ".".join(li) self._write_filepath_opt = distributed_file_utils.write_filepath( filr_path_opt, self.model.distribute_strategy) return self._write_filepath
def testChiefWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedChiefStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), filepath) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
def testWorkerWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedWorkerStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), os.path.join(dirpath, 'workertemp_3', 'foo.bar')) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), os.path.join(dirpath, 'workertemp_3'))
def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" try: file_path = self.filepath.format( epoch=epoch + 1, timer=datetime.datetime.now().strftime('%m%d_%H%M%S'), **logs) except KeyError as e: raise KeyError('Failed to format this callback filepath: "{}". ' 'Reason: {}'.format(self.filepath, e)) self._write_filepath = distributed_file_utils.write_filepath( file_path, self.model.distribute_strategy) return self._write_filepath
def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" # pylint: disable=protected-access try: # `filepath` may contain placeholders such as `{epoch:02d}` and # `{mape:.2f}`. A mismatch between logged metrics and the path's # placeholders can cause formatting to fail. file_path = self.filepath.format(epoch=epoch + 1, **logs) except KeyError as e: raise KeyError('Failed to format this callback filepath: "{}". ' 'Reason: {}'.format(self.filepath, e)) self._write_filepath = distributed_file_utils.write_filepath( file_path, self.model.distribute_strategy) return self._write_filepath
def proc_model_checkpoint_saves_on_chief_but_not_otherwise( test_obj, file_format): model, saving_filepath, train_ds, steps = _model_setup( test_obj, file_format) num_epoch = 2 extension = os.path.splitext(saving_filepath)[1] # Incorporate type/index information and thread id in saving_filepath to # ensure every worker has a unique path. Note that in normal use case the # saving_filepath will be the same for all workers, but we use different # ones here just to test out chief saves checkpoint but non-chief doesn't. saving_filepath = os.path.join( test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' % (test_base.get_task_type(), test_base.get_task_index(), extension)) # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse( training_state.checkpoint_exists(saving_filepath)) model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, validation_data=train_ds, validation_steps=steps, callbacks=[ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only) ]) # If it's chief, the model should be saved; if not, the model shouldn't. test_obj.assertEqual( training_state.checkpoint_exists(saving_filepath), test_base.is_chief()) # If it's chief, the model should be saved (`write_filepath` should # simply return `saving_filepath`); if not, i.e. for non-chief workers, # the temporary path generated by `write_filepath` should no longer # contain the checkpoint that has been deleted. test_obj.assertEqual( training_state.checkpoint_exists( distributed_file_utils.write_filepath( saving_filepath, model._distribution_strategy)), test_base.is_chief())