Ejemplo n.º 1
0
    def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
        try:
            from clearml import Model
            from clearml.binding.frameworks import WeightsFileHandler
        except ImportError:
            raise RuntimeError(
                "This contrib module requires clearml to be installed. "
                "You may install clearml using: \n pip install clearml \n"
            )

        try:
            basename = metadata["basename"]  # type: ignore[index]
        except (TypeError, KeyError):
            warnings.warn("Checkpoint metadata missing or basename cannot be found")
            basename = "checkpoint"

        checkpoint_key = (self.dirname, basename)

        cb_context = self._CallbacksContext(
            callback_type=WeightsFileHandler.CallbackType,
            slots=self._checkpoint_slots[checkpoint_key],
            checkpoint_key=str(checkpoint_key),
            filename=filename,
            basename=basename,
            metadata=metadata,
        )

        pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
        post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)

        try:
            super(ClearMLSaver, self).__call__(checkpoint, filename, metadata)
        finally:
            WeightsFileHandler.remove_pre_callback(pre_cb_id)
            WeightsFileHandler.remove_post_callback(post_cb_id)
Ejemplo n.º 2
0
def test_clearml_saver_callbacks():
    mock_task = MagicMock(spec=clearml.Task)
    mock_task.name = "check-task"

    mock_model = MagicMock(spec=clearml.OutputModel)

    model_info = WeightsFileHandler.ModelInfo(
        model=mock_model,
        upload_filename="test.pt",
        local_model_path="",
        local_model_id="",
        framework=Framework.pytorch,
        task=mock_task,
    )

    mock_model_info = MagicMock(spec_set=model_info)

    # Simulate 4 calls to save model and 2 to remove (n_saved=2)
    filenames = [
        "best_model_5_val_acc=0.123.pt",
        "best_model_6_val_acc=0.234.pt",
        "best_model_7_val_acc=0.356.pt",
        "best_model_8_val_acc=0.456.pt",
    ]
    metadata_list = [
        {
            "basename": "best_model",
            "score_name": "val_acc",
            "priority": 0.123
        },
        {
            "basename": "best_model",
            "score_name": "val_acc",
            "priority": 0.234
        },
        {
            "basename": "best_model",
            "score_name": "val_acc",
            "priority": 0.345
        },
        {
            "basename": "best_model",
            "score_name": "val_acc",
            "priority": 0.456
        },
    ]
    dirname = "/tmp/test"

    _checkpoint_slots = defaultdict(list)

    n_saved = 2

    for i, (filename, metadata) in enumerate(zip(filenames, metadata_list)):

        mock_model_info.upload_filename = filename

        if i >= n_saved:
            # Remove
            filename_to_remove = filenames[i % n_saved]
            for slots in _checkpoint_slots.values():
                try:
                    slots[slots.index(filename_to_remove)] = None
                except ValueError:
                    pass
                else:
                    i = i % n_saved
                    break

        basename = metadata["basename"]
        checkpoint_key = (dirname, basename)

        context = ClearMLSaver._CallbacksContext(
            callback_type=WeightsFileHandler.CallbackType,
            slots=_checkpoint_slots[checkpoint_key],
            checkpoint_key=str(checkpoint_key),
            filename=filename,
            basename=basename,
            metadata=metadata,
        )

        output_model_info = context.pre_callback(
            str(WeightsFileHandler.CallbackType.save), mock_model_info)
        assert (hasattr(output_model_info, "upload_filename")
                and f"{basename}_{i}.pt" in output_model_info.upload_filename)
        assert hasattr(output_model_info, "local_model_id") and str(
            checkpoint_key) in output_model_info.local_model_id

        output_model_info = context.post_callback(
            str(WeightsFileHandler.CallbackType.save), mock_model_info)
        assert hasattr(output_model_info, "model") and hasattr(
            output_model_info.model, "name")
        assert hasattr(output_model_info, "model") and hasattr(
            output_model_info.model, "comment")
        assert isinstance(output_model_info.model.name,
                          str) and filename in output_model_info.model.name
        assert (isinstance(output_model_info.model.comment, str)
                and metadata["basename"] in output_model_info.model.comment
                and metadata["score_name"] in output_model_info.model.comment)