Пример #1
0
class Evaluator(object):
    # def __init__(self, n_class, size0, size_g, size_p, n, sub_batch_size=6, mode=1, test=False):
    def __init__(self, n_class, size_g, size_p, sub_batch_size=6, mode=1, test=False):
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        # self.size0 = size0
        self.size_g = size_g
        self.size_p = size_p
        # self.n = n
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.test = test

        # self.ratio = float(size_p[0]) / size0[0]
        # self.step = (size0[0] - size_p[0]) // (n - 1)
        # self.template, self.coordinates = template_patch2global(size0, size_p, n, self.step)

        if test:
            self.flip_range = [False, True]
            self.rotate_range = [0, 1, 2, 3]
        else:
            self.flip_range = [False]
            self.rotate_range = [0]

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def eval_test(self, sample, model, global_fixed):
        with torch.no_grad():
            images = sample["image"]
            if not self.test:
                labels = sample["label"]  # PIL images
                #lbls = [RGB_mapping_to_class(np.array(label)) for label in labels]
                #labels = [Image.fromarray(lbl) for lbl in lbls]
                labels_npy = masks_transform(labels, numpy=True)

            images_global = resize(images, self.size_g)
            outputs_global = np.zeros(
                (len(images), self.n_class, self.size_g[0] // 4, self.size_g[1] // 4)
            )
            # outputs_global = np.zeros((len(images), self.n_class, self.size_g[0], self.size_g[1]))
            if self.mode == 2 or self.mode == 3:
                images_local = [image.copy() for image in images]
                # scores_local = np.zeros((len(images), self.n_class, self.size0[0], self.size0[1]))
                # scores = np.zeros((len(images), self.n_class, self.size0[0], self.size0[1]))
                scores_local = [
                    np.zeros((1, self.n_class, images[i].size[1], images[i].size[0]))
                    for i in range(len(images))
                ]
                scores = [
                    np.zeros((1, self.n_class, images[i].size[1], images[i].size[0]))
                    for i in range(len(images))
                ]

            for flip in self.flip_range:
                if flip:
                    # we already rotated images for 270'
                    for b in range(len(images)):
                        images_global[b] = transforms.functional.rotate(
                            images_global[b], 90
                        )  # rotate back!
                        images_global[b] = transforms.functional.hflip(images_global[b])
                        if self.mode == 2 or self.mode == 3:
                            images_local[b] = transforms.functional.rotate(
                                images_local[b], 90
                            )  # rotate back!
                            images_local[b] = transforms.functional.hflip(
                                images_local[b]
                            )
                for angle in self.rotate_range:
                    if angle > 0:
                        for b in range(len(images)):
                            images_global[b] = transforms.functional.rotate(
                                images_global[b], 90
                            )
                            if self.mode == 2 or self.mode == 3:
                                images_local[b] = transforms.functional.rotate(
                                    images_local[b], 90
                                )

                    # prepare global images onto cuda
                    images_glb = images_transform(images_global)  # b, c, h, w

                    if self.mode == 2 or self.mode == 3:
                        # patches = global2patch(images_local, self.n, self.step, self.size_p)
                        patches, coordinates, templates, sizes, ratios = global2patch(
                        # patches, coordinates, _, sizes, ratios = global2patch(
                            images, self.size_p
                        )
                        # predicted_patches = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
                        # predicted_ensembles = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
                        predicted_patches = [
                            np.zeros(
                                (
                                    len(coordinates[i]),
                                    self.n_class,
                                    self.size_p[0],
                                    self.size_p[1],
                                )
                            )
                            for i in range(len(images))
                        ]
                        predicted_ensembles = [
                            np.zeros(
                                (
                                    len(coordinates[i]),
                                    self.n_class,
                                    self.size_p[0],
                                    self.size_p[1],
                                )
                            )
                            for i in range(len(images))
                        ]

                    if self.mode == 1:
                        # eval with only resized global image 
                        if flip:
                            outputs_global += np.flip(
                                np.rot90(
                                    model.forward(images_glb, None, None, None)[0]
                                    .data.cpu()
                                    .numpy(),
                                    k=angle,
                                    axes=(3, 2),
                                ),
                                axis=3,
                            )
                        else:
                            outputs_global += np.rot90(
                                model.forward(images_glb, None, None, None)[0]
                                .data.cpu()
                                .numpy(),
                                k=angle,
                                axes=(3, 2),
                            )

                    if self.mode == 2:
                        # eval with patches 
                        for i in range(len(images)):
                            j = 0
                            # while j < self.n**2:
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i][j : j + self.sub_batch_size]
                                )  # b, c, h, w
                                # output_ensembles, output_global, output_patches, _ = model.forward(images_glb[i:i+1], patches_var, self.coordinates[j : j+self.sub_batch_size], self.ratio, mode=self.mode, n_patch=self.n**2) # include cordinates
                                output_ensembles, output_global, output_patches, _ = model.forward(
                                    images_glb[i : i + 1],
                                    patches_var,
                                    coordinates[i][j : j + self.sub_batch_size],
                                    ratios[i],
                                    mode=self.mode,
                                    n_patch=len(coordinates[i]),
                                )

                                # patch predictions
                                predicted_patches[i][
                                    j : j + output_patches.size()[0]
                                ] += (
                                    F.interpolate(
                                        output_patches, size=self.size_p, mode="nearest"
                                    )
                                    .data.cpu()
                                    .numpy()
                                )
                                predicted_ensembles[i][
                                    j : j + output_ensembles.size()[0]
                                ] += (
                                    F.interpolate(
                                        output_ensembles,
                                        size=self.size_p,
                                        mode="nearest",
                                    )
                                    .data.cpu()
                                    .numpy()
                                )
                                j += patches_var.size()[0]
                            if flip:
                                outputs_global[i] += np.flip(
                                    np.rot90(
                                        output_global[0].data.cpu().numpy(),
                                        k=angle,
                                        axes=(2, 1),
                                    ),
                                    axis=2,
                                )
                                # scores_local[i] += np.flip(np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps)
                                # scores[i] += np.flip(np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps)
                                scores_local[i] += np.flip(
                                    np.rot90(
                                        np.array(
                                            patch2global(
                                                predicted_patches[i : i + 1],
                                                self.n_class,
                                                sizes[i : i + 1],
                                                coordinates[i : i + 1],
                                                self.size_p,
                                            )
                                        ),
                                        k=angle,
                                        axes=(3, 2),
                                    ),
                                    axis=3,
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.flip(
                                    np.rot90(
                                        np.array(
                                            patch2global(
                                                predicted_ensembles[i : i + 1],
                                                self.n_class,
                                                sizes[i : i + 1],
                                                coordinates[i : i + 1],
                                                self.size_p,
                                            )
                                        ),
                                        k=angle,
                                        axes=(3, 2),
                                    ),
                                    axis=3,
                                )  # merge softmax scores from patches (overlaps)
                            else:
                                outputs_global[i] += np.rot90(
                                    output_global[0].data.cpu().numpy(),
                                    k=angle,
                                    axes=(2, 1),
                                )
                                # scores_local[i] += np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps)
                                # scores[i] += np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps)
                                scores_local[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_patches[i : i + 1],
                                            self.n_class,
                                            sizes[i : i + 1],
                                            coordinates[i : i + 1],
                                            self.size_p,
                                        )
                                    ),
                                    k=angle,
                                    axes=(3, 2),
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_ensembles[i : i + 1],
                                            self.n_class,
                                            sizes[i : i + 1],
                                            coordinates[i : i + 1],
                                            self.size_p,
                                        )
                                    ),
                                    k=angle,
                                    axes=(3, 2),
                                )  # merge softmax scores from patches (overlaps)

                    if self.mode == 3:
                        # eval global with help from patches 
                        # go through local patches to collect feature maps
                        # collect predictions from patches
                        for i in range(len(images)):
                            j = 0
                            # while j < self.n**2:
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i][j : j + self.sub_batch_size]
                                )  # b, c, h, w
                                # fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, self.ratio, self.coordinates, [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=self.n**2) # include cordinates
                                fm_patches, output_patches = model.module.collect_local_fm(
                                    images_glb[i : i + 1],
                                    patches_var,
                                    ratios[i],
                                    coordinates[i],
                                    [j, j + self.sub_batch_size],
                                    len(images),
                                    global_model=global_fixed,
                                    template=templates[0],
                                    n_patch_all=len(coordinates[i]),
                                )
                                predicted_patches[i][
                                    j : j + output_patches.size()[0]
                                ] += (
                                    F.interpolate(
                                        output_patches, size=self.size_p, mode="nearest"
                                    )
                                    .data.cpu()
                                    .numpy()
                                )
                                j += self.sub_batch_size
                        # go through global image
                        # tmp, fm_global = model.forward(images_glb, None, self.coordinates, self.ratio, mode=self.mode, global_model=None, n_patch=self.n**2) # include cordinates
                        tmp, fm_global = model.forward(
                            images_glb, None, None, None, mode=self.mode
                        )
                        if flip:
                            outputs_global += np.flip(
                                np.rot90(tmp.data.cpu().numpy(), k=angle, axes=(3, 2)),
                                axis=3,
                            )
                        else:
                            outputs_global += np.rot90(
                                tmp.data.cpu().numpy(), k=angle, axes=(3, 2)
                            )
                        # generate ensembles
                        for i in range(len(images)):
                            j = 0
                            # while j < self.n ** 2:
                            while j < len(coordinates[i]):
                                fl = fm_patches[i][j : j + self.sub_batch_size].cuda()
                                # fg = model.module._crop_global(fm_global[i:i+1], self.coordinates[j:j+self.sub_batch_size], self.ratio)[0]
                                fg = model.module._crop_global(
                                    fm_global[i : i + 1],
                                    coordinates[i][j : j + self.sub_batch_size],
                                    ratios[i],
                                )[0]
                                fg = F.interpolate(
                                    fg, size=fl.size()[2:], mode="bilinear"
                                )
                                output_ensembles = model.module.ensemble(
                                    fl, fg
                                )  # include cordinates
                                # output_ensembles = F.interpolate(model.module.ensemble(fl, fg), self.size_p, **model.module._up_kwargs)

                                # ensemble predictions
                                predicted_ensembles[i][
                                    j : j + output_ensembles.size()[0]
                                ] += (
                                    F.interpolate(
                                        output_ensembles,
                                        size=self.size_p,
                                        mode="nearest",
                                    )
                                    .data.cpu()
                                    .numpy()
                                )
                                j += self.sub_batch_size
                            if flip:
                                # scores_local[i] += np.flip(np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps)
                                # scores[i] += np.flip(np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps)
                                scores_local[i] += np.flip(
                                    np.rot90(
                                        np.array(
                                            patch2global(
                                                predicted_patches[i : i + 1],
                                                self.n_class,
                                                sizes[i : i + 1],
                                                coordinates[i : i + 1],
                                                self.size_p,
                                            )
                                        ),
                                        k=angle,
                                        axes=(3, 2),
                                    ),
                                    axis=3,
                                )[
                                    0
                                ]  # merge softmax scores from patches (overlaps)
                                scores[i] += np.flip(
                                    np.rot90(
                                        np.array(
                                            patch2global(
                                                predicted_ensembles[i : i + 1],
                                                self.n_class,
                                                sizes[i : i + 1],
                                                coordinates[i : i + 1],
                                                self.size_p,
                                            )
                                        ),
                                        k=angle,
                                        axes=(3, 2),
                                    ),
                                    axis=3,
                                )[
                                    0
                                ]  # merge softmax scores from patches (overlaps)
                            else:
                                # scores_local[i] += np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps)
                                # scores[i] += np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, self.n, self.step, self.size0, self.size_p, len(images))), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps)
                                scores_local[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_patches[i : i + 1],
                                            self.n_class,
                                            sizes[i : i + 1],
                                            coordinates[i : i + 1],
                                            self.size_p,
                                        )
                                    ),
                                    k=angle,
                                    axes=(3, 2),
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_ensembles[i : i + 1],
                                            self.n_class,
                                            sizes[i : i + 1],
                                            coordinates[i : i + 1],
                                            self.size_p,
                                        )
                                    ),
                                    k=angle,
                                    axes=(3, 2),
                                )  # merge softmax scores from patches (overlaps)

            # global predictions 
            # predictions_global = F.interpolate(torch.Tensor(outputs_global), self.size0, mode='nearest').argmax(1).detach().numpy()
            outputs_global = torch.Tensor(outputs_global)
            predictions_global = [
                F.interpolate(
                    outputs_global[i : i + 1], images[i].size[::-1], mode="nearest"
                )
                .argmax(1)
                .detach()
                .numpy()[0]
                for i in range(len(images))
            ]
            if not self.test:
                self.metrics_global.update(labels_npy, predictions_global)

            if self.mode == 2 or self.mode == 3:
                # patch predictions 
                # predictions_local = scores_local.argmax(1) # b, h, w
                predictions_local = [score.argmax(1)[0] for score in scores_local]
                if not self.test:
                    self.metrics_local.update(labels_npy, predictions_local)
                
                # combined/ensemble predictions 
                # predictions = scores.argmax(1) # b, h, w
                predictions = [score.argmax(1)[0] for score in scores]
                if not self.test:
                    self.metrics.update(labels_npy, predictions)
                return predictions, predictions_global, predictions_local
            else:
                return None, predictions_global, None
Пример #2
0
class Evaluator(object):
    def __init__(self,
                 n_class,
                 size_g,
                 size_p,
                 sub_batch_size=6,
                 mode=1,
                 test=False):
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.size_g = size_g
        self.size_p = size_p
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.test = test

        if test:
            self.flip_range = [False, True]
            self.rotate_range = [0, 1, 2, 3]
        else:
            self.flip_range = [False]
            self.rotate_range = [0]

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def eval_test(self, sample, model, global_fixed):
        with torch.no_grad():
            images = sample['image']
            if not self.test:
                labels = sample['label']  # PIL images
                labels_npy = masks_transform(labels, numpy=True)

            images_global = resize(images, self.size_g)
            outputs_global = np.zeros(
                (len(images), self.n_class, self.size_g[0] // 4,
                 self.size_g[1] // 4))
            if self.mode == 2 or self.mode == 3:
                images_local = [image.copy() for image in images]
                scores_local = [
                    np.zeros((1, self.n_class, images[i].size[1],
                              images[i].size[0])) for i in range(len(images))
                ]
                scores = [
                    np.zeros((1, self.n_class, images[i].size[1],
                              images[i].size[0])) for i in range(len(images))
                ]

            for flip in self.flip_range:
                if flip:
                    # we already rotated images for 270'
                    for b in range(len(images)):
                        images_global[b] = transforms.functional.rotate(
                            images_global[b], 90)  # rotate back!
                        images_global[b] = transforms.functional.hflip(
                            images_global[b])
                        if self.mode == 2 or self.mode == 3:
                            images_local[b] = transforms.functional.rotate(
                                images_local[b], 90)  # rotate back!
                            images_local[b] = transforms.functional.hflip(
                                images_local[b])
                for angle in self.rotate_range:
                    if angle > 0:
                        for b in range(len(images)):
                            images_global[b] = transforms.functional.rotate(
                                images_global[b], 90)
                            if self.mode == 2 or self.mode == 3:
                                images_local[b] = transforms.functional.rotate(
                                    images_local[b], 90)

                    # prepare global images onto cuda
                    images_glb = images_transform(images_global)  # b, c, h, w

                    if self.mode == 2 or self.mode == 3:
                        patches, coordinates, templates, sizes, ratios = global2patch(
                            images, self.size_p)
                        predicted_patches = [
                            np.zeros((len(coordinates[i]), self.n_class,
                                      self.size_p[0], self.size_p[1]))
                            for i in range(len(images))
                        ]
                        predicted_ensembles = [
                            np.zeros((len(coordinates[i]), self.n_class,
                                      self.size_p[0], self.size_p[1]))
                            for i in range(len(images))
                        ]

                    if self.mode == 1:
                        # eval with only resized global image ##########################
                        if flip:
                            outputs_global += np.flip(np.rot90(model.forward(
                                images_glb, None, None,
                                None)[0].data.cpu().numpy(),
                                                               k=angle,
                                                               axes=(3, 2)),
                                                      axis=3)
                        else:
                            outputs_global += np.rot90(model.forward(
                                images_glb, None, None,
                                None)[0].data.cpu().numpy(),
                                                       k=angle,
                                                       axes=(3, 2))
                        ################################################################

                    if self.mode == 2:
                        # eval with patches ###########################################
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i]
                                    [j:j + self.sub_batch_size])  # b, c, h, w
                                output_ensembles, output_global, output_patches, _ = model.forward(
                                    images_glb[i:i + 1],
                                    patches_var,
                                    coordinates[i][j:j + self.sub_batch_size],
                                    ratios[i],
                                    mode=self.mode,
                                    n_patch=len(coordinates[i]))

                                # patch predictions
                                predicted_patches[i][j:j + output_patches.size(
                                )[0]] += F.interpolate(
                                    output_patches,
                                    size=self.size_p,
                                    mode='nearest').data.cpu().numpy()
                                predicted_ensembles[i][j:j +
                                                       output_ensembles.size(
                                                       )[0]] += F.interpolate(
                                                           output_ensembles,
                                                           size=self.size_p,
                                                           mode='nearest'
                                                       ).data.cpu().numpy()
                                j += patches_var.size()[0]
                            if flip:
                                outputs_global[i] += np.flip(np.rot90(
                                    output_global[0].data.cpu().numpy(),
                                    k=angle,
                                    axes=(2, 1)),
                                                             axis=2)
                                scores_local[i] += np.flip(
                                    np.rot90(np.array(
                                        patch2global(
                                            predicted_patches[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                             k=angle,
                                             axes=(3, 2)),
                                    axis=3
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.flip(
                                    np.rot90(np.array(
                                        patch2global(
                                            predicted_ensembles[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                             k=angle,
                                             axes=(3, 2)),
                                    axis=3
                                )  # merge softmax scores from patches (overlaps)
                            else:
                                outputs_global[i] += np.rot90(
                                    output_global[0].data.cpu().numpy(),
                                    k=angle,
                                    axes=(2, 1))
                                scores_local[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_patches[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                    k=angle,
                                    axes=(3, 2)
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_ensembles[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                    k=angle,
                                    axes=(3, 2)
                                )  # merge softmax scores from patches (overlaps)
                        ###############################################################

                    if self.mode == 3:
                        # eval global with help from patches ##################################################
                        # go through local patches to collect feature maps
                        # collect predictions from patches
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i]
                                    [j:j + self.sub_batch_size])  # b, c, h, w
                                fm_patches, output_patches = model.module.collect_local_fm(
                                    images_glb[i:i + 1],
                                    patches_var,
                                    ratios[i],
                                    coordinates[i],
                                    [j, j + self.sub_batch_size],
                                    len(images),
                                    global_model=global_fixed,
                                    template=templates[i],
                                    n_patch_all=len(coordinates[i]))
                                predicted_patches[i][j:j + output_patches.size(
                                )[0]] += F.interpolate(
                                    output_patches,
                                    size=self.size_p,
                                    mode='nearest').data.cpu().numpy()
                                j += self.sub_batch_size
                        # go through global image
                        tmp, fm_global = model.forward(images_glb,
                                                       None,
                                                       None,
                                                       None,
                                                       mode=self.mode)
                        if flip:
                            outputs_global += np.flip(np.rot90(
                                tmp.data.cpu().numpy(), k=angle, axes=(3, 2)),
                                                      axis=3)
                        else:
                            outputs_global += np.rot90(tmp.data.cpu().numpy(),
                                                       k=angle,
                                                       axes=(3, 2))
                        # generate ensembles
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                fl = fm_patches[i][j:j +
                                                   self.sub_batch_size].cuda()
                                fg = model.module._crop_global(
                                    fm_global[i:i + 1],
                                    coordinates[i][j:j + self.sub_batch_size],
                                    ratios[i])[0]
                                fg = F.interpolate(fg,
                                                   size=fl.size()[2:],
                                                   mode='bilinear')
                                output_ensembles = model.module.ensemble(
                                    fl, fg)  # include cordinates

                                # ensemble predictions
                                predicted_ensembles[i][j:j +
                                                       output_ensembles.size(
                                                       )[0]] += F.interpolate(
                                                           output_ensembles,
                                                           size=self.size_p,
                                                           mode='nearest'
                                                       ).data.cpu().numpy()
                                j += self.sub_batch_size
                            if flip:
                                scores_local[i] += np.flip(
                                    np.rot90(np.array(
                                        patch2global(
                                            predicted_patches[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                             k=angle,
                                             axes=(3, 2)),
                                    axis=3
                                )[0]  # merge softmax scores from patches (overlaps)
                                scores[i] += np.flip(
                                    np.rot90(np.array(
                                        patch2global(
                                            predicted_ensembles[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                             k=angle,
                                             axes=(3, 2)),
                                    axis=3
                                )[0]  # merge softmax scores from patches (overlaps)
                            else:
                                scores_local[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_patches[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                    k=angle,
                                    axes=(3, 2)
                                )  # merge softmax scores from patches (overlaps)
                                scores[i] += np.rot90(
                                    np.array(
                                        patch2global(
                                            predicted_ensembles[i:i + 1],
                                            self.n_class, sizes[i:i + 1],
                                            coordinates[i:i + 1],
                                            self.size_p)),
                                    k=angle,
                                    axes=(3, 2)
                                )  # merge softmax scores from patches (overlaps)
                        ###################################################

            # global predictions ###########################
            outputs_global = torch.Tensor(outputs_global)
            predictions_global = [
                F.interpolate(outputs_global[i:i + 1],
                              images[i].size[::-1],
                              mode='nearest').argmax(1).detach().numpy()[0]
                for i in range(len(images))
            ]
            if not self.test:
                self.metrics_global.update(labels_npy, predictions_global)

            if self.mode == 2 or self.mode == 3:
                # patch predictions ###########################
                predictions_local = [
                    score.argmax(1)[0] for score in scores_local
                ]
                if not self.test:
                    self.metrics_local.update(labels_npy, predictions_local)
                ###################################################
                # combined/ensemble predictions ###########################
                predictions = [score.argmax(1)[0] for score in scores]
                if not self.test:
                    self.metrics.update(labels_npy, predictions)
                return predictions, predictions_global, predictions_local
            else:
                return None, predictions_global, None
Пример #3
0
class Trainer(object):
    # def __init__(self, criterion, optimizer, n_class, size0, size_g, size_p, n, sub_batch_size=6, mode=1, lamb_fmreg=0.15):
    def __init__(
        self,
        criterion,
        optimizer,
        n_class,
        size_g,
        size_p,
        sub_batch_size=6,
        mode=1,
        lamb_fmreg=0.15,
    ):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        # self.size0 = size0
        self.size_g = size_g
        self.size_p = size_p
        # self.n = n
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.lamb_fmreg = lamb_fmreg

        # self.ratio = float(size_p[0]) / size0[0]
        # self.step = (size0[0] - size_p[0]) // (n - 1)
        # self.template, self.coordinates = template_patch2global(size0, size_p, n, self.step)

    def set_train(self, model):
        model.module.ensemble_conv.train()
        if self.mode == 1 or self.mode == 3:
            model.module.resnet_global.train()
            model.module.fpn_global.train()
        else:
            model.module.resnet_local.train()
            model.module.fpn_local.train()

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def train(self, sample, model, global_fixed):
        images, labels = sample["image"], sample["label"]  # PIL images
        #lbls = [RGB_mapping_to_class(np.array(label)) for label in labels]
        #labels = [Image.fromarray(lbl) for lbl in lbls]
#        del(lbls)

        labels_npy = masks_transform(
            labels, numpy=True
        )  # label of origin size in numpy

        images_glb = resize(images, self.size_g)  # list of resized PIL images
        images_glb = images_transform(images_glb)
        labels_glb = resize(
            labels, (self.size_g[0] // 4, self.size_g[1] // 4), label=True
        )  # down 1/4 for loss
        # labels_glb = resize(labels, self.size_g, label=True) # must downsample image for reduced GPU memory
        labels_glb = masks_transform(labels_glb)  # 127 * 127 * 8 = 129032

        if self.mode == 2 or self.mode == 3:
            patches, coordinates, templates, sizes, ratios = global2patch(
            # patches, coordinates, _, sizes, ratios = global2patch(
                images, self.size_p
            )
            label_patches, _, _, _, _ = global2patch(labels, self.size_p)
            # patches, label_patches = global2patch(images, self.n, self.step, self.size_p), global2patch(labels, self.n, self.step, self.size_p)
            # predicted_patches = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            # predicted_ensembles = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            predicted_patches = [
                np.zeros(
                    (len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])
                )
                for i in range(len(images))
            ]
            predicted_ensembles = [
                np.zeros(
                    (len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])
                )
                for i in range(len(images))
            ]
            outputs_global = [None for i in range(len(images))]

        if self.mode == 1:
            # training with only (resized) global image 
            outputs_global, _ = model.forward(images_glb, None, None, None)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            

        if self.mode == 2:
            # training with patches 
            for i in range(len(images)):
                j = 0
                # while j < self.n**2:
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j : j + self.sub_batch_size]
                    )  # b, c, h, w
                    label_patches_var = masks_transform(
                        resize(
                            label_patches[i][j : j + self.sub_batch_size],
                            (self.size_p[0] // 4, self.size_p[1] // 4),
                            label=True,
                        )
                    )  # down 1/4 for loss
                    # label_patches_var = masks_transform(label_patches[i][j : j+self.sub_batch_size])

                    # output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(images_glb[i:i+1], patches_var, self.coordinates[j : j+self.sub_batch_size], self.ratio, mode=self.mode, n_patch=self.n**2) # include cordinates
                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(
                        images_glb[i : i + 1],
                        patches_var,
                        coordinates[i][j : j + self.sub_batch_size],
                        ratios[i],
                        mode=self.mode,
                        n_patch=len(coordinates[i]),
                    )
                    loss = (
                        self.criterion(output_patches, label_patches_var)
                        + self.criterion(output_ensembles, label_patches_var)
                        + self.lamb_fmreg * fmreg_l2
                    )
                    loss.backward()

                    # patch predictions
                    predicted_patches[i][j : j + output_patches.size()[0]] = (
                        F.interpolate(output_patches, size=self.size_p, mode="nearest")
                        .data.cpu()
                        .numpy()
                    )
                    predicted_ensembles[i][j : j + output_ensembles.size()[0]] = (
                        F.interpolate(
                            output_ensembles, size=self.size_p, mode="nearest"
                        )
                        .data.cpu()
                        .numpy()
                    )
                    j += self.sub_batch_size
                outputs_global[i] = output_global
            outputs_global = torch.cat(outputs_global, dim=0)

            self.optimizer.step()
            self.optimizer.zero_grad()
            

        if self.mode == 3:
            # train global with help from patches 
            # go through local patches to collect feature maps
            # collect predictions from patches
            for i in range(len(images)):
                j = 0
                # while j < self.n**2:
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j : j + self.sub_batch_size]
                    )  # b, c, h, w
                    # fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, self.ratio, self.coordinates, [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=self.n**2) # include cordinates
                    fm_patches, output_patches = model.module.collect_local_fm(
                        images_glb[i : i + 1],
                        patches_var,
                        ratios[i],
                        coordinates[i],
                        [j, j + self.sub_batch_size],
                        len(images),
                        global_model=global_fixed,
                        template=templates[0],
                        n_patch_all=len(coordinates[i]),
                    )
                    predicted_patches[i][j : j + output_patches.size()[0]] = (
                        F.interpolate(output_patches, size=self.size_p, mode="nearest")
                        .data.cpu()
                        .numpy()
                    )
                    j += self.sub_batch_size
            # train on global image
            outputs_global, fm_global = model.forward(
                images_glb, None, None, None, mode=self.mode
            )
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward(retain_graph=True)
            # fmreg loss
            # generate ensembles & calc loss
            for i in range(len(images)):
                j = 0
                # while j < self.n**2:
                while j < len(coordinates[i]):
                    label_patches_var = masks_transform(
                        resize(
                            label_patches[i][j : j + self.sub_batch_size],
                            (self.size_p[0] // 4, self.size_p[1] // 4),
                            label=True,
                        )
                    )
                    # label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], self.size_p, label=True))
                    fl = fm_patches[i][j : j + self.sub_batch_size].cuda()
                    # fg = model.module._crop_global(fm_global[i:i+1], self.coordinates[j:j+self.sub_batch_size], self.ratio)[0]
                    fg = model.module._crop_global(
                        fm_global[i : i + 1],
                        coordinates[i][j : j + self.sub_batch_size],
                        ratios[i],
                    )[0]
                    fg = F.interpolate(fg, size=fl.size()[2:], mode="bilinear")
                    output_ensembles = model.module.ensemble(fl, fg)
                    # output_ensembles = F.interpolate(model.module.ensemble(fl, fg), self.size_p, **model.module._up_kwargs)
                    loss = self.criterion(
                        output_ensembles, label_patches_var
                    )  # + 0.15 * mse(fl, fg)
                    # if i == len(images) - 1 and j + self.sub_batch_size >= self.n**2:
                    if i == len(images) - 1 and j + self.sub_batch_size >= len(
                        coordinates[i]
                    ):
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)

                    # ensemble predictions
                    predicted_ensembles[i][j : j + output_ensembles.size()[0]] = (
                        F.interpolate(
                            output_ensembles, size=self.size_p, mode="nearest"
                        )
                        .data.cpu()
                        .numpy()
                    )
                    j += self.sub_batch_size
            self.optimizer.step()
            self.optimizer.zero_grad()

        # global predictions 
        # predictions_global = F.interpolate(outputs_global.cpu(), self.size0, mode='nearest').argmax(1).detach().numpy()
        outputs_global = outputs_global.cpu()
        predictions_global = [
            F.interpolate(
                outputs_global[i : i + 1], images[i].size[::-1], mode="nearest"
            )
            .argmax(1)
            .detach()
            .numpy()
            for i in range(len(images))
        ]
        self.metrics_global.update(labels_npy, predictions_global)

        if self.mode == 2 or self.mode == 3:
            # patch predictions 
            # scores_local = np.array(patch2global(predicted_patches, self.n_class, self.n, self.step, self.size0, self.size_p, len(images))) # merge softmax scores from patches (overlaps)
            scores_local = np.array(
                patch2global(
                    predicted_patches, self.n_class, sizes, coordinates, self.size_p
                )
            )  # merge softmax scores from patches (overlaps)
            predictions_local = scores_local.argmax(1)  # b, h, w
            self.metrics_local.update(labels_npy, predictions_local)
            
            # combined/ensemble predictions 
            # scores = np.array(patch2global(predicted_ensembles, self.n_class, self.n, self.step, self.size0, self.size_p, len(images))) # merge softmax scores from patches (overlaps)
            scores = np.array(
                patch2global(
                    predicted_ensembles, self.n_class, sizes, coordinates, self.size_p
                )
            )  # merge softmax scores from patches (overlaps)
            predictions = scores.argmax(1)  # b, h, w
            self.metrics.update(labels_npy, predictions)
        return loss
Пример #4
0
class Trainer(object):
    # def __init__(self, criterion, optimizer, n_class, size0, size_g, size_p, n, sub_batch_size=6, mode=1, lamb_fmreg=0.15):
    def __init__(self, criterion, optimizer, n_class, size_g, size_p, sub_batch_size=6, mode=1, lamb_fmreg=0.15):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        # self.size0 = size0
        self.size_g = size_g
        self.size_p = size_p
        # self.n = n
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.lamb_fmreg = lamb_fmreg

        # self.ratio = float(size_p[0]) / size0[0]
        # self.step = (size0[0] - size_p[0]) // (n - 1)
        # self.template, self.coordinates = template_patch2global(size0, size_p, n, self.step)
    
    def set_train(self, model, parallel=True):
        if not parallel:
            model.ensemble_conv.train()
            if self.mode == 1 or self.mode == 3:
                model.resnet_global.train()
                model.fpn_global.train()
            else:
                model.resnet_local.train()
                model.fpn_local.train()
        else:
            model.module.ensemble_conv.train()
            if self.mode == 1 or self.mode == 3:
                model.module.resnet_global.train()
                model.module.fpn_global.train()
            else:
                model.module.resnet_local.train()
                model.module.fpn_local.train()

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def train(self, sample, model, global_fixed, parallel=True):
        images, labels = sample['image'], sample['label'] # PIL images
        labels_npy = masks_transform(labels, numpy=True) # label of origin size in numpy

        images_glb = resize(images, self.size_g) # list of resized PIL images
        images_glb = images_transform(images_glb)
        labels_glb = resize(labels, (self.size_g[0] // 4, self.size_g[1] // 4), label=True) # down 1/4 for loss
        # labels_glb = resize(labels, self.size_g, label=True) # must downsample image for reduced GPU memory
        labels_glb = masks_transform(labels_glb)

        if self.mode == 2 or self.mode == 3:
            patches, coordinates, templates, sizes, ratios = global2patch(images, self.size_p)
            label_patches, _, _, _, _ = global2patch(labels, self.size_p)
            # patches, label_patches = global2patch(images, self.n, self.step, self.size_p), global2patch(labels, self.n, self.step, self.size_p)
            # predicted_patches = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            # predicted_ensembles = [ np.zeros((self.n**2, self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            predicted_patches = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            predicted_ensembles = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            outputs_global = [ None for i in range(len(images)) ]

        if self.mode == 1:
            # training with only (resized) global image #########################################
            outputs_global, _ = model.forward(images_glb, None, None, None)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            ##############################################

        if self.mode == 2:
            # training with patches ###########################################
            for i in range(len(images)):
                # sync the start for each global image
                torch.distributed.barrier()
                #print("Will start training global image:", i, " on rank  :", torch.distributed.get_rank())
                #group = torch.distributed.new_group(range(torch.distributed.get_world_size()))
                coordinate_size = torch.tensor([len(coordinates[i])], dtype=torch.int32).cuda()
                torch.distributed.all_reduce(coordinate_size, op=torch.distributed.ReduceOp.MAX)
                #print("Will train:", coordinate_size.item(), "loops in global image.", "On rank  :", torch.distributed.get_rank())
                
                j = 0
                _loop = 0
                # while j < self.n**2:
                #print("==== ==== len(coordinates[i]):", len(coordinates[i]), ", at rank:", torch.distributed.get_rank())
                #while j < len(coordinates[i]):
                #while j in range(6):
                #while _ in range(coordinate_size.item()):
                while _loop < coordinate_size.item():
                    patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w
                    label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True)) # down 1/4 for loss
                    # label_patches_var = masks_transform(label_patches[i][j : j+self.sub_batch_size])

                    # output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(images_glb[i:i+1], patches_var, self.coordinates[j : j+self.sub_batch_size], self.ratio, mode=self.mode, n_patch=self.n**2) # include cordinates
                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(images_glb[i:i+1], patches_var, coordinates[i][j : j+self.sub_batch_size], ratios[i], mode=self.mode, n_patch=len(coordinates[i]))
                    loss = self.criterion(output_patches, label_patches_var) + self.criterion(output_ensembles, label_patches_var) + self.lamb_fmreg * fmreg_l2
                    loss.backward()

                    # patch predictions
                    predicted_patches[i][j:j+output_patches.size()[0]] = F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy()
                    predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy()
                    # Because we choose loop the biggest coordinate_size in all ranks, 
                    # make sure not cross the border for current rank
                    j = (j+self.sub_batch_size)%len(coordinates[i])
                    _loop += self.sub_batch_size
                    #print("==== ==== nested loop:", j, ", at rank:", torch.distributed.get_rank())
                outputs_global[i] = output_global
            outputs_global = torch.cat(outputs_global, dim=0)

            self.optimizer.step()
            self.optimizer.zero_grad()
            #####################################################################################

        if self.mode == 3:
            # train global with help from patches ##################################################
            # go through local patches to collect feature maps
            # collect predictions from patches
            for i in range(len(images)):
                j = 0
                # while j < self.n**2:
                while j < len(coordinates[i]):
                    patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w
                    if parallel:
                        # fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, self.ratio, self.coordinates, [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=self.n**2) # include cordinates
                        fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, ratios[i], coordinates[i], [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=len(coordinates[i]))
                    else:
                        # fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, self.ratio, self.coordinates, [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=self.n**2) # include cordinates
                        fm_patches, output_patches = model.collect_local_fm(images_glb[i:i+1], patches_var, ratios[i], coordinates[i], [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=self.template, n_patch_all=len(coordinates[i]))
                    
                    predicted_patches[i][j:j+output_patches.size()[0]] = F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
            # train on global image
            outputs_global, fm_global = model.forward(images_glb, None, None, None, mode=self.mode)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward(retain_graph=True)
            # fmreg loss
            # generate ensembles & calc loss
            for i in range(len(images)):
                j = 0
                # while j < self.n**2:
                while j < len(coordinates[i]):
                    label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True))
                    # label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], self.size_p, label=True))
                    fl = fm_patches[i][j : j+self.sub_batch_size].cuda()
                    if parallel:
                        # fg = model.module._crop_global(fm_global[i:i+1], self.coordinates[j:j+self.sub_batch_size], self.ratio)[0]
                        fg = model.module._crop_global(fm_global[i:i+1], coordinates[i][j:j+self.sub_batch_size], ratios[i])[0]
                    else:
                        # fg = model.module._crop_global(fm_global[i:i+1], self.coordinates[j:j+self.sub_batch_size], self.ratio)[0]
                        fg = model._crop_global(fm_global[i:i+1], coordinates[i][j:j+self.sub_batch_size], ratios[i])[0]

                    fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear')
                    if parallel:
                        output_ensembles = model.module.ensemble(fl, fg)
                        # output_ensembles = F.interpolate(model.module.ensemble(fl, fg), self.size_p, **model.module._up_kwargs)
                    else:
                        output_ensembles = model.ensemble(fl, fg)
                        # output_ensembles = F.interpolate(model.module.ensemble(fl, fg), self.size_p, **model.module._up_kwargs)

                    loss = self.criterion(output_ensembles, label_patches_var)# + 0.15 * mse(fl, fg)
                    # if i == len(images) - 1 and j + self.sub_batch_size >= self.n**2:
                    if i == len(images) - 1 and j + self.sub_batch_size >= len(coordinates[i]):
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)

                    # ensemble predictions
                    predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
            self.optimizer.step()
            self.optimizer.zero_grad()

        # global predictions ###########################
        # predictions_global = F.interpolate(outputs_global.cpu(), self.size0, mode='nearest').argmax(1).detach().numpy()
        outputs_global = outputs_global.cpu()
        predictions_global = [F.interpolate(outputs_global[i:i+1], images[i].size[::-1], mode='nearest').argmax(1).detach().numpy() for i in range(len(images))]
        self.metrics_global.update(labels_npy, predictions_global)

        if self.mode == 2 or self.mode == 3:
            # patch predictions ###########################
            # scores_local = np.array(patch2global(predicted_patches, self.n_class, self.n, self.step, self.size0, self.size_p, len(images))) # merge softmax scores from patches (overlaps)
            scores_local = np.array(patch2global(predicted_patches, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps)
            predictions_local = scores_local.argmax(1) # b, h, w
            self.metrics_local.update(labels_npy, predictions_local)
            ###################################################
            # combined/ensemble predictions ###########################
            # scores = np.array(patch2global(predicted_ensembles, self.n_class, self.n, self.step, self.size0, self.size_p, len(images))) # merge softmax scores from patches (overlaps)
            scores = np.array(patch2global(predicted_ensembles, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps)
            predictions = scores.argmax(1) # b, h, w
            self.metrics.update(labels_npy, predictions)
        return loss
Пример #5
0
class Trainer(object):
    def __init__(self,
                 criterion,
                 optimizer,
                 n_class,
                 size_g,
                 size_p,
                 sub_batch_size=6,
                 mode=1,
                 lamb_fmreg=0.15):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.size_g = size_g
        self.size_p = size_p
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.lamb_fmreg = lamb_fmreg

    def set_train(self, model):
        model.module.ensemble_conv.train()
        if self.mode == 1 or self.mode == 3:
            model.module.resnet_global.train()
            model.module.fpn_global.train()
        else:
            model.module.resnet_local.train()
            model.module.fpn_local.train()

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def train(self, sample, model, global_fixed):
        images, labels = sample['image'], sample['label']  # PIL images
        labels_npy = masks_transform(
            labels, numpy=True)  # label of origin size in numpy

        images_glb = resize(images, self.size_g)  # list of resized PIL images
        images_glb = images_transform(images_glb)
        labels_glb = resize(labels, (self.size_g[0] // 4, self.size_g[1] // 4),
                            label=True)  # FPN down 1/4, for loss
        labels_glb = masks_transform(labels_glb)

        if self.mode == 2 or self.mode == 3:
            patches, coordinates, templates, sizes, ratios = global2patch(
                images, self.size_p)
            label_patches, _, _, _, _ = global2patch(labels, self.size_p)
            predicted_patches = [
                np.zeros((len(coordinates[i]), self.n_class, self.size_p[0],
                          self.size_p[1])) for i in range(len(images))
            ]
            predicted_ensembles = [
                np.zeros((len(coordinates[i]), self.n_class, self.size_p[0],
                          self.size_p[1])) for i in range(len(images))
            ]
            outputs_global = [None for i in range(len(images))]

        if self.mode == 1:
            # training with only (resized) global image #########################################
            outputs_global, _ = model.forward(images_glb, None, None, None)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            ##############################################

        if self.mode == 2:
            # training with patches ###########################################
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j:j + self.sub_batch_size])  # b, c, h, w
                    label_patches_var = masks_transform(
                        resize(label_patches[i][j:j + self.sub_batch_size],
                               (self.size_p[0] // 4, self.size_p[1] // 4),
                               label=True))  # down 1/4 for loss

                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(
                        images_glb[i:i + 1],
                        patches_var,
                        coordinates[i][j:j + self.sub_batch_size],
                        ratios[i],
                        mode=self.mode,
                        n_patch=len(coordinates[i]))
                    loss = self.criterion(
                        output_patches, label_patches_var) + self.criterion(
                            output_ensembles,
                            label_patches_var) + self.lamb_fmreg * fmreg_l2
                    loss.backward()

                    # patch predictions
                    predicted_patches[i][j:j + output_patches.size(
                    )[0]] = F.interpolate(output_patches,
                                          size=self.size_p,
                                          mode='nearest').data.cpu().numpy()
                    predicted_ensembles[i][j:j + output_ensembles.size(
                    )[0]] = F.interpolate(output_ensembles,
                                          size=self.size_p,
                                          mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
                outputs_global[i] = output_global
            outputs_global = torch.cat(outputs_global, dim=0)

            self.optimizer.step()
            self.optimizer.zero_grad()
            #####################################################################################

        if self.mode == 3:
            # train global with help from patches ##################################################
            # go through local patches to collect feature maps
            # collect predictions from patches
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j:j + self.sub_batch_size])  # b, c, h, w
                    fm_patches, output_patches = model.module.collect_local_fm(
                        images_glb[i:i + 1],
                        patches_var,
                        ratios[i],
                        coordinates[i], [j, j + self.sub_batch_size],
                        len(images),
                        global_model=global_fixed,
                        template=templates[i],
                        n_patch_all=len(coordinates[i]))
                    predicted_patches[i][j:j + output_patches.size(
                    )[0]] = F.interpolate(output_patches,
                                          size=self.size_p,
                                          mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
            # train on global image
            outputs_global, fm_global = model.forward(images_glb,
                                                      None,
                                                      None,
                                                      None,
                                                      mode=self.mode)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward(retain_graph=True)
            # fmreg loss
            # generate ensembles & calc loss
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    label_patches_var = masks_transform(
                        resize(label_patches[i][j:j + self.sub_batch_size],
                               (self.size_p[0] // 4, self.size_p[1] // 4),
                               label=True))
                    fl = fm_patches[i][j:j + self.sub_batch_size].cuda()
                    fg = model.module._crop_global(
                        fm_global[i:i + 1],
                        coordinates[i][j:j + self.sub_batch_size],
                        ratios[i])[0]
                    fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear')
                    output_ensembles = model.module.ensemble(fl, fg)
                    loss = self.criterion(
                        output_ensembles,
                        label_patches_var)  # + 0.15 * mse(fl, fg)
                    if i == len(images) - 1 and j + self.sub_batch_size >= len(
                            coordinates[i]):
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)

                    # ensemble predictions
                    predicted_ensembles[i][j:j + output_ensembles.size(
                    )[0]] = F.interpolate(output_ensembles,
                                          size=self.size_p,
                                          mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
            self.optimizer.step()
            self.optimizer.zero_grad()

        # global predictions ###########################
        outputs_global = outputs_global.cpu()
        predictions_global = [
            F.interpolate(outputs_global[i:i + 1],
                          images[i].size[::-1],
                          mode='nearest').argmax(1).detach().numpy()
            for i in range(len(images))
        ]
        self.metrics_global.update(labels_npy, predictions_global)

        if self.mode == 2 or self.mode == 3:
            # patch predictions ###########################
            scores_local = np.array(
                patch2global(predicted_patches, self.n_class, sizes,
                             coordinates, self.size_p)
            )  # merge softmax scores from patches (overlaps)
            predictions_local = scores_local.argmax(1)  # b, h, w
            self.metrics_local.update(labels_npy, predictions_local)
            ###################################################
            # combined/ensemble predictions ###########################
            scores = np.array(
                patch2global(predicted_ensembles, self.n_class, sizes,
                             coordinates, self.size_p)
            )  # merge softmax scores from patches (overlaps)
            predictions = scores.argmax(1)  # b, h, w
            self.metrics.update(labels_npy, predictions)
        return loss
Пример #6
0
class Trainer(object):
    def __init__(self,
                 criterion,
                 optimizer,
                 n_class,
                 size_g,
                 size_p,
                 sub_batch_size=6,
                 mode=1,
                 lamb_fmreg=0.15):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.size_g = size_g
        self.size_p = size_p
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.lamb_fmreg = lamb_fmreg

    def set_train(self, model):
        model.module.ensemble_conv.train()
        if self.mode == 1 or self.mode == 3:
            model.module.resnet_global.train()
            model.module.fpn_global.train()
        else:
            model.module.resnet_local.train()
            model.module.fpn_local.train()

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def train(self, sample, model, global_fixed):
        images, labels = sample['image'], sample['label']  # PIL images
        ids = sample['id']
        width, height = images[0].size

        if width != self.size_g[0] or height != self.size_g[1]:
            images_glb = resize(images,
                                self.size_g)  # list of resized PIL images
        else:
            images_glb = list(images)

        images_glb = images_transform(images_glb)
        labels_glb = masks_transform(labels)

        if self.mode == 2 or self.mode == 3:
            patches, coordinates, sizes, ratios, label_patches = global2patch(
                images, labels_glb, self.size_p, ids)
            predicted_patches = np.zeros((len(images), 4))
            predicted_ensembles = np.zeros((len(images), 4))
            outputs_global = [None for i in range(len(images))]

        if self.mode == 1:
            # training with only (resized) global image #########################################
            outputs_global, _ = model.forward(images_glb, None, None, None)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward()

            self.optimizer.step()
            self.optimizer.zero_grad()
            ##############################################

        if self.mode == 2:
            # training with patches ###########################################
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j:j + self.sub_batch_size])  # b, c, h, w
                    label_patches_var = masks_transform(
                        label_patches[i][j:j + self.sub_batch_size])

                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(
                        images_glb[i:i + 1],
                        patches_var,
                        coordinates[i][j:j + self.sub_batch_size],
                        ratios[i],
                        mode=self.mode,
                        n_patch=len(coordinates[i]))
                    loss = self.criterion(
                        output_patches, label_patches_var) + self.criterion(
                            output_ensembles,
                            label_patches_var) + self.lamb_fmreg * fmreg_l2
                    loss.backward()

                    # patch predictions
                    for pred_patch, pred_ensemble in zip(
                            torch.max(output_patches.data, 1)[1].data,
                            torch.max(output_ensembles.data, 1)[1].data):
                        predicted_patches[i][int(pred_patch)] += 1
                        predicted_ensembles[i][int(pred_ensemble)] += 1

                    j += self.sub_batch_size

                outputs_global[i] = output_global

            outputs_global = torch.cat(outputs_global, dim=0)

            self.optimizer.step()
            self.optimizer.zero_grad()
            #####################################################################################

        if self.mode == 3:
            # train global with help from patches ##################################################
            # go through local patches to collect feature maps
            # collect predictions from patches

            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    patches_var = images_transform(
                        patches[i][j:j + self.sub_batch_size])  # b, c, h, w
                    _, output_patches = model.module.collect_local_fm(
                        images_glb[i:i + 1],
                        patches_var,
                        ratios[i],
                        coordinates[i], [j, j + self.sub_batch_size],
                        len(images),
                        global_model=global_fixed,
                        n_patch_all=len(coordinates[i]))

                    for pred_patch in torch.max(output_patches.data,
                                                1)[1].data:
                        predicted_patches[i][int(pred_patch)] += 1

                    j += self.sub_batch_size

            # train on global image

            outputs_global, fm_global = model.forward(images_glb,
                                                      None,
                                                      None,
                                                      None,
                                                      mode=self.mode)

            loss = self.criterion(outputs_global, labels_glb)
            loss.backward(retain_graph=True)

            # fmreg loss
            # generate ensembles & calc loss
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    label_patches_var = masks_transform(
                        label_patches[i][j:j + self.sub_batch_size])
                    patches_var = images_transform(
                        patches[i][j:j + self.sub_batch_size])  # b, c, h, w

                    fl = model.module.generate_local_fm(
                        images_glb[i:i + 1],
                        patches_var,
                        ratios[i],
                        coordinates[i], [j, j + self.sub_batch_size],
                        len(images),
                        global_model=global_fixed,
                        n_patch_all=len(coordinates[i]))
                    fg = model.module._crop_global(
                        fm_global[i:i + 1],
                        coordinates[i][j:j + self.sub_batch_size],
                        ratios[i])[0]
                    fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear')
                    output_ensembles = model.module.ensemble(fl, fg)

                    loss = self.criterion(
                        output_ensembles,
                        label_patches_var)  # + 0.15 * mse(fl, fg)
                    if i == len(images) - 1 and j + self.sub_batch_size >= len(
                            coordinates[i]):
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)

                    # ensemble predictions
                    for pred_ensemble in torch.max(output_ensembles.data,
                                                   1)[1].data:
                        predicted_ensembles[i][int(pred_ensemble)] += 1

                    j += self.sub_batch_size

            self.optimizer.step()
            self.optimizer.zero_grad()

        # global predictions ###########################
        _, predictions_global = torch.max(outputs_global.data, 1)
        self.metrics_global.update(labels_glb, predictions_global)

        if self.mode == 2 or self.mode == 3:
            # patch predictions ###########################
            predictions_local = predicted_patches.argmax(1)
            #self.metrics_local.update(labels_npy, predictions_local)
            self.metrics_local.update(labels_glb, predictions_local)
            ###################################################
            # combined/ensemble predictions ###########################
            predictions = predicted_ensembles.argmax(1)
            self.metrics.update(labels_glb, predictions)
        return loss
Пример #7
0
class Trainer(object):
    def __init__(self, criterion, optimizer, n_class, size_g, size_p, sub_batch_size=6, mode=1, lamb_fmreg=0.15):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.size_g = size_g
        self.size_p = size_p
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.lamb_fmreg = lamb_fmreg
    
    def set_train(self, model):
        model.module.ensemble_conv.train()
        if self.mode == 1 or self.mode == 3:
            model.module.resnet_global.train()
            model.module.fpn_global.train()
        else:
            model.module.resnet_local.train()
            model.module.fpn_local.train()

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def train(self, sample, model, global_fixed):
        images, labels, labels_npy, images_glb = sample['image'], sample['label'], sample['label_npy'], sample['image_glb'] # PIL images
        #labels_npy = masks_transform(labels, numpy=True) # label of origin size in numpy
        #images_glb = resize(images, self.size_g) # list of resized PIL images
        #images_glb = images_transform(images_glb)
        labels_glb = resize(labels, (self.size_g[0] // 4, self.size_g[1] // 4), label=True) # FPN down 1/4, for loss
        labels_glb = masks_transform(labels_glb)
        if self.mode == 2 or self.mode == 3:
            patches, coordinates, templates, sizes, ratios = global2patch(images, self.size_p)
            label_patches, _, _, _, _ = global2patch(labels, self.size_p)
            #predicted_patches = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            #predicted_ensembles = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ]
            #outputs_global = [ None for i in range(len(images)) ]

        if self.mode == 1:
            # training with only (resized) global image #########################################
            outputs_global, _ = model.forward(images_glb, None, None, None)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            ##############################################

        if self.mode == 2:
            
            # training with patches ###########################################
            subdataset = AerialSubdatasetMode2(images_glb, ratios, coordinates, patches, label_patches, (self.size_p[0] // 4, self.size_p[1] // 4))
            data_loader = torch.utils.data.DataLoader(dataset=subdataset, \
                                                    batch_size=self.sub_batch_size, \
                                                    num_workers=20, \
                                                    collate_fn=collate_mode2, \
                                                    shuffle=False, pin_memory=True)
            for batch_sample in data_loader:
                for sub_batch_id in range(len(batch_sample['n_patch'])):
                    patches_var = batch_sample['patches'][sub_batch_id].cuda()
                    label_patches_var = batch_sample['labels'][sub_batch_id].cuda()
                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(batch_sample['images_glob'][sub_batch_id].cuda(), \
                                                                                            patches_var, \
                                                                                            batch_sample['coords'][sub_batch_id], \
                                                                                            batch_sample['ratio'][sub_batch_id], mode=self.mode, \
                                                                                            n_patch=batch_sample['n_patch'][sub_batch_id])

                    loss = self.criterion(output_patches, label_patches_var) + self.criterion(output_ensembles, label_patches_var) + self.lamb_fmreg * fmreg_l2
                    loss.backward()
            
            ''' 
            for i in range(len(images)):
                j = 0
                print("LEN", len(coordinates[i]))
                while j < len(coordinates[i]):
                    track.start("transform_internal")
                    patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w
                    label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True)) # down 1/4 for loss
                    track.end("transform_internal")
                    
                    track.start("ff_internal")
                    output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(images_glb[i:i+1], patches_var, coordinates[i][j : j+self.sub_batch_size], ratios[i], mode=self.mode, n_patch=len(coordinates[i]))
                    track.end("ff_internal")
                    loss = self.criterion(output_patches, label_patches_var) + self.criterion(output_ensembles, label_patches_var) + self.lamb_fmreg * fmreg_l2
                    loss.backward()

                    # patch predictions
                    #predicted_patches[i][j:j+output_patches.size()[0]] = F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy()
                    #predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
                #outputs_global[i] = output_global
            #outputs_global = torch.cat(outputs_global, dim=0)
            '''
            self.optimizer.step()
            self.optimizer.zero_grad()
            #####################################################################################

        if self.mode == 3:
            # train global with help from patches ##################################################
            # go through local patches to collect feature maps
            # collect predictions from patches
            
            track.start("Collect patches")
            # import pdb; pdb.set_trace();
            subdataset = AerialSubdatasetMode3a(patches, coordinates, images_glb, ratios, templates)
            data_loader = torch.utils.data.DataLoader(dataset=subdataset, \
                                                    batch_size=self.sub_batch_size, \
                                                    num_workers=20, \
                                                    collate_fn=collate_mode3a, \
                                                    shuffle=False, pin_memory=True)
            for batch_sample in data_loader:
                for sub_batch_id in range(len(batch_sample['ratios'])):
                    patches_var = batch_sample['patches'][sub_batch_id].cuda()
                    coord = batch_sample['coords'][sub_batch_id]
                    j = batch_sample['coord_ids'][sub_batch_id]
                    fm_patches, _ = model.module.collect_local_fm(batch_sample['images_glb'][sub_batch_id].cuda(), \
                                                                  patches_var, \
                                                                  batch_sample['ratios'][sub_batch_id], \
                                                                  coord, \
                                                                  [min(j), max(j) + 1], \
                                                                  len(images), \
                                                                  global_model=global_fixed, \
                                                                  template= batch_sample['templates'][sub_batch_id].cuda(), \
                                                                  n_patch_all=len(coord))
            # for i in range(len(images)):
            #     j = 0
            #     while j < len(coordinates[i]):
            #         patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w
            #         fm_patches, _ = model.module.collect_local_fm(images_glb[i:i+1], patches_var, ratios[i], coordinates[i], [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=templates[i], n_patch_all=len(coordinates[i]))
            #         j += self.sub_batch_size
            
            track.end("Collect patches")
            
            images_glb = images_glb.cuda()
            # train on global image
            outputs_global, fm_global = model.forward(images_glb, None, None, None, mode=self.mode)
            loss = self.criterion(outputs_global, labels_glb)
            loss.backward(retain_graph=True)
            
            subdataset = AerialSubdatasetMode3b(label_patches, \
                                                (self.size_p[0] // 4, self.size_p[1] // 4), \
                                                fm_patches,\
                                                coordinates, ratios)
            data_loader = torch.utils.data.DataLoader(dataset=subdataset, \
                                                    batch_size=self.sub_batch_size, \
                                                    num_workers=20, \
                                                    collate_fn=collate_mode3b, \
                                                    shuffle=False, pin_memory=True)
            track.start("load_mode_3b")
            for batch_idx, batch_sample in enumerate(data_loader):
                for sub_batch_id in range(len(batch_sample['ratios'])):
                    label_patches_var = batch_sample['label_patches'][sub_batch_id].cuda()
                    fl = batch_sample['fl'][sub_batch_id].cuda()
                    image_id = batch_sample['id'][sub_batch_id]
                    track.end("load_mode_3b")
                    fg = model.module._crop_global(fm_global[image_id: image_id+1], \
                                                   batch_sample['coords'][sub_batch_id], \
                                                   batch_sample['ratios'][sub_batch_id])[0]
                    fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear')
                    output_ensembles = model.module.ensemble(fl, fg)
                    loss = self.criterion(output_ensembles, label_patches_var)# + 0.15 * mse(fl, fg)
                    if batch_idx == len(data_loader) - 1 and sub_batch_id == len(batch_sample['ratios']) - 1:
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)
                    track.start("load_mode_3b")
            # fmreg loss
            # generate ensembles & calc loss
            """
            track.start("load_mode_3b")
            for i in range(len(images)):
                j = 0
                while j < len(coordinates[i]):
                    label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True))
                    fl = fm_patches[i][j : j+self.sub_batch_size].cuda()
                    track.end("load_mode_3b")
                    fg = model.module._crop_global(fm_global[i:i+1], coordinates[i][j:j+self.sub_batch_size], ratios[i])[0]
                    fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear')
                    output_ensembles = model.module.ensemble(fl, fg)
                    loss = self.criterion(output_ensembles, label_patches_var)# + 0.15 * mse(fl, fg)
                    if i == len(images) - 1 and j + self.sub_batch_size >= len(coordinates[i]):
                        loss.backward()
                    else:
                        loss.backward(retain_graph=True)
                    track.start("load_mode_3b")
                    # ensemble predictions
                    #predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy()
                    j += self.sub_batch_size
            """
            self.optimizer.step()
            self.optimizer.zero_grad()
        '''
        # global predictions ###########################
        outputs_global = outputs_global.cpu()
        predictions_global = [F.interpolate(outputs_global[i:i+1], images[i].size[::-1], mode='nearest').argmax(1).detach().numpy() for i in range(len(images))]
        self.metrics_global.update(labels_npy, predictions_global)
        
        if self.mode == 2 or self.mode == 3:
            # patch predictions ###########################
            scores_local = np.array(patch2global(predicted_patches, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps)
            predictions_local = scores_local.argmax(1) # b, h, w
            self.metrics_local.update(labels_npy, predictions_local)
            ###################################################
            # combined/ensemble predictions ###########################
            scores = np.array(patch2global(predicted_ensembles, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps)
            predictions = scores.argmax(1) # b, h, w
            self.metrics.update(labels_npy, predictions)
        '''
        return loss
Пример #8
0
class Evaluator(object):
    def __init__(self,
                 n_class,
                 size_g,
                 size_p,
                 sub_batch_size=6,
                 mode=1,
                 test=False):
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.size_g = size_g
        self.size_p = size_p
        self.sub_batch_size = sub_batch_size
        self.mode = mode
        self.test = test

        if test:
            self.flip_range = [False, True]
            self.rotate_range = [0, 1, 2, 3]
        else:
            self.flip_range = [False]
            self.rotate_range = [0]

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()

    def eval_test(self, sample, model, global_fixed):
        with torch.no_grad():
            images = sample['image']
            ids = sample['id']
            if not self.test:
                labels = sample['label']  # PIL images
                labels_glb = masks_transform(labels)

            width, height = images[0].size

            if width > self.size_g[0] or height > self.size_g[1]:
                images_global = resize(
                    images, self.size_g)  # list of resized PIL images
            else:
                images_global = list(images)

            if self.mode == 2 or self.mode == 3:
                images_local = [image.copy() for image in images]
                scores_local = [
                    np.zeros((1, self.n_class, images[i].size[1],
                              images[i].size[0])) for i in range(len(images))
                ]
                scores = [
                    np.zeros((1, self.n_class, images[i].size[1],
                              images[i].size[0])) for i in range(len(images))
                ]

            for flip in self.flip_range:
                if flip:
                    # we already rotated images for 270'
                    for b in range(len(images)):
                        images_global[b] = transforms.functional.rotate(
                            images_global[b], 90)  # rotate back!
                        images_global[b] = transforms.functional.hflip(
                            images_global[b])
                        if self.mode == 2 or self.mode == 3:
                            images_local[b] = transforms.functional.rotate(
                                images_local[b], 90)  # rotate back!
                            images_local[b] = transforms.functional.hflip(
                                images_local[b])
                for angle in self.rotate_range:
                    if angle > 0:
                        for b in range(len(images)):
                            images_global[b] = transforms.functional.rotate(
                                images_global[b], 90)
                            if self.mode == 2 or self.mode == 3:
                                images_local[b] = transforms.functional.rotate(
                                    images_local[b], 90)

                    # prepare global images onto cuda
                    images_glb = images_transform(images_global)  # b, c, h, w

                    if self.mode == 2 or self.mode == 3:
                        patches, coordinates, sizes, ratios, label_patches = global2patch(
                            images, labels_glb, self.size_p, ids)
                        predicted_patches = np.zeros((len(images), 4))
                        predicted_ensembles = np.zeros((len(images), 4))
                        outputs_global = [None for i in range(len(images))]
                    if self.mode == 1:
                        # eval with only resized global image ##########################
                        if flip:
                            outputs_global += np.flip(np.rot90(model.forward(
                                images_glb, None, None,
                                None)[0].data.cpu().numpy(),
                                                               k=angle,
                                                               axes=(3, 2)),
                                                      axis=3)
                        else:
                            outputs_global, _ = model.forward(
                                images_glb, None, None, None)
                        ################################################################

                    if self.mode == 2:
                        # eval with patches ###########################################
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i]
                                    [j:j + self.sub_batch_size])  # b, c, h, w
                                output_ensembles, output_global, output_patches, _ = model.forward(
                                    images_glb[i:i + 1],
                                    patches_var,
                                    coordinates[i][j:j + self.sub_batch_size],
                                    ratios[i],
                                    mode=self.mode,
                                    n_patch=len(coordinates[i]))

                                # patch predictions
                                for pred_patch, pred_ensemble in zip(
                                        torch.max(output_patches.data,
                                                  1)[1].data,
                                        torch.max(output_ensembles.data,
                                                  1)[1].data):
                                    predicted_patches[i][int(pred_patch)] += 1
                                    predicted_ensembles[i][int(
                                        pred_ensemble)] += 1

                                j += patches_var.size()[0]
                            outputs_global[i] = output_global

                        outputs_global = torch.cat(outputs_global, dim=0)
                        ###############################################################

                    if self.mode == 3:
                        # eval global with help from patches ##################################################
                        # go through local patches to collect feature maps
                        # collect predictions from patches
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i]
                                    [j:j + self.sub_batch_size])  # b, c, h, w
                                _, output_patches = model.module.collect_local_fm(
                                    images_glb[i:i + 1],
                                    patches_var,
                                    ratios[i],
                                    coordinates[i],
                                    [j, j + self.sub_batch_size],
                                    len(images),
                                    global_model=global_fixed,
                                    n_patch_all=len(coordinates[i]))

                                for pred_patch in torch.max(
                                        output_patches.data, 1)[1].data:
                                    predicted_patches[i][int(pred_patch)] += 1

                                j += self.sub_batch_size
                        # go through global image

                        tmp, fm_global = model.forward(images_glb,
                                                       None,
                                                       None,
                                                       None,
                                                       mode=self.mode)

                        if flip:
                            outputs_global += np.flip(np.rot90(
                                tmp.data.cpu().numpy(), k=angle, axes=(3, 2)),
                                                      axis=3)
                        else:
                            outputs_global = tmp
                        # generate ensembles
                        for i in range(len(images)):
                            j = 0
                            while j < len(coordinates[i]):
                                patches_var = images_transform(
                                    patches[i]
                                    [j:j + self.sub_batch_size])  # b, c, h, w
                                fl = model.module.generate_local_fm(
                                    images_glb[i:i + 1],
                                    patches_var,
                                    ratios[i],
                                    coordinates[i],
                                    [j, j + self.sub_batch_size],
                                    len(images),
                                    global_model=global_fixed,
                                    n_patch_all=len(coordinates[i]))
                                fg = model.module._crop_global(
                                    fm_global[i:i + 1],
                                    coordinates[i][j:j + self.sub_batch_size],
                                    ratios[i])[0]
                                fg = F.interpolate(fg,
                                                   size=fl.size()[2:],
                                                   mode='bilinear')
                                output_ensembles = model.module.ensemble(
                                    fl, fg)  # include cordinates

                                # ensemble predictions
                                for pred_ensemble in torch.max(
                                        output_ensembles.data, 1)[1].data:
                                    predicted_ensembles[i][int(
                                        pred_ensemble)] += 1

                                j += self.sub_batch_size
                        ###################################################

            _, predictions_global = torch.max(outputs_global.data, 1)

            if not self.test:
                self.metrics_global.update(labels_glb, predictions_global)

            if self.mode == 2 or self.mode == 3:
                # patch predictions ###########################
                predictions_local = predicted_patches.argmax(1)
                if not self.test:
                    self.metrics_local.update(labels_glb, predictions_local)
                ###################################################
                predictions = predicted_ensembles.argmax(1)
                if not self.test:
                    self.metrics.update(labels_glb, predictions)
                return predictions, predictions_global, predictions_local
            else:
                return None, predictions_global, None
Пример #9
0
class Trainer(object):
    def __init__(self, criterion, optimizer, n_class, sub_batchsize, mode=1, fmreg=0.15):
        self.criterion = criterion
        self.optimizer = optimizer
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.sub_batchsize = sub_batchsize
        self.mode = mode
        self.fmreg = fmreg # regulization item

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()  

    def train(self, sample, model):
        model.train()
        labels = sample['label'].squeeze(1).long()
        labels_npy = np.array(labels)
        labels_torch = labels.cuda()
        h, w = sample['output_size'][0]
        # print(labels[0].size)
        if self.mode == 1:  # global
            img_g = sample['image_g'].cuda()
            outputs_g = model.forward(img_g)
            outputs_g = F.interpolate(outputs_g, size=(h, w), mode='bilinear')
            # print(outputs_g.size(), labels_torch.size())
            loss = self.criterion(outputs_g, labels_torch)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        if self.mode == 2:  # local
            img_l = sample['image_l'].cuda()
            batch_size = img_l.size(0)
            idx = 0
            outputs_l = []
            while idx+self.sub_batchsize <= batch_size:
                output_l = model.forward(img_l[idx:idx+self.sub_batchsize])
                output_l = F.interpolate(output_l, size=(h, w), mode='bilinear')
                outputs_l.append(output_l)
                loss = self.criterion(output_l, labels_torch[idx:idx+self.sub_batchsize])
                loss.backward()
                idx += self.sub_batchsize

            outputs_l = torch.cat(outputs_l, dim=0)
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        if self.mode == 3:  # global&local
            img_g = sample['image_g'].cuda()
            img_l = sample['image_l'].cuda()
            batch_size = img_l.size(0)
            idx = 0
            outputs = []; outputs_g = []; outputs_l = []
            while idx+self.sub_batchsize <= batch_size:
                output, output_g, output_l, mse = model.forward(img_g[idx:idx+self.sub_batchsize], img_l[idx:idx+self.sub_batchsize], target=labels_torch[idx:idx+self.sub_batchsize])
                outputs.append(output); outputs_g.append(output_g); outputs_l.append(output_l)
                loss = 2* self.criterion(output, labels_torch[idx:idx+self.sub_batchsize]) + self.criterion(output_g, labels_torch[idx:idx+self.sub_batchsize]) + \
                        self.criterion(output_l, labels_torch[idx:idx+self.sub_batchsize]) + self.fmreg * mse
                loss.backward()
                idx += self.sub_batchsize
            outputs = torch.cat(outputs, dim=0); outputs_g = torch.cat(outputs_g, dim=0); outputs_l = torch.cat(outputs_l, dim=0) 
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        # predictions
        if self.mode == 1:
            outputs_g = outputs_g.cpu()
            predictions_global = [outputs_g[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            self.metrics_global.update(labels_npy, predictions_global)
        
        if self.mode == 2:
            outputs_l = outputs_l.cpu()
            predictions_local = [outputs_l[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            self.metrics_local.update(labels_npy, predictions_local)
        
        if self.mode == 3:
            outputs_g = outputs_g.cpu(); outputs_l = outputs_l.cpu(); outputs = outputs.cpu()
            predictions_global = [outputs_g[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            predictions_local = [outputs_l[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            predictions = [outputs[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            self.metrics_global.update(labels_npy, predictions_global)
            self.metrics_local.update(labels_npy, predictions_local)
            self.metrics.update(labels_npy, predictions)

        return loss
Пример #10
0
class Evaluator(object):
    def __init__(self, n_class, sub_batchsize, mode=1, test=False):
        self.metrics_global = ConfusionMatrix(n_class)
        self.metrics_local = ConfusionMatrix(n_class)
        self.metrics = ConfusionMatrix(n_class)
        self.n_class = n_class
        self.sub_batchsize = sub_batchsize
        self.mode = mode
        self.test = test

    def get_scores(self):
        score_train = self.metrics.get_scores()
        score_train_local = self.metrics_local.get_scores()
        score_train_global = self.metrics_global.get_scores()
        
        return score_train, score_train_global, score_train_local

    def reset_metrics(self):
        self.metrics.reset()
        self.metrics_local.reset()
        self.metrics_global.reset()  

    def eval_test(self, sample, model):
        with torch.no_grad():
            ids = sample['id']
            h, w = sample['output_size'][0]
            if not self.test:
                labels = sample['label'].squeeze(1).long()
                labels_npy = np.array(labels)

            if self.mode == 1:  # global
                img_g = sample['image_g'].cuda()
                outputs_g = model.forward(img_g)
                outputs_g = F.interpolate(outputs_g, size=(h, w), mode='bilinear')
        
            if self.mode == 2:  # local
                img_l = sample['image_l'].cuda()
                batch_size = img_l.size(0)
                idx = 0
                outputs_l = []
                while idx+self.sub_batchsize <= batch_size:
                    output_l = model.forward(img_l[idx:idx+self.sub_batchsize])
                    output_l = F.interpolate(output_l, size=(h, w), mode='bilinear')
                    outputs_l.append(output_l)
                    idx += self.sub_batchsize

                outputs_l = torch.cat(outputs_l, dim=0)
        
            if self.mode == 3:  # global&local
                img_g = sample['image_g'].cuda()
                img_l = sample['image_l'].cuda()
                batch_size = img_l.size(0)
                idx = 0
                outputs = []; outputs_g = []; outputs_l = []
                while idx+self.sub_batchsize <= batch_size:
                    output, output_g, output_l, mse = model.forward(img_g[idx:idx+self.sub_batchsize], img_l[idx:idx+self.sub_batchsize])
                    outputs.append(output); outputs_g.append(output_g); outputs_l.append(output_l)
                    idx += self.sub_batchsize
                
                outputs = torch.cat(outputs, dim=0); outputs_g = torch.cat(outputs_g, dim=0); outputs_l = torch.cat(outputs_l, dim=0) 
                # no target
                # outputs, outputs_g, outputs_l, mse = model.forward(img_g, img_l)
        
        # predictions
        if self.mode == 1:
            outputs_g = outputs_g.cpu()
            predictions_global = [outputs_g[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            if not self.test:
                self.metrics_global.update(labels_npy, predictions_global)
            
            return None, predictions_global, None
        
        if self.mode == 2:
            outputs_l = outputs_l.cpu()
            predictions_local = [outputs_l[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            if not self.test:
                self.metrics_local.update(labels_npy, predictions_local)
            
            return None, None, predictions_local
        
        if self.mode == 3:
            outputs_g = outputs_g.cpu(); outputs_l = outputs_l.cpu(); outputs = outputs.cpu()
            predictions_global = [outputs_g[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            predictions_local = [outputs_l[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            predictions = [outputs[i:i+1].argmax(1).detach().numpy() for i in range(len(labels))]
            if not self.test:
                self.metrics_global.update(labels_npy, predictions_global)
                self.metrics_local.update(labels_npy, predictions_local)
                self.metrics.update(labels_npy, predictions)
            
            return predictions, predictions_global, predictions_local