def test_postprocessor_str(): postprocessor = _Postprocessor( default_uncollate, torch.relu, torch.softmax, None, ) assert str(postprocessor) == ( "_Postprocessor:\n" "\t(per_batch_transform): FuncModule(relu)\n" "\t(uncollate_fn): FuncModule(default_uncollate)\n" "\t(per_sample_transform): FuncModule(softmax)\n" "\t(serializer): None")
def _create_uncollate_postprocessors( self, stage: RunningStage, is_serving: bool = False, ) -> _Postprocessor: save_per_sample = None save_fn = None postprocess: Postprocess = self._postprocess_pipeline func_names: Dict[str, str] = { k: self._resolve_function_hierarchy(k, postprocess, stage, object_type=Postprocess) for k in self.POSTPROCESS_FUNCS } # since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here. if postprocess._save_path: save_per_sample: bool = self._is_overriden_recursive( "save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage]) if save_per_sample: save_per_sample: Callable = getattr(postprocess, func_names["save_sample"]) else: save_fn: Callable = getattr(postprocess, func_names["save_data"]) return _Postprocessor( getattr(postprocess, func_names["uncollate"]), getattr(postprocess, func_names["per_batch_transform"]), getattr(postprocess, func_names["per_sample_transform"]), serializer=None if is_serving else self._serializer, save_fn=save_fn, save_per_sample=save_per_sample, is_serving=is_serving, )