示例#1
0
def plot_images(
    model,
    logger,
    test_set,
    dest_task="normal",
    ood_images=None,
    show_masks=False,
    loss_models={},
    preds_name=None,
    target_name=None,
    ood_name=None,
):

    from task_configs import get_task, ImageTask

    test_images, preds, targets, losses, _ = model.predict_with_data(test_set)

    if isinstance(dest_task, str):
        dest_task = get_task(dest_task)

    if show_masks and isinstance(dest_task, ImageTask):
        test_masks = ImageTask.build_mask(targets,
                                          dest_task.mask_val,
                                          tol=1e-3)
        logger.images(test_masks.float(), f"{dest_task}_masks", resize=64)

    dest_task.plot_func(preds, preds_name or f"{dest_task.name}_preds", logger)
    dest_task.plot_func(targets, target_name or f"{dest_task.name}_target",
                        logger)

    if ood_images is not None:
        ood_preds = model.predict(ood_images)
        dest_task.plot_func(ood_preds, ood_name
                            or f"{dest_task.name}_ood_preds", logger)

    for name, loss_model in loss_models.items():
        with torch.no_grad():
            output = loss_model(preds, targets, test_images)
            if hasattr(output, "task"):
                output.task.plot_func(output, name, logger, resize=128)
            else:
                logger.images(output.clamp(min=0, max=1), name, resize=128)
示例#2
0
    def plot_paths(self,
                   graph,
                   logger,
                   realities=[],
                   plot_names=None,
                   epochs=0,
                   tr_step=0,
                   prefix=""):

        sqrt2 = math.sqrt(2)

        cmap = get_cmap("jet")
        path_values = {}
        realities_map = {reality.name: reality for reality in realities}
        for name, config in (plot_names or self.plots.items()):
            paths = config["paths"]

            realities = config["realities"]

            for reality in realities:
                with torch.no_grad():
                    path_values[reality] = self.compute_paths(
                        graph,
                        paths={path: self.paths[path]
                               for path in paths},
                        reality=realities_map[reality])

                    pred_mu = torch.Tensor().cuda()
                    pred_sig = torch.Tensor().cuda()
                    paths_ = [
                        'f(emboss4d(x))', 'f(grey(x))', 'f(sobel(x))',
                        'f(wav(x))', 'n(x)', 'f(gauss(x))', 'f(laplace(x))',
                        'f(sharp(x))'
                    ]
                    for p in paths_:
                        if p in [
                                'x', 'y^', 'bin(x)', 'grey(x)', 'emboss(x)',
                                'sobel(x)', 'stacked2reshade(x)', 'gauss(x)',
                                'laplace(x)', 'sharp(x)'
                        ]:
                            continue
                        if 'y^' in p and reality == 'ood': continue
                        path_values[reality][f'{p}_m'] = path_values[reality][
                            p][:, :1]
                        pred_mu = torch.cat(
                            (pred_mu, path_values[reality][f'{p}_m']), dim=1)
                        path_values[reality][f'{p}_s'] = path_values[reality][
                            p][:, 1:2].exp() * sqrt2
                        pred_sig = torch.cat(
                            (pred_sig, path_values[reality][f'{p}_s']), dim=1)
                        del path_values[reality][p]

                    lap_dist = torch.distributions.Laplace(loc=pred_mu,
                                                           scale=pred_sig +
                                                           1e-15)
                    pdfs = []
                    for i in range(len(paths_)):
                        pdfs.append(
                            (lap_dist.log_prob(
                                path_values[reality][f'{paths_[i]}_m']).exp() *
                             path_values[reality]['stacked2reshade(x)']).sum(
                                 1, keepdim=True))
                    pi = torch.cat(pdfs, dim=1)
                    onehot = torch.cuda.FloatTensor(pi.size()).fill_(0.)
                    onehot.scatter_(1, pi.argmax(1, keepdim=True), 1.0)
                    path_values[reality][f'stacked2reshade(x)_m'] = (
                        pred_mu * onehot).sum(1, keepdim=True)
                    path_values[reality][f'stacked2reshade(x)_s'] = (
                        pred_sig * onehot).sum(1, keepdim=True)

                    for i in [4, 0, 1, 2, 3, 5, 6, 7]:
                        path_values[reality][
                            f'stacked2reshade(x)_w{i+1}'] = path_values[
                                reality][f'stacked2reshade(x)'][:, i:i + 1]
                    del path_values[reality]['stacked2reshade(x)']

                    path_values[reality]['emboss4d(x)'] = path_values[reality][
                        'emboss4d(x)'][:, :3]

                    if reality is 'test':  #compute error map
                        mask_task = self.paths["y^"][-1]
                        mask = ImageTask.build_mask(path_values[reality]["y^"],
                                                    val=mask_task.mask_val)
                        errors = ((
                            path_values[reality]["y^"][:, :1] -
                            path_values[reality]["stacked2reshade(x)_m"][:, :1]
                        )**2).mean(dim=1, keepdim=True)
                        errors = (3 * errors / (mask_task.variance)).clamp(
                            min=0, max=1)
                        log_errors = torch.log(errors + 1)
                        log_errors = log_errors / log_errors.max()
                        log_errors = torch.tensor(cmap(
                            log_errors.cpu()))[:, 0].permute(
                                (0, 3, 1, 2)).float()[:, 0:3]
                        log_errors = log_errors.clamp(min=0, max=1).to(DEVICE)
                        log_errors[~mask.expand_as(log_errors)] = 0.505
                        path_values[reality]['error'] = log_errors

                    path_values[reality] = {
                        k: v.clamp(min=0, max=1).cpu()
                        for k, v in path_values[reality].items()
                    }

                    del path_values[reality][f'wav(x)']

        # more processing
        def reshape_img_to_rows(x_):
            downsample = lambda x: F.interpolate(
                x.unsqueeze(0), scale_factor=0.5, mode='bilinear').squeeze(0)
            x_list = [downsample(x_[i]) for i in range(x_.size(0))]
            x = torch.cat(x_list, dim=-1)
            return x

        all_images = {}
        for reality in realities:
            all_imgs_reality = []
            plot_name = ''

            keys_list = list(path_values[reality].keys())
            keys_list.remove('x')
            keys_list.insert(0, 'x')
            if 'y^' in keys_list:
                keys_list.remove('y^')
                keys_list.insert(1, 'y^')
            if 'error' in keys_list:
                keys_list.remove('error')
                keys_list.insert(-5, 'error')

            for k in keys_list:
                plot_name += k + '|'
                img_row = reshape_img_to_rows(path_values[reality][k])
                if img_row.size(0) == 1: img_row = img_row.repeat(3, 1, 1)
                all_imgs_reality.append(img_row)
            plot_name = plot_name[:-1]
            plot_name = plot_name.replace("stacked2reshade", "st")
            plot_name = plot_name.replace("emboss4d", "em")
            plot_name = plot_name.replace("grey", "gr")
            plot_name = plot_name.replace("sobel", "sb")
            plot_name = plot_name.replace("laplace", "lp")
            plot_name = plot_name.replace("gauss", "gs")
            plot_name = plot_name.replace("(x)", "")
            all_images[reality] = torch.cat(all_imgs_reality, dim=-2)

        return all_images
示例#3
0
    def plot_paths(self,
                   graph,
                   logger,
                   realities=[],
                   plot_names=None,
                   epochs=0,
                   tr_step=0,
                   prefix=""):
        error_pairs = {"n(x)": "y^"}
        realities_map = {reality.name: reality for reality in realities}
        for name, config in (plot_names or self.plots.items()):
            paths = config["paths"]

            realities = config["realities"]
            images = []
            error = False
            cmap = get_cmap("jet")

            first = True
            error_passed_ood = 0
            for reality in realities:
                with torch.no_grad():
                    path_values = self.compute_paths(
                        graph,
                        paths={path: self.paths[path]
                               for path in paths},
                        reality=realities_map[reality])

                shape = list(path_values[list(path_values.keys())[0]].shape)
                shape[1] = 3

                for i, path in enumerate(paths):
                    if path == 'depth': continue
                    X = path_values.get(path, torch.zeros(shape,
                                                          device=DEVICE))
                    if first: images += [[]]

                    if reality is 'ood' and error_passed_ood == 0:
                        images[i].append(X.clamp(min=0, max=1).expand(*shape))
                    elif reality is 'ood' and error_passed_ood == 1:
                        images[i + 1].append(
                            X.clamp(min=0, max=1).expand(*shape))
                    else:
                        images[-1].append(X.clamp(min=0, max=1).expand(*shape))

                    if path in error_pairs:

                        error = True
                        if first:
                            images += [[]]

                    if error:

                        Y = path_values.get(path,
                                            torch.zeros(shape, device=DEVICE))
                        Y_hat = path_values.get(
                            error_pairs[path], torch.zeros(shape,
                                                           device=DEVICE))

                        out_task = self.paths[path][-1]

                        if self.target_task == "reshading":  #Use depth mask
                            Y_mask = path_values.get(
                                "depth", torch.zeros(shape, device=DEVICE))
                            mask_task = self.paths["r(x)"][-1]
                            mask = ImageTask.build_mask(Y_mask,
                                                        val=mask_task.mask_val)
                        else:
                            mask = ImageTask.build_mask(Y_hat,
                                                        val=out_task.mask_val)

                        errors = ((Y - Y_hat)**2).mean(dim=1, keepdim=True)
                        log_errors = torch.log(
                            errors.clamp(min=0, max=out_task.variance))

                        errors = (3 * errors / (out_task.variance)).clamp(
                            min=0, max=1)

                        log_errors = torch.log(errors + 1)
                        log_errors = log_errors / log_errors.max()
                        log_errors = torch.tensor(cmap(
                            log_errors.cpu()))[:, 0].permute(
                                (0, 3, 1, 2)).float()[:, 0:3]
                        log_errors = log_errors.clamp(
                            min=0, max=1).expand(*shape).to(DEVICE)
                        log_errors[~mask.expand_as(log_errors)] = 0.505
                        if reality is 'ood':
                            images[i + 1].append(log_errors)
                            error_passed_ood = 1
                        else:
                            images[-1].append(log_errors)

                        error = False
                first = False

            for i in range(0, len(images)):
                images[i] = torch.cat(images[i], dim=0)

            logger.images_grouped(
                images,
                f"{prefix}_{name}_[{', '.join(realities)}]_[{', '.join(paths)}]",
                resize=config["size"])
示例#4
0
    def plot_paths(self,
                   graph,
                   logger,
                   realities=[],
                   plot_names=None,
                   epochs=0,
                   tr_step=0,
                   prefix=""):

        sqrt2 = math.sqrt(2)

        cmap = get_cmap("jet")
        path_values = {}
        realities_map = {reality.name: reality for reality in realities}
        for name, config in (plot_names or self.plots.items()):
            paths = config["paths"]

            realities = config["realities"]

            for reality in realities:
                with torch.no_grad():
                    # pdb.set_trace()
                    path_values[reality] = self.compute_paths(
                        graph,
                        paths={path: self.paths[path]
                               for path in paths},
                        reality=realities_map[reality])
                    if reality is 'test':  #compute error map
                        mask_task = self.paths["y^"][-1]
                        mask = ImageTask.build_mask(path_values[reality]["y^"],
                                                    val=mask_task.mask_val)
                        errors = (
                            (path_values[reality]["y^"][:, :3] -
                             path_values[reality]["f(g(x))"][:, :3])**2).mean(
                                 dim=1, keepdim=True)
                        errors = (3 * errors / (mask_task.variance)).clamp(
                            min=0, max=1)
                        log_errors = torch.log(errors + 1)
                        log_errors = log_errors / log_errors.max()
                        log_errors = torch.tensor(cmap(
                            log_errors.cpu()))[:, 0].permute(
                                (0, 3, 1, 2)).float()[:, 0:3]
                        log_errors = log_errors.clamp(min=0, max=1).to(DEVICE)
                        log_errors[~mask.expand_as(log_errors)] = 0.505
                        path_values[reality]['error'] = log_errors

                    nchannels = path_values[reality]['f(g(x))'].size(1) // 2
                    path_values[reality]['f(g(x))_m'] = path_values[reality][
                        'f(g(x))'][:, :nchannels]
                    path_values[reality]['f(g(x))_s'] = path_values[reality][
                        'f(g(x))'][:, nchannels:].exp() * sqrt2
                    path_values[reality]['f0(g(x))_m'] = path_values[reality][
                        'f0(g(x))'][:, :nchannels]
                    path_values[reality]['f0(g(x))_s'] = path_values[reality][
                        'f0(g(x))'][:, nchannels:].exp() * sqrt2
                    path_values[reality] = {
                        k: v.clamp(min=0, max=1).cpu()
                        for k, v in path_values[reality].items()
                    }
                    del path_values[reality]['f(g(x))']
                    del path_values[reality]['f0(g(x))']
                    # del path_values[reality]['emboss(x)']

        # more processing
        def reshape_img_to_rows(x_):
            downsample = lambda x: F.interpolate(
                x.unsqueeze(0), scale_factor=0.5, mode='bilinear').squeeze(0)
            x_list = [downsample(x_[i]) for i in range(x_.size(0))]
            x = torch.cat(x_list, dim=-1)
            return x

        all_images = {}
        for reality in realities:
            all_imgs_reality = []
            plot_name = ''
            for k in path_values[reality].keys():
                plot_name += k + '|'
                if path_values[reality][k].size(1) > 3:
                    path_values[reality][k] = path_values[reality][k][:, :3]
                img_row = reshape_img_to_rows(path_values[reality][k])
                if img_row.size(0) == 1: img_row = img_row.repeat(3, 1, 1)
                all_imgs_reality.append(img_row)
            plot_name = plot_name[:-1]
            all_images[reality] = torch.cat(all_imgs_reality, dim=-2)

        return all_images