예제 #1
0
def postprocessing_pipeline_simplified(cache_dirpath, loader_mode):
    if loader_mode == 'resize':
        size_adjustment_function = partial(resize_image, target_size=ORIGINAL_SIZE)
    else:
        raise NotImplementedError

    mask_resize = Step(name='mask_resize',
                       transformer=make_apply_transformer(size_adjustment_function,
                                                          output_name='resized_images',
                                                          apply_on=['images'],
                                                          n_threads=NUM_THREADS),
                       input_data=['network_output'],
                       adapter=Adapter({'images': E('network_output', 'mask_prediction'),
                                        }))

    binarizer = Step(name='binarizer',
                     transformer=make_apply_transformer(
                         partial(binarize, threshold=THRESHOLD),
                         output_name='binarized_images',
                         apply_on=['images'],
                         n_threads=NUM_THREADS),
                     input_steps=[mask_resize],
                     adapter=Adapter({'images': E(mask_resize.name, 'resized_images'),
                                      }))

    labeler = Step(name='labeler',
                   transformer=make_apply_transformer(
                       label,
                       output_name='labeled_images',
                       apply_on=['images'],
                       n_threads=NUM_THREADS),
                   input_steps=[binarizer],
                   adapter=Adapter({'images': E(binarizer.name, 'binarized_images'),
                                    }))

    labeler.set_mode_inference()
    labeler.set_parameters_upstream({'experiment_directory': cache_dirpath,
                                     'is_fittable': False
                                     })
    return labeler
def inference_segmentation_pipeline(config):
    if config.general.loader_mode == 'resize_and_pad':
        size_adjustment_function = partial(
            postprocessing.crop_image,
            target_size=config.general.original_size)
    elif config.general.loader_mode == 'resize' or config.general.loader_mode == 'stacking':
        size_adjustment_function = partial(
            postprocessing.resize_image,
            target_size=config.general.original_size)
    else:
        raise NotImplementedError

    if USE_TTA:
        preprocessing, tta_generator = pipelines.preprocessing_inference_tta(
            config, model_name='segmentation_network')

        segmentation_network = Step(
            name='segmentation_network',
            transformer=models.SegmentationModel(
                **config.model['segmentation_network']),
            input_steps=[preprocessing])

        tta_aggregator = pipelines.aggregator('tta_aggregator',
                                              segmentation_network,
                                              tta_generator=tta_generator,
                                              config=config.tta_aggregator)

        prediction_renamed = Step(name='prediction_renamed',
                                  transformer=IdentityOperation(),
                                  input_steps=[tta_aggregator],
                                  adapter=Adapter({
                                      'mask_prediction':
                                      E(tta_aggregator.name,
                                        'aggregated_prediction')
                                  }))

        mask_resize = Step(name='mask_resize',
                           transformer=misc.make_apply_transformer(
                               size_adjustment_function,
                               output_name='resized_images',
                               apply_on=['images'],
                               n_threads=config.execution.num_threads,
                           ),
                           input_steps=[prediction_renamed],
                           adapter=Adapter({
                               'images':
                               E(prediction_renamed.name, 'mask_prediction'),
                           }))
    else:

        preprocessing = pipelines.preprocessing_inference(
            config, model_name='segmentation_network')

        segmentation_network = misc.FineTuneStep(
            name='segmentation_network',
            transformer=models.SegmentationModel(
                **config.model['segmentation_network']),
            input_steps=[preprocessing],
        )

        mask_resize = Step(
            name='mask_resize',
            transformer=misc.make_apply_transformer(
                size_adjustment_function,
                output_name='resized_images',
                apply_on=['images'],
                n_threads=config.execution.num_threads,
            ),
            input_steps=[segmentation_network],
            adapter=Adapter({
                'images':
                E(segmentation_network.name, 'mask_prediction'),
            }),
        )

    binarizer = Step(name='binarizer',
                     transformer=misc.make_apply_transformer(
                         partial(postprocessing.binarize,
                                 threshold=config.thresholder.threshold_masks),
                         output_name='binarized_images',
                         apply_on=['images'],
                         n_threads=config.execution.num_threads),
                     input_steps=[mask_resize],
                     adapter=Adapter({
                         'images':
                         E(mask_resize.name, 'resized_images'),
                     }))

    labeler = Step(name='labeler',
                   transformer=misc.make_apply_transformer(
                       postprocessing.label,
                       output_name='labeled_images',
                       apply_on=['images'],
                       n_threads=config.execution.num_threads,
                   ),
                   input_steps=[binarizer],
                   adapter=Adapter({
                       'images':
                       E(binarizer.name, 'binarized_images'),
                   }))
    mask_postprocessing = Step(name='mask_postprocessing',
                               transformer=misc.make_apply_transformer(
                                   postprocessing.mask_postprocessing,
                                   output_name='labeled_images',
                                   apply_on=['images'],
                                   n_threads=config.execution.num_threads,
                               ),
                               input_steps=[labeler],
                               adapter=Adapter({
                                   'images':
                                   E(labeler.name, 'labeled_images'),
                               }))

    mask_postprocessing.set_mode_inference()
    mask_postprocessing.set_parameters_upstream({
        'experiment_directory':
        config.execution.experiment_dir,
        'is_fittable':
        False
    })
    segmentation_network.is_fittable = True
    return mask_postprocessing