def unet_tta(config, suffix=''):
    preprocessing, tta_generator = pipelines.preprocessing_inference_tta(
        config, model_name='unet')

    unet = Step(name='unet{}'.format(suffix),
                transformer=models.PyTorchUNet(**config.model['unet']),
                input_data=['callback_input'],
                input_steps=[preprocessing],
                is_trainable=True,
                experiment_directory=config.execution.experiment_dir)

    tta_aggregator = pipelines.aggregator(
        'tta_aggregator{}'.format(suffix),
        unet,
        tta_generator=tta_generator,
        experiment_directory=config.execution.experiment_dir,
        config=config.tta_aggregator)

    prediction_renamed = Step(
        name='prediction_renamed{}'.format(suffix),
        transformer=IdentityOperation(),
        input_steps=[tta_aggregator],
        adapter=Adapter({
            'mask_prediction':
            E(tta_aggregator.name, 'aggregated_prediction')
        }),
        experiment_directory=config.execution.experiment_dir)

    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(
            postprocessing.crop_image,
            target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize':
        size_adjustment_function = partial(
            postprocessing.resize_image,
            target_size=config.general.original_size)
    else:
        raise NotImplementedError

    mask_resize = Step(name='mask_resize{}'.format(suffix),
                       transformer=utils.make_apply_transformer(
                           size_adjustment_function,
                           output_name='resized_images',
                           apply_on=['images']),
                       input_steps=[prediction_renamed],
                       adapter=Adapter({
                           'images':
                           E(prediction_renamed.name, 'mask_prediction'),
                       }),
                       experiment_directory=config.execution.experiment_dir)

    return mask_resize
def network_tta(config, suffix=''):
    if SECOND_LEVEL:
        raise NotImplementedError('Second level does not work with TTA')

    preprocessing, tta_generator = pipelines.preprocessing_inference_tta(config, model_name='network')

    if USE_DEPTH:
        Network = models.SegmentationModelWithDepth
    else:
        Network = models.SegmentationModel

    network = Step(name='network{}'.format(suffix),
                   transformer=Network(**config.model['network']),
                   input_data=['callback_input'],
                   input_steps=[preprocessing],
                   is_trainable=True,
                   experiment_directory=config.execution.experiment_dir)

    tta_aggregator = pipelines.aggregator('tta_aggregator{}'.format(suffix), network,
                                          tta_generator=tta_generator,
                                          experiment_directory=config.execution.experiment_dir,
                                          config=config.tta_aggregator)

    prediction_renamed = Step(name='prediction_renamed{}'.format(suffix),
                              transformer=IdentityOperation(),
                              input_steps=[tta_aggregator],
                              adapter=Adapter({'mask_prediction': E(tta_aggregator.name, 'aggregated_prediction')
                                               }),
                              experiment_directory=config.execution.experiment_dir)

    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(postprocessing.crop_image, target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize' or config.general.loader_mode == 'stacking':
        size_adjustment_function = partial(postprocessing.resize_image, target_size=config.general.original_size)
    else:
        raise NotImplementedError

    mask_resize = Step(name='mask_resize{}'.format(suffix),
                       transformer=utils.make_apply_transformer(size_adjustment_function,
                                                                output_name='resized_images',
                                                                apply_on=['images']),
                       input_steps=[prediction_renamed],
                       adapter=Adapter({'images': E(prediction_renamed.name, 'mask_prediction'),
                                        }),
                       experiment_directory=config.execution.experiment_dir)

    return mask_resize
def inference_segmentation_pipeline(config):
    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(
            postprocessing.crop_image,
            target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize' or config.general.loader_mode == 'stacking':
        size_adjustment_function = partial(
            postprocessing.resize_image,
            target_size=config.general.original_size)
    else:
        raise NotImplementedError

    if USE_TTA:
        preprocessing, tta_generator = pipelines.preprocessing_inference_tta(
            config, model_name='segmentation_network')

        segmentation_network = Step(
            name='segmentation_network',
            transformer=models.SegmentationModel(
                **config.model['segmentation_network']),
            input_steps=[preprocessing])

        tta_aggregator = pipelines.aggregator('tta_aggregator',
                                              segmentation_network,
                                              tta_generator=tta_generator,
                                              config=config.tta_aggregator)

        prediction_renamed = Step(name='prediction_renamed',
                                  transformer=IdentityOperation(),
                                  input_steps=[tta_aggregator],
                                  adapter=Adapter({
                                      'mask_prediction':
                                      E(tta_aggregator.name,
                                        'aggregated_prediction')
                                  }))

        mask_resize = Step(name='mask_resize',
                           transformer=misc.make_apply_transformer(
                               size_adjustment_function,
                               output_name='resized_images',
                               apply_on=['images'],
                               n_threads=config.execution.num_threads,
                           ),
                           input_steps=[prediction_renamed],
                           adapter=Adapter({
                               'images':
                               E(prediction_renamed.name, 'mask_prediction'),
                           }))
    else:

        preprocessing = pipelines.preprocessing_inference(
            config, model_name='segmentation_network')

        segmentation_network = misc.FineTuneStep(
            name='segmentation_network',
            transformer=models.SegmentationModel(
                **config.model['segmentation_network']),
            input_steps=[preprocessing],
        )

        mask_resize = Step(
            name='mask_resize',
            transformer=misc.make_apply_transformer(
                size_adjustment_function,
                output_name='resized_images',
                apply_on=['images'],
                n_threads=config.execution.num_threads,
            ),
            input_steps=[segmentation_network],
            adapter=Adapter({
                'images':
                E(segmentation_network.name, 'mask_prediction'),
            }),
        )

    binarizer = Step(name='binarizer',
                     transformer=misc.make_apply_transformer(
                         partial(postprocessing.binarize,
                                 threshold=config.thresholder.threshold_masks),
                         output_name='binarized_images',
                         apply_on=['images'],
                         n_threads=config.execution.num_threads),
                     input_steps=[mask_resize],
                     adapter=Adapter({
                         'images':
                         E(mask_resize.name, 'resized_images'),
                     }))

    labeler = Step(name='labeler',
                   transformer=misc.make_apply_transformer(
                       postprocessing.label,
                       output_name='labeled_images',
                       apply_on=['images'],
                       n_threads=config.execution.num_threads,
                   ),
                   input_steps=[binarizer],
                   adapter=Adapter({
                       'images':
                       E(binarizer.name, 'binarized_images'),
                   }))
    mask_postprocessing = Step(name='mask_postprocessing',
                               transformer=misc.make_apply_transformer(
                                   postprocessing.mask_postprocessing,
                                   output_name='labeled_images',
                                   apply_on=['images'],
                                   n_threads=config.execution.num_threads,
                               ),
                               input_steps=[labeler],
                               adapter=Adapter({
                                   'images':
                                   E(labeler.name, 'labeled_images'),
                               }))

    mask_postprocessing.set_mode_inference()
    mask_postprocessing.set_parameters_upstream({
        'experiment_directory':
        config.execution.experiment_dir,
        'is_fittable':
        False
    })
    segmentation_network.is_fittable = True
    return mask_postprocessing