Esempio n. 1
0
def visualize_regression_prediction_i(iou_pred,
                                      i,
                                      input_dir=CONFIG.INPUT_DIR,
                                      seg_dir=CONFIG.IOU_SEG_VIS_DIR,
                                      components_dir=CONFIG.COMPONENTS_DIR):

    if os.path.isfile(get_save_path_input_i(i, input_dir=input_dir)):
        label_mapping = getattr(
            importlib.import_module(CONFIG.DATASET.module_name),
            CONFIG.DATASET.class_name)(**CONFIG.DATASET.kwargs, ).label_mapping
        pred_mapping = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name)(**CONFIG.TRAIN_DATASET.kwargs,
                                             ).label_mapping

        probs, gt, path = probs_gt_load(i, input_dir=input_dir)
        input_image = Image.open(path).convert("RGB")
        input_image = np.asarray(input_image.resize(probs.shape[:2][::-1]))
        components = components_load(i, components_dir=components_dir)

        e = entropy(probs)
        pred = np.asarray(np.argmax(probs, axis=-1), dtype='int')
        gt[gt == 255] = 0
        predc = np.asarray([
            pred_mapping[pred[p, q]][1] for p in range(pred.shape[0])
            for q in range(pred.shape[1])
        ])
        gtc = np.asarray([
            label_mapping[gt[p, q]][1] for p in range(gt.shape[0])
            for q in range(gt.shape[1])
        ])
        predc = predc.reshape(input_image.shape)
        gtc = gtc.reshape(input_image.shape)

        overlay_factor = [1.0, 0.5, 1.0]
        img_predc, img_gtc, img_entropy = [
            Image.fromarray(
                np.uint8(arr * overlay_factor[i] + input_image *
                         (1 - overlay_factor[i])))
            for i, arr in enumerate([predc, gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])
        ]

        img_ioupred = Image.fromarray(visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        image = Image.fromarray(img_total.astype('uint8'), 'RGB')

        if not os.path.exists(seg_dir):
            os.makedirs(seg_dir)

        image.save(join(seg_dir, "image{}.png".format(i)))
        plt.close()

        print("stored: {}".format(join(seg_dir, "image{}.png".format(i))))
Esempio n. 2
0
def get_ious_for_image(image_index, iou_pred, thresholds, args):
    confusion_matrices_pos = {
        t: np.zeros((num_categories, num_categories))
        for t in thresholds
    }
    confusion_matrices_neg = {
        t: np.zeros((num_categories, num_categories))
        for t in thresholds
    }

    pred, gt, _ = probs_gt_load(
        image_index,
        input_dir=join(CONFIG.metaseg_io_path, "input", "deeplabv3plus",
                       args["dataset"]),
        preds=True,
    )

    # transform a2d2 labels to cityscapes category ids
    gt = np.vectorize(label_mappings[args["dataset"]].get)(gt)

    # transform predictions to cityscapes category ids
    pred = np.vectorize(trainid_to_catid.get)(pred)

    # load components for constructing the iou mask based on different IoU thresholds
    components = components_load(
        image_index,
        components_dir=join(CONFIG.metaseg_io_path, "components",
                            "deeplabv3plus", args["dataset"]),
    )

    # border of components have been labeled with the negative index of the
    # main component itself we want however to include the border of the segment in
    # the evaluation which is why we have to make it also positive
    components = np.absolute(components)

    # -1 because component indices start with 1
    components = iou_pred[components - 1]

    for t in thresholds:
        # confusion_matrices_pos[t] = iou(pred,
        #                                 gt,
        #                                 n_classes=num_categories,
        #                                 update_matrix=confusion_matrices_pos[t],
        #                                 ignore_index=0,
        #                                 mask=(components >= t))[1]
        confusion_matrices_neg[t] = iou(
            pred,
            gt,
            n_classes=num_categories,
            update_matrix=confusion_matrices_neg[t],
            ignore_index=0,
            mask=(components < t),
        )[1]

    return confusion_matrices_pos, confusion_matrices_neg
 def compute_metrics_i(self, i, input_dir, metrics_dir, components_dir):
     """
 perform metrics computation for one image
 :param i: (int) id of the image to be processed
 """
     if os.path.isfile(get_save_path_input_i(
             i, input_dir=input_dir)) and self.rewrite:
         start = time.time()
         probs, gt, _ = probs_gt_load(i, input_dir=input_dir)
         metrics, components = compute_metrics_components(probs, gt)
         metrics_dump(metrics, i, metrics_dir=metrics_dir)
         components_dump(components, i, components_dir=components_dir)
         print('image {} processed in {}s'.format(
             i, round(time.time() - start)))
 def add_heatmap_as_metric_i(heat_dir, key, i):
     """
 derive aggregated metrics per image and add to metrics dictionary
 :param heat_dir:  (str) directory with heatmaps as numpy arrays
 :param key:       (str) new key to access added metric
 :param i:         (int) id of the image to be processed
 """
     _, _, path = probs_gt_load(i)
     heat_name = os.path.basename(path)[:-4] + ".npy"
     heatmap = np.load(heat_dir + heat_name)
     metrics = metrics_load(i, metrics_dir=CONFIG.METRICS_DIR)
     components = components_load(i, components_dir=CONFIG.COMPONENTS_DIR)
     keys = [key, key + "_in", key + "_bd", key + "_rel", key + "_rel_in"]
     heat_metric = {k: [] for k in keys}
     for comp_id in range(1, abs(np.min(components)) + 1):
         values = compute_metrics_from_heatmap(heatmap, components, comp_id)
         for j, k in enumerate(keys):
             heat_metric[k].append(values[j])
     metrics.update(heat_metric)
     metrics_dump(metrics, i, metrics_dir=CONFIG.METRICS_DIR)
Esempio n. 5
0
    def show_full_image(self, ind, save=False):
        """Displays four panels of the full image belonging to a segment.

        Top left: Entropy heatmap of prediction.
        Top right: Predicted IoU of each segment.
        Bottom left: Source image with ground truth overlay.
        Bottom right: Predicted semantic segmentation.
        """
        self.log.info(
            "{} detailed image...".format("Saving" if save else "Loading"))
        box = self.data["box"][ind]
        image = np.asarray(
            Image.open(self.data["image_path"][self.gi[ind]]).convert("RGB"))
        image_index = self.data["image_index"][self.gi[ind]]
        iou_pred = self.data["iou_pred"][self.gi[ind]]
        dataset = self.data["dataset"][self.gi[ind]]
        model_name = self.data["model_name"][self.gi[ind]]

        pred, gt, image_path = probs_gt_load(
            image_index,
            input_dir=join(CONFIG.metaseg_io_path, "input", model_name,
                           dataset),
        )
        components = components_load(
            image_index,
            components_dir=join(CONFIG.metaseg_io_path, "components",
                                model_name, dataset),
        )
        e = entropy(pred)
        pred = pred.argmax(2)
        predc = np.asarray([
            self.pred_mapping[pred[ind_i, ind_j]][1]
            for ind_i in range(pred.shape[0]) for ind_j in range(pred.shape[1])
        ]).reshape(image.shape)
        overlay_factor = [1.0, 0.5, 1.0]

        if self.label_mapping[dataset] is not None:
            gtc = np.asarray([
                self.label_mapping[dataset][gt[ind_i, ind_j]][1]
                for ind_i in range(gt.shape[0]) for ind_j in range(gt.shape[1])
            ]).reshape(image.shape)
        else:
            gtc = np.zeros_like(image)
            overlay_factor[1] = 0.0

        img_predc, img_gtc, img_entropy = [
            Image.fromarray(
                np.uint8(arr * overlay_factor[i] + image *
                         (1 - overlay_factor[i])))
            for i, arr in enumerate([predc, gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])
        ]

        img_ioupred = Image.fromarray(
            self.visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        box_line_width = 5
        left, upper = max(0, box[0] - box_line_width), max(
            0, box[1] - box_line_width)
        right, lower = min(pred.shape[1], box[2] + box_line_width), min(
            pred.shape[0], box[3] + box_line_width)

        for k in images:
            draw = ImageDraw.Draw(k)
            draw.rectangle([left, upper, right, lower],
                           outline=(255, 0, 0),
                           width=box_line_width)
            del draw

        for k in range(len(images)):
            images[k] = np.asarray(images[k]).astype("uint8")

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        fig_tmp.canvas.set_window_title("Dataset: {}, Image index: {}".format(
            dataset, image_index))
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(img_total, interpolation="nearest")

        if save:
            fig_tmp.subplots_adjust(bottom=0,
                                    left=0,
                                    right=1,
                                    top=1,
                                    hspace=0,
                                    wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind)),
                bbox_inches="tight",
                pad_inches=0.0,
            )
            self.log.debug("Saved image to {}".format(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind))))
        else:
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()
def main(args, _run, _log):
    log_config(_run, _log)

    if not args['only_plot']:
        with open(args['embeddings_file'], 'rb') as f:
            data = pkl.load(f)

        image_indices = np.array(data['image_index'])
        image_level_index = np.array(data['image_level_index'])
        gt_segments = np.array(data['gt'])
        boxes = np.array(data['box'])

        inds = get_indices(
            join(CONFIG.metaseg_io_path, 'input', 'deeplabv3plus', 'a2d2'))

        if args['file_total_count'] is None:
            total_num_instances = {cl: 0 for cl in id_to_trainid.keys()}
        else:
            with open(args['file_total_count'], 'rb') as f:
                total_num_instances = pkl.load(f)
        filtered_num_instances = {cl: 0 for cl in id_to_trainid.keys()}

        for ind in tqdm.tqdm(inds):
            pred, gt, img_path = probs_gt_load(ind,
                                               join(CONFIG.metaseg_io_path,
                                                    'input', 'deeplabv3plus',
                                                    'a2d2'),
                                               preds=True)

            # count number of instances of each class of the minimum size in ground truth and prediction
            for cl in np.unique(gt):
                components_gt, counts_gt = label(gt == cl)
                if args['file_total_count'] is None:
                    for c in range(1, counts_gt + 1):
                        segment_indices = np.argwhere(components_gt == c)
                        top, left = segment_indices.min(0)
                        bottom, right = segment_indices.max(0)
                        if (bottom - top) < args['min_height'] or (
                                right - left) < args['min_width']:
                            continue
                        else:
                            total_num_instances[cl] += 1

                if ind in image_indices:
                    for b in boxes[(gt_segments == cl)
                                   & (image_level_index == np.argwhere(
                                       image_indices == ind).squeeze()), :]:
                        components_gt, instance_counts = return_and_update_instances(
                            components_gt, b)
                        filtered_num_instances[cl] += instance_counts

        _log.info('Saving file with total counts...')
        if args['file_total_count'] is None:
            with open(args['save_file_total'], 'wb') as f:
                pkl.dump(total_num_instances, f)

        _log.info('Saving file with filtered counts...')
        with open(args['save_file_filtered'], 'wb') as f:
            pkl.dump(filtered_num_instances, f)
    else:
        with open(args['save_file_total'], 'rb') as f:
            total_num_instances = pkl.load(f)
        with open(args['save_file_filtered'], 'rb') as f:
            filtered_num_instances = pkl.load(f)

    _log.info('Start plotting')

    # aggregate over training ids:
    num_instances = {k: 0 for k in trainid_to_name.keys()}
    f_num_instances = {k: 0 for k in trainid_to_name.keys()}
    for k, v in total_num_instances.items():
        num_instances[id_to_trainid[k]] += v
    for k, v in filtered_num_instances.items():
        f_num_instances[id_to_trainid[k]] += v

    sel_classes = None
    # sel_classes = [31, 22, 12, 34, 3, 35]  # classes with many extracted instances
    # sel_classes = [1, 4, 17, 24, 16, 18]  # classes with few extracted instances
    # start_angles = [45, 0, 10, 0, 0, 0]
    start_angles = [0] * 6
    fontsize = 8

    fig = plt.figure('Class occurances filtered and not filtered',
                     figsize=(3.3, 2.5) if sel_classes is not None else
                     (10, 10),
                     dpi=args['dpi'])
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman'
                                  ] + plt.rcParams['font.serif']
    plt.rcParams['font.size'] = 6.0

    def label_autopct(pct, allvals):
        absolute = int(pct / 100.0 * np.sum(allvals))
        return '{:.1f}%\n({:d})'.format(pct, absolute) if pct > 10 else ''

    n = math.ceil(math.sqrt(len([1 for v in num_instances.values() if v > 0])))
    cmap = plt.get_cmap('tab20c')
    for i, k in enumerate([key for key, v in num_instances.items()
                           if v > 0] if sel_classes is None else sel_classes):
        if num_instances[k] > 0:
            ax = fig.add_subplot(n if sel_classes is None else 2,
                                 n if sel_classes is None else 3, i + 1)
            ax.text(
                0.5,
                1.0,
                '{}'.format(trainid_to_name[k] if not trainid_to_name[k][-1].
                            isdigit() else trainid_to_name[k][:-2]),
                horizontalalignment='center',
                transform=ax.transAxes,
                fontdict=dict(size=8),
            )
            ax.pie(
                [num_instances[k] - f_num_instances[k], f_num_instances[k]],
                radius=1.2,
                colors=cmap(np.array([10, 5])),
                startangle=start_angles[i] if sel_classes is not None else 0,
                # autopct=lambda pct: '{:1.0f}%'.format(pct) if pct > 10 else '',
                autopct=lambda pct: label_autopct(pct, [
                    num_instances[k] - f_num_instances[k], f_num_instances[k]
                ]),
                pctdistance=0.65,
                wedgeprops=dict(
                    width=1.0,
                    edgecolor='w',
                    linewidth=2,
                ),
                textprops=dict(
                    # size=fontsize,
                ),
            )
            ax.set(aspect='equal')
    fig.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.6, rect=(0.0, 0.0, 1.0, 1.0))
    plt.savefig(
        join(
            args['plot_dir'], 'instance_counts{}.{}'.format(
                '' if sel_classes is None else '_selected',
                args['plot_filetype'])),
        dpi=args['dpi'],
    )
    _log.info('Saved instance counts plot to \'{}\''.format(
        join(
            args['plot_dir'], 'instance_counts{}.{}'.format(
                '' if sel_classes is None else '_selected',
                args['plot_filetype']))))
Esempio n. 7
0
def main(args, _run, _log):
    log_config(_run, _log)
    # load a network architecture
    _log.info('Loading {}...'.format(args['net']))
    if args['net'] == 'vgg16':
        net = feature_vgg16()
    elif args['net'] == 'resnet18':
        net = feature_resnet18()
    elif args['net'] == 'resnet101':
        net = feature_resnet101()
    elif args['net'] == 'resnet152':
        net = feature_resnet152()
    elif args['net'] == 'wide_resnet101':
        net = feature_wide_resnet101()
    elif args['net'] == 'densenet201':
        net = feature_densenet201()
    else:
        raise ValueError
    net = net.cuda(args['gpu'])
    net.eval()

    # if no precomputed segments have been supplied, they have to be computed
    if args['load_file'] is None:
        _log.info('Loading Metrics...')
        xa_all = []
        start_others = []
        pred_test = []
        dataset_assignments = []
        image_indices = []

        # the first dataset of the 'datasets' configuration serves as source domain dataset. Metric statistics of this
        # dataset are used to normalize the target domain metric statistics. This is why it has to get loaded too.
        _log.info('{}...'.format(args['datasets'][0]))
        xa, ya, x_names, class_names, xa_mean, xa_std, classes_mean, classes_std, *_, start, pred = load_data(
            args['datasets'][0])

        # Now load all other metric statistics and normalize them using the source domain mean and standard deviation
        for i, d in enumerate(args['datasets'][1:], start=1):
            _log.info('{} ...'.format(d))
            num_imgs = get_indices(
                join(CONFIG.metaseg_io_path, 'metrics', 'deeplabv3plus', d))
            xa_tmp, *_, start_tmp, pred_tmp = load_data(
                d,
                num_imgs=num_imgs,
                xa_mean=xa_mean,
                xa_std=xa_std,
                classes_mean=classes_mean,
                classes_std=classes_std)
            xa_all.append(xa_tmp)
            pred_test.append(pred_tmp)
            dataset_assignments += [i] * len(num_imgs)
            image_indices += num_imgs
            start_others.append(start_tmp)

        # combine them into single arrays
        xa_all = np.concatenate(xa_all).squeeze()
        pred_test = np.concatenate(pred_test).squeeze()
        dataset_assignments = np.array(dataset_assignments).squeeze()
        image_indices = np.array(image_indices).squeeze()

        for starts in start_others[1:]:
            start_others[0] += [s + start_others[0][-1] for s in starts[1:]]
        start_all = start_others[0]
        del xa_tmp, start_tmp, pred_tmp, start_others

        _log.debug('Shape of metrics array: {}'.format(xa_all.shape))

        # Using the normalized metric statistics use a meta segmentation network pretrained on the source domain to
        # predict IoU
        _log.info('Predicting IoU...')
        if args['meta_model'] == 'neural':
            ya_pred_test = meta_nn_predict(
                pretrained_model_path=args['meta_nn_path'],
                x_test=xa_all,
                gpu=args['gpu'])
        elif args['meta_model'] == 'linear':
            ya_pred_test, _ = regression_fit_and_predict(x_train=xa,
                                                         y_train=ya,
                                                         x_test=xa_all)
        else:
            raise ValueError('Meta model {} not supported.'.format(
                args['meta_model']))

        # This list will be used as an additional filter. Only segments with class predictions in this list will be
        # picked for further processing.
        pred_class_selection = [
            # 0,  # road
            # 1,  # sidewalk
            # 2,  # building
            3,  # wall
            4,  # fence
            6,  # traffic light
            7,  # traffic sign
            # 8,  # vegetation
            # 9,  # terrain
            # 10,  # sky
            11,  # person
            12,  # rider
            13,  # car
            14,  # truck
            15,  # bus
            16,  # train
            17,  # motorcycle
            18,  # bicycle
        ]

        # Now the different filters are getting applied to the segments
        _log.info('Filtering segments...')
        inds = np.zeros(pred_test.shape[0]).astype(np.bool)

        # Filter for the predicted IoU to be less than the supplied threshold
        inds = np.logical_or(inds, (ya_pred_test < args['iou_threshold']))

        # Filter for extracting segments with predefined class predictions
        inds = np.logical_and(inds, np.isin(pred_test, pred_class_selection))

        _log.info('Filtered components (not checked for minimum size):')
        train_dat = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name)(**CONFIG.TRAIN_DATASET.kwargs)
        _log.info('\t{:^{width}s} | Filtered | Total'.format(
            'Class name',
            width=max([len(v[0]) for v in train_dat.pred_mapping.values()] +
                      [len('Class name')])))
        for cl in np.unique(pred_test).flatten():
            _log.info('\t{:^{width}s} | {:>8d} | {:<8d}'.format(
                train_dat.pred_mapping[cl][0],
                inds[pred_test == cl].sum(), (pred_test == cl).sum(),
                width=max([len(v[0])
                           for v in train_dat.pred_mapping.values()] +
                          [len('Class name')])))

        # Aggregating arguments for extraction of component information.
        inds = np.argwhere(inds).flatten()
        component_image_mapping = get_image_index_to_components(
            inds, start_all)
        p_args = [
            (v, image_indices[k], ya_pred_test[start_all[k]:start_all[k + 1]],
             args['datasets'][dataset_assignments[k]], args['min_height'],
             args['min_width'], args['min_crop_height'],
             args['min_crop_width'], 'deeplabv3plus')
            for k, v in component_image_mapping.items()
        ]

        # Extracting component information can be parallelized in a multiprocessing pool
        _log.info('Extracting component information...')
        with Pool(args['n_jobs']) as p:
            r = list(
                tqdm.tqdm(p.imap(wrapper_cutout_components, p_args),
                          total=len(p_args)))
        r = [c for c in r if len(c['component_indices']) > 0]

        _log.info('Computing embeddings...')
        crops = {
            'embeddings': [],
            'image_path': [],
            'image_index': [],
            'component_index': [],
            'box': [],
            'gt': [],
            'pred': [],
            'dataset': [],
            'model_name': [],
            'image_level_index': [],
            'iou_pred': []
        }
        # process all extracted crops and compute feature embeddings
        for c in tqdm.tqdm(r):
            # load image
            preds, gt, image_path = probs_gt_load(
                c['image_index'],
                input_dir=join(CONFIG.metaseg_io_path, 'input',
                               c['model_name'], c['dataset']),
                preds=True)

            crops['image_path'].append(image_path)
            crops['model_name'].append(c['model_name'])
            crops['dataset'].append(c['dataset'])
            crops['image_index'].append(c['image_index'])
            crops['iou_pred'].append(c['iou_pred'])

            image = Image.open(image_path).convert('RGB')
            for i, b in enumerate(c['boxes']):
                img = trans.ToTensor()(image.crop(b))
                img = trans.Normalize(mean=imagenet_mean,
                                      std=imagenet_std)(img)
                crops['embeddings'].append(get_embedding(
                    img.unsqueeze(0), net))
                crops['box'].append(b)
                crops['component_index'].append(c['component_indices'][i])
                crops['image_level_index'].append(len(crops['image_path']) - 1)
                crops['gt'].append(
                    get_component_gt(gt, c['segment_indices'][i]))
                crops['pred'].append(
                    get_component_pred(preds, c['segment_indices'][i]))

        _log.info('Saving data...')
        with open(args['save_file'], 'wb') as f:
            pkl.dump(crops, f)
    else:
        with open(args['load_file'], 'rb') as f:
            crops = pkl.load(f)

        _log.info('Computing embeddings...')
        boxes = np.array(crops['box']).squeeze()
        image_level_index = np.array(crops['image_level_index']).squeeze()
        crops['embeddings'] = []
        for i, image_path in tqdm.tqdm(enumerate(crops['image_path']),
                                       total=len(crops['image_path'])):
            image = Image.open(image_path).convert('RGB')
            for j in np.argwhere(image_level_index == i).flatten():
                img = trans.ToTensor()(image.crop(boxes[j]))
                img = trans.Normalize(mean=imagenet_mean,
                                      std=imagenet_std)(img)
                crops['embeddings'].append(get_embedding(
                    img.unsqueeze(0), net))

        if 'plot_embeddings' in crops:
            del crops['plot_embeddings']
        if 'nn_embeddings' in crops:
            del crops['nn_embeddings']

        _log.info('Saving data...')
        with open(args['save_file'], 'wb') as f:
            pkl.dump(crops, f)
Esempio n. 8
0
def main(args, _run, _log):
    log_config(_run, _log)
    # load a network architecture
    _log.info("Loading {}...".format(args["net"]))
    if args["net"] == "vgg16":
        net = feature_vgg16()
    elif args["net"] == "resnet18":
        net = feature_resnet18()
    elif args["net"] == "resnet101":
        net = feature_resnet101()
    elif args["net"] == "resnet152":
        net = feature_resnet152()
    elif args["net"] == "wide_resnet101":
        net = feature_wide_resnet101()
    elif args["net"] == "densenet201":
        net = feature_densenet201()
    else:
        raise ValueError
    net = net.cuda(args["gpu"])
    net.eval()

    # if no precomputed segments have been supplied, they have to be computed
    if args["load_file"] is None:
        _log.info("Loading Metrics...")
        xa_all = []
        start_others = []
        pred_test = []
        dataset_assignments = []
        image_indices = []

        # the first dataset of the 'datasets' configuration serves as source domain
        # dataset. Metric statistics of this dataset are used to normalize the target
        # domain metric statistics. This is why it has to get loaded too.
        if args["meta_model"] == "neural" and all(
            i in torch.load(args["meta_nn_path"]).keys()
            for i in [
                "train_xa_mean",
                "train_xa_std",
                "train_classes_mean",
                "train_classes_std",
            ]
        ):
            _log.info(
                "Loading values for normalization from saved model file '{}'".format(
                    args["meta_nn_path"]
                )
            )
            model_dict = torch.load(args["meta_nn_path"])
            xa_mean = model_dict["train_xa_mean"]
            xa_std = model_dict["train_xa_std"]
            classes_mean = model_dict["train_classes_mean"]
            classes_std = model_dict["train_classes_std"]
        else:
            _log.info("{}...".format(args["datasets"][0]))
            (
                xa,
                ya,
                x_names,
                class_names,
                xa_mean,
                xa_std,
                classes_mean,
                classes_std,
                *_,
                start,
                pred,
            ) = load_data(args["datasets"][0])

        # Now load all other metric statistics and normalize them using the source
        # domain mean and standard deviation
        for i, d in enumerate(args["datasets"][1:], start=1):
            _log.info("{} ...".format(d))
            num_imgs = get_indices(
                join(CONFIG.metaseg_io_path, "metrics", "deeplabv3plus", d)
            )
            xa_tmp, *_, start_tmp, pred_tmp = load_data(
                d,
                num_imgs=num_imgs,
                xa_mean=xa_mean,
                xa_std=xa_std,
                classes_mean=classes_mean,
                classes_std=classes_std,
            )
            xa_all.append(xa_tmp)
            pred_test.append(pred_tmp)
            dataset_assignments += [i] * len(num_imgs)
            image_indices += num_imgs
            start_others.append(start_tmp)

        # combine them into single arrays
        xa_all = np.concatenate(xa_all).squeeze()
        pred_test = np.concatenate(pred_test).squeeze()
        dataset_assignments = np.array(dataset_assignments).squeeze()
        image_indices = np.array(image_indices).squeeze()

        for starts in start_others[1:]:
            start_others[0] += [s + start_others[0][-1] for s in starts[1:]]
        start_all = start_others[0]
        del xa_tmp, start_tmp, pred_tmp, start_others

        _log.debug("Shape of metrics array: {}".format(xa_all.shape))

        # Using the normalized metric statistics use a meta segmentation network
        # pretrained on the source domain to predict IoU
        _log.info("Predicting IoU...")
        if args["meta_model"] == "neural":
            ya_pred_test = meta_nn_predict(
                pretrained_model_path=args["meta_nn_path"],
                x_test=xa_all,
                gpu=args["gpu"],
            )
        elif args["meta_model"] == "linear":
            ya_pred_test, _ = regression_fit_and_predict(
                x_train=xa, y_train=ya, x_test=xa_all
            )
        else:
            raise ValueError("Meta model {} not supported.".format(args["meta_model"]))

        # Now the different filters are getting applied to the segments
        _log.info("Filtering segments...")
        inds = np.zeros(pred_test.shape[0]).astype(np.bool)

        # Filter for the predicted IoU to be less than the supplied threshold
        inds = np.logical_or(inds, (ya_pred_test < args["iou_threshold"]))

        # Filter for extracting segments with predefined class predictions
        if hasattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            "pred_class_selection",
        ):
            pred_class_selection = getattr(
                importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
                "pred_class_selection",
            )
            inds = np.logical_and(inds, np.isin(pred_test, pred_class_selection))

        _log.info("Filtered components (not checked for minimum size):")
        train_dat = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name,
        )(**CONFIG.TRAIN_DATASET.kwargs)
        _log.info(
            "\t{:^{width}s} | Filtered | Total".format(
                "Class name",
                width=max(
                    [len(v[0]) for v in train_dat.pred_mapping.values()]
                    + [len("Class name")]
                ),
            )
        )
        for cl in np.unique(pred_test).flatten():
            _log.info(
                "\t{:^{width}s} | {:>8d} | {:<8d}".format(
                    train_dat.pred_mapping[cl][0],
                    inds[pred_test == cl].sum(),
                    (pred_test == cl).sum(),
                    width=max(
                        [len(v[0]) for v in train_dat.pred_mapping.values()]
                        + [len("Class name")]
                    ),
                )
            )

        # Aggregating arguments for extraction of component information.
        inds = np.argwhere(inds).flatten()
        component_image_mapping = get_image_index_to_components(inds, start_all)
        p_args = [
            (
                v,
                image_indices[k],
                ya_pred_test[start_all[k] : start_all[k + 1]],
                args["datasets"][dataset_assignments[k]],
                args["min_height"],
                args["min_width"],
                args["min_crop_height"],
                args["min_crop_width"],
                "deeplabv3plus",
            )
            for k, v in component_image_mapping.items()
        ]

        # Extracting component information can be parallelized in a multiprocessing pool
        _log.info("Extracting component information...")
        with Pool(args["n_jobs"]) as p:
            r = list(
                tqdm.tqdm(p.imap(wrapper_cutout_components, p_args), total=len(p_args))
            )
        r = [c for c in r if len(c["component_indices"]) > 0]

        _log.info("Computing embeddings...")
        crops = {
            "embeddings": [],
            "image_path": [],
            "image_index": [],
            "component_index": [],
            "box": [],
            "gt": [],
            "pred": [],
            "dataset": [],
            "model_name": [],
            "image_level_index": [],
            "iou_pred": [],
        }
        # process all extracted crops and compute feature embeddings
        for c in tqdm.tqdm(r):
            # load image
            preds, gt, image_path = probs_gt_load(
                c["image_index"],
                input_dir=join(
                    CONFIG.metaseg_io_path, "input", c["model_name"], c["dataset"]
                ),
                preds=True,
            )

            crops["image_path"].append(image_path)
            crops["model_name"].append(c["model_name"])
            crops["dataset"].append(c["dataset"])
            crops["image_index"].append(c["image_index"])
            crops["iou_pred"].append(c["iou_pred"])

            image = Image.open(image_path).convert("RGB")
            for i, b in enumerate(c["boxes"]):
                img = trans.ToTensor()(image.crop(b))
                img = trans.Normalize(mean=imagenet_mean, std=imagenet_std)(img)
                crops["embeddings"].append(get_embedding(img.unsqueeze(0), net))
                crops["box"].append(b)
                crops["component_index"].append(c["component_indices"][i])
                crops["image_level_index"].append(len(crops["image_path"]) - 1)
                crops["gt"].append(get_component_gt(gt, c["segment_indices"][i]))
                crops["pred"].append(get_component_pred(preds, c["segment_indices"][i]))

        _log.info("Saving data...")
        with open(args["save_file"], "wb") as f:
            pkl.dump(crops, f)
    else:
        with open(args["load_file"], "rb") as f:
            crops = pkl.load(f)

        _log.info("Computing embeddings...")
        boxes = np.array(crops["box"]).squeeze()
        image_level_index = np.array(crops["image_level_index"]).squeeze()
        crops["embeddings"] = []
        for i, image_path in tqdm.tqdm(
            enumerate(crops["image_path"]), total=len(crops["image_path"])
        ):
            image = Image.open(image_path).convert("RGB")
            for j in np.argwhere(image_level_index == i).flatten():
                img = trans.ToTensor()(image.crop(boxes[j]))
                img = trans.Normalize(mean=imagenet_mean, std=imagenet_std)(img)
                crops["embeddings"].append(get_embedding(img.unsqueeze(0), net))

        if "plot_embeddings" in crops:
            del crops["plot_embeddings"]
        if "nn_embeddings" in crops:
            del crops["nn_embeddings"]

        _log.info("Saving data...")
        with open(args["save_file"], "wb") as f:
            pkl.dump(crops, f)