def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: try: from clearml import Model from clearml.binding.frameworks import WeightsFileHandler except ImportError: raise RuntimeError( "This contrib module requires clearml to be installed. " "You may install clearml using: \n pip install clearml \n" ) try: basename = metadata["basename"] # type: ignore[index] except (TypeError, KeyError): warnings.warn("Checkpoint metadata missing or basename cannot be found") basename = "checkpoint" checkpoint_key = (self.dirname, basename) cb_context = self._CallbacksContext( callback_type=WeightsFileHandler.CallbackType, slots=self._checkpoint_slots[checkpoint_key], checkpoint_key=str(checkpoint_key), filename=filename, basename=basename, metadata=metadata, ) pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback) post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback) try: super(ClearMLSaver, self).__call__(checkpoint, filename, metadata) finally: WeightsFileHandler.remove_pre_callback(pre_cb_id) WeightsFileHandler.remove_post_callback(post_cb_id)
def test_clearml_saver_callbacks(): mock_task = MagicMock(spec=clearml.Task) mock_task.name = "check-task" mock_model = MagicMock(spec=clearml.OutputModel) model_info = WeightsFileHandler.ModelInfo( model=mock_model, upload_filename="test.pt", local_model_path="", local_model_id="", framework=Framework.pytorch, task=mock_task, ) mock_model_info = MagicMock(spec_set=model_info) # Simulate 4 calls to save model and 2 to remove (n_saved=2) filenames = [ "best_model_5_val_acc=0.123.pt", "best_model_6_val_acc=0.234.pt", "best_model_7_val_acc=0.356.pt", "best_model_8_val_acc=0.456.pt", ] metadata_list = [ { "basename": "best_model", "score_name": "val_acc", "priority": 0.123 }, { "basename": "best_model", "score_name": "val_acc", "priority": 0.234 }, { "basename": "best_model", "score_name": "val_acc", "priority": 0.345 }, { "basename": "best_model", "score_name": "val_acc", "priority": 0.456 }, ] dirname = "/tmp/test" _checkpoint_slots = defaultdict(list) n_saved = 2 for i, (filename, metadata) in enumerate(zip(filenames, metadata_list)): mock_model_info.upload_filename = filename if i >= n_saved: # Remove filename_to_remove = filenames[i % n_saved] for slots in _checkpoint_slots.values(): try: slots[slots.index(filename_to_remove)] = None except ValueError: pass else: i = i % n_saved break basename = metadata["basename"] checkpoint_key = (dirname, basename) context = ClearMLSaver._CallbacksContext( callback_type=WeightsFileHandler.CallbackType, slots=_checkpoint_slots[checkpoint_key], checkpoint_key=str(checkpoint_key), filename=filename, basename=basename, metadata=metadata, ) output_model_info = context.pre_callback( str(WeightsFileHandler.CallbackType.save), mock_model_info) assert (hasattr(output_model_info, "upload_filename") and f"{basename}_{i}.pt" in output_model_info.upload_filename) assert hasattr(output_model_info, "local_model_id") and str( checkpoint_key) in output_model_info.local_model_id output_model_info = context.post_callback( str(WeightsFileHandler.CallbackType.save), mock_model_info) assert hasattr(output_model_info, "model") and hasattr( output_model_info.model, "name") assert hasattr(output_model_info, "model") and hasattr( output_model_info.model, "comment") assert isinstance(output_model_info.model.name, str) and filename in output_model_info.model.name assert (isinstance(output_model_info.model.comment, str) and metadata["basename"] in output_model_info.model.comment and metadata["score_name"] in output_model_info.model.comment)