Example #1
0
def test_postprocessor_str():
    postprocessor = _Postprocessor(
        default_uncollate,
        torch.relu,
        torch.softmax,
        None,
    )
    assert str(postprocessor) == (
        "_Postprocessor:\n"
        "\t(per_batch_transform): FuncModule(relu)\n"
        "\t(uncollate_fn): FuncModule(default_uncollate)\n"
        "\t(per_sample_transform): FuncModule(softmax)\n"
        "\t(serializer): None")
    def _create_uncollate_postprocessors(
        self,
        stage: RunningStage,
        is_serving: bool = False,
    ) -> _Postprocessor:
        save_per_sample = None
        save_fn = None

        postprocess: Postprocess = self._postprocess_pipeline

        func_names: Dict[str, str] = {
            k: self._resolve_function_hierarchy(k,
                                                postprocess,
                                                stage,
                                                object_type=Postprocess)
            for k in self.POSTPROCESS_FUNCS
        }

        # since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
        if postprocess._save_path:
            save_per_sample: bool = self._is_overriden_recursive(
                "save_sample",
                postprocess,
                Postprocess,
                prefix=_STAGES_PREFIX[stage])

            if save_per_sample:
                save_per_sample: Callable = getattr(postprocess,
                                                    func_names["save_sample"])
            else:
                save_fn: Callable = getattr(postprocess,
                                            func_names["save_data"])

        return _Postprocessor(
            getattr(postprocess, func_names["uncollate"]),
            getattr(postprocess, func_names["per_batch_transform"]),
            getattr(postprocess, func_names["per_sample_transform"]),
            serializer=None if is_serving else self._serializer,
            save_fn=save_fn,
            save_per_sample=save_per_sample,
            is_serving=is_serving,
        )