def unet_tta(config, suffix=''):
    preprocessing, tta_generator = pipelines.preprocessing_inference_tta(
        config, model_name='unet')

    unet = Step(name='unet{}'.format(suffix),
                transformer=models.PyTorchUNet(**config.model['unet']),
                input_data=['callback_input'],
                input_steps=[preprocessing],
                is_trainable=True,
                experiment_directory=config.execution.experiment_dir)

    tta_aggregator = pipelines.aggregator(
        'tta_aggregator{}'.format(suffix),
        unet,
        tta_generator=tta_generator,
        experiment_directory=config.execution.experiment_dir,
        config=config.tta_aggregator)

    prediction_renamed = Step(
        name='prediction_renamed{}'.format(suffix),
        transformer=IdentityOperation(),
        input_steps=[tta_aggregator],
        adapter=Adapter({
            'mask_prediction':
            E(tta_aggregator.name, 'aggregated_prediction')
        }),
        experiment_directory=config.execution.experiment_dir)

    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(
            postprocessing.crop_image,
            target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize':
        size_adjustment_function = partial(
            postprocessing.resize_image,
            target_size=config.general.original_size)
    else:
        raise NotImplementedError

    mask_resize = Step(name='mask_resize{}'.format(suffix),
                       transformer=utils.make_apply_transformer(
                           size_adjustment_function,
                           output_name='resized_images',
                           apply_on=['images']),
                       input_steps=[prediction_renamed],
                       adapter=Adapter({
                           'images':
                           E(prediction_renamed.name, 'mask_prediction'),
                       }),
                       experiment_directory=config.execution.experiment_dir)

    return mask_resize
def unet(config, suffix='', train_mode=True):
    if train_mode:
        preprocessing = pipelines.preprocessing_train(config,
                                                      model_name='unet',
                                                      suffix=suffix)
    else:
        preprocessing = pipelines.preprocessing_inference(config,
                                                          suffix=suffix)

    unet = utils.FineTuneStep(
        name='unet{}'.format(suffix),
        transformer=models.PyTorchUNet(**config.model['unet']),
        input_data=['callback_input'],
        input_steps=[preprocessing],
        adapter=Adapter({
            'datagen':
            E(preprocessing.name, 'datagen'),
            'validation_datagen':
            E(preprocessing.name, 'validation_datagen'),
            'meta_valid':
            E('callback_input', 'meta_valid'),
        }),
        is_trainable=True,
        fine_tuning=config.model.unet.training_config.fine_tuning,
        experiment_directory=config.execution.experiment_dir)

    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(
            postprocessing.crop_image,
            target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize':
        size_adjustment_function = partial(
            postprocessing.resize_image,
            target_size=config.general.original_size)
    else:
        raise NotImplementedError

    mask_resize = Step(name='mask_resize{}'.format(suffix),
                       transformer=utils.make_apply_transformer(
                           size_adjustment_function,
                           output_name='resized_images',
                           apply_on=['images']),
                       input_steps=[unet],
                       adapter=Adapter({
                           'images':
                           E(unet.name, 'mask_prediction'),
                       }),
                       experiment_directory=config.execution.experiment_dir)

    return mask_resize