コード例 #1
0
    def test_rgb_to_grayscale(self):
        script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

        img_tensor, pil_img = self._create_data(32, 34, device=self.device)

        for num_output_channels in (3, 1):
            gray_pil_image = F.rgb_to_grayscale(
                pil_img, num_output_channels=num_output_channels)
            gray_tensor = F.rgb_to_grayscale(
                img_tensor, num_output_channels=num_output_channels)

            self.approxEqualTensorToPIL(gray_tensor.float(),
                                        gray_pil_image,
                                        tol=1.0 + 1e-10,
                                        agg_method="max")

            s_gray_tensor = script_rgb_to_grayscale(
                img_tensor, num_output_channels=num_output_channels)
            self.assertTrue(s_gray_tensor.equal(gray_tensor))

            batch_tensors = self._create_data_batch(16,
                                                    18,
                                                    num_samples=4,
                                                    device=self.device)
            self._test_fn_on_batch(batch_tensors,
                                   F.rgb_to_grayscale,
                                   num_output_channels=num_output_channels)
コード例 #2
0
    def __init__(self, root_dir='./data/256_ObjectCategories', split='train'):
        self.num_classes = 50
        self.split = split

        self._read_all(root_dir)
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        data = []
        labels = []
        if split == 'train':
            for i in range(len(self.data)):
                data.append(self.transform(self.data[i]))
                data.append(self.transform(hflip(self.data[i])))
                data.append(self.transform(rgb_to_grayscale(self.data[i], 3)))
                data.append(
                    self.transform(rgb_to_grayscale(hflip(self.data[i]), 3)))
                labels += [
                    self.labels[i], self.labels[i], self.labels[i],
                    self.labels[i]
                ]
        else:
            for i in range(len(self.data)):
                data.append(self.transform(self.data[i]))
                labels += [self.labels[i]]
        self.seed = np.random.RandomState(0)

        self.data = torch.stack(data)
        self.labels = torch.tensor(np.asarray(labels))
コード例 #3
0
def test_rgb_to_grayscale(device, num_output_channels):
    script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

    img_tensor, pil_img = _create_data(32, 34, device=device)

    gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
    gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

    _assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

    s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
    assert_equal(s_gray_tensor, gray_tensor)

    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
    _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
コード例 #4
0
    def test_rgb_to_grayscale(self):
        script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

        img_tensor, pil_img = self._create_data(32, 34, device=self.device)

        for num_output_channels in (3, 1):
            gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
            gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

            if num_output_channels == 1:
                print(gray_tensor.shape)

            self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

            s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
            self.assertTrue(s_gray_tensor.equal(gray_tensor))
コード例 #5
0
 def __call__(self, image):
     # [0.1, 1.9]
     color_balance = 0.1 + 1.8 * self.sample_magnitude(rand_negate=False)
     image = color_balance * image \
             + (1-color_balance) * TF.rgb_to_grayscale(image, 3)
     image = image.clamp(0., 1.)
     return image
コード例 #6
0
    def figure_3_d(content_image, style_image):
        content_image_yuv = rgb_to_yuv(content_image)
        content_luminance = content_image_yuv[:, :1].repeat(1, 3, 1, 1)
        content_chromaticity = content_image_yuv[:, 1:]

        style_luminance = rgb_to_grayscale(style_image, num_output_channels=3)

        print("Replicating Figure 3 (d)")
        output_luminance = paper.nst(
            content_luminance,
            style_luminance,
            impl_params=args.impl_params,
        )
        output_luminance = torch.mean(output_luminance, dim=1, keepdim=True)
        output_chromaticity = resize(content_chromaticity,
                                     output_luminance.size()[2:])
        output_image_yuv = torch.cat((output_luminance, output_chromaticity),
                                     dim=1)
        output_image = yuv_to_rgb(output_image_yuv)
        filename = make_output_filename(
            "gatys_et_al_2017",
            "fig_3",
            "d",
            impl_params=args.impl_params,
        )
        output_file = path.join(args.image_results_dir, filename)
        save_result(output_image, output_file)
コード例 #7
0
def test_content_transform_grayscale_image(subtests, content_image,
                                           impl_params, instance_norm):
    content_image = F.rgb_to_grayscale(content_image)
    edge_size = 16

    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)
    hyper_parameters.content_transform.edge_size = edge_size

    content_transform = paper.content_transform(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    if instance_norm:
        utils.make_reproducible()
    actual = content_transform(content_image)

    if impl_params:
        if instance_norm:
            # Since the transform involves an uncontrollable random component, we can't
            # do a reference test here
            return

        transform_image = F.resize(content_image, edge_size)
    else:
        transform = transforms.CenterCrop(edge_size)
        transform_image = transform(content_image)

    desired = transform_image.repeat(1, 3, 1, 1)

    ptu.assert_allclose(actual, desired)
コード例 #8
0
ファイル: impl.py プロジェクト: gauenk/cl_gen
 def __call__(self,pic):
     pic_bw = tvF.rgb_to_grayscale(pic,1)
     poisson_pic = torch.poisson(self.alpha*pic_bw,generator=self.seed)/self.alpha
     if pic.shape[-3] == 3:
         repeat = [1 for i in pic.shape]
         repeat[-3] = 3
         poisson_pic = poisson_pic.repeat(*(repeat))
     return poisson_pic
コード例 #9
0
 def forward(self, x: torch.Tensor, **kwargs):
     trigger = x[..., self.start_h: self.end_h, self.start_w: self.end_w]
     if trigger.size(1) == 3:
         trigger = F.rgb_to_grayscale(trigger)
     mlp_output = self.mlp_model(trigger.flatten(1))
     mlp_output = self.amplify_rate * mlp_output.softmax(1)[:, :self.num_classes]
     org_output = self.org_model(x).softmax(1)
     return (self.alpha * mlp_output + (1 - self.alpha) * org_output) / self.temperature
コード例 #10
0
    def __call__(self, sample):
        image = sample["image"]

        gray_probability = torch.rand(1)

        if gray_probability[0] > 0.5:
            image = TF.rgb_to_grayscale(image)

        if "image2" in sample:
            image2 = sample["image2"]

            gray_probability = torch.rand(1)

            if gray_probability[0] > 0.5:
                image2 = TF.rgb_to_grayscale(image2)

            return {'image': image, 'image2': image2, 'label': sample['label']}

        return {'image': image, 'label': sample['label']}
コード例 #11
0
    def __call__(self, sample):
        """
        Args:
            img (PIL Image or Tensor): Image to be converted to grayscale.

        Returns:
            PIL Image or Tensor: Grayscaled image.
        """
        img, target = sample['image'], sample['target']
        return {'image': F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels), 'target': target}
コード例 #12
0
    def __next__(self) -> torch.Tensor:
        # Loading image
        succes, image = self.capture.read()
        if succes is False:
            raise StopIteration

        image = torch.tensor(image)
        image = torch.movedim(torch.tensor(image), -1, 0)
        image = rgb_to_grayscale(image).squeeze()
        return image
コード例 #13
0
ファイル: dataset.py プロジェクト: STomoya/animeface
 def transform(self, image):
     image = TF.resize(image, int(self.image_size * self.resize_ratio))
     image = TF.center_crop(image, self.image_size)
     gray = TF.rgb_to_grayscale(image)
     image = TF.adjust_hue(image, (random.random() - 0.5) / 5)
     image = TF.to_tensor(image)
     image = TF.normalize(image, 0.5, 0.5)
     gray = TF.to_tensor(gray)
     gray = TF.normalize(gray, 0.5, 0.5)
     return image, gray
コード例 #14
0
def load_frame(path: str, frame: int, color=False):
    """Load frame of video and turn into grayscale."""
    loader = DataLoader(path)
    loader.dataset.set_frame(frame)

    _, (_, image) = next(enumerate(loader))
    if not color:
        image = rgb_to_grayscale(image.permute(2, 0, 1)).squeeze()

    return image
コード例 #15
0
ファイル: image_utils.py プロジェクト: yarenty/ludwig
def grayscale(img: torch.Tensor) -> torch.Tensor:
    try:
        import torchvision.transforms.functional as F
    except ImportError:
        logger.error("torchvision is not installed. "
                     "In order to install all image feature dependencies run "
                     "pip install ludwig[image]")
        sys.exit(-1)

    return F.rgb_to_grayscale(img)
コード例 #16
0
    def process_state(self, state):
        x = torch.tensor(data=state, dtype=torch.float, device=device)
        # Change color channel position from (210, 160, 3) to (1, 210, 160)
        x = x.permute(2, 0, 1)
        # From color to gray
        x = rgb_to_grayscale(x)
        # Resize from (1, 210, 160) to (1, 80, 80)
        x = Resize([80, 80])(x)
        # Reduce size 1 dimension
        x = x.squeeze(0)

        return x
コード例 #17
0
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be converted to grayscale.

        Returns:
            PIL Image or Tensor: Randomly grayscaled image.
        """
        num_output_channels = F._get_image_num_channels(img)
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(
                img, num_output_channels=num_output_channels), True
        return img, False
コード例 #18
0
ファイル: utils.py プロジェクト: laitalaj/cvpce
def build_mask(img, tolerance=1e-2):
    _, h, w = img.shape
    corners = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
    gray_image = ttf.rgb_to_grayscale(img).numpy()
    white_corners = [(x, y) for x, y in corners
                     if gray_image[0, y, x] >= 1 - tolerance]
    sobel_image = sobel(gray_image)[0]
    mask = np.full((h, w), False)
    for x, y in white_corners:
        if mask[y, x]: continue
        cfill = flood(sobel_image, (y, x), tolerance=tolerance)
        mask = mask | cfill
    return torch.tensor(mask)
 def process_state(self, state):
     x = torch.tensor(data=state, dtype=torch.float, device=self.device)
     # Change color channel position from (210, 160, 3) to (1, 210, 160)
     x = x.permute(2, 0, 1)
     # From color to gray
     x = rgb_to_grayscale(x)
     # Resize from (1, 210, 160) to (1, 84, 84)
     x = Resize([RESIZE, RESIZE])(x)
     # Reduce size 1 dimension
     x = x.squeeze(0)
     # Normalize input 0 to i
     x = x.div(255)
     return x.detach().cpu().numpy()
コード例 #20
0
ファイル: pytorch.py プロジェクト: mindee/doctr
def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
    out = F.rgb_to_grayscale(img, num_output_channels=3)
    # Random RGB shift
    shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1]
    rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
    # Inverse the color
    if out.dtype == torch.uint8:
        out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
    else:
        out = out * rgb_shift.to(dtype=out.dtype)
    # Inverse the color
    out = 255 - out if out.dtype == torch.uint8 else 1 - out
    return out
コード例 #21
0
def get_nmlz_img(cfg, raw_img):
    noise_type = cfg.noise_params.ntype
    if noise_type in ["g", "hg"]: nmlz_raw = raw_img - 0.5
    elif noise_type in ["qis"]:

        # -- convert to bw --
        raw_img_bw = tvF.rgb_to_grayscale(raw_img, 1)
        raw_img_bw = add_color_channel(raw_img_bw)

        # -- quantize as from adc(poisson_mean) --
        raw_img_bw = quantize_img(cfg, raw_img_bw)

        # -- start dnn normalization for optimization --
        nmlz_raw = raw_img_bw - 0.5
    else:
        print("[Warning]: Check normalize raw image.")
        nmlz_raw = raw_img
    return nmlz_raw
コード例 #22
0
ファイル: transform.py プロジェクト: Embattled/ryuocr
def random_morph(trainData, op_name=None):
    # if op_name not in known_patterns:
    #     raise Exception("Unknown pattern " + op_name + "!")
    dil = ImageMorph.MorphOp(op_name="dilation8")
    ero = ImageMorph.MorphOp(op_name="erosion4")

    ops = [dil, ero]

    for i in range(len(trainData)):
        r = torch.randint(2, (1, 1)).item()
        if r == 1:
            continue

        data = F.rgb_to_grayscale(trainData[i].clone(), 1)

        _, data = ero.apply(F.to_pil_image(data, mode="L"))
        data = ImageOps.colorize(data, "black", "white")
        trainData[i] = F.pil_to_tensor(data)
コード例 #23
0
ファイル: cnn_mnist.py プロジェクト: mint-lab/dl_tutorial
def predict(image, model):
    model.eval()
    with torch.no_grad():
        # Convert the given image to its 1 x 1 x 28 x 28 tensor
        if type(image) is torch.Tensor:
            tensor = image.type(torch.float) / 255  # Normalize to [0, 1]
        else:
            tensor = 1 - TF.to_tensor(image)  # Invert (white to black)
        if tensor.ndim < 3:
            tensor = tensor.unsqueeze(0)
        if tensor.shape[0] == 3:
            tensor = TF.rgb_to_grayscale(tensor)  # Make grayscale
        tensor = TF.resize(tensor, 28)  # Resize to 28 x 28
        dev = next(model.parameters()).device
        tensor = tensor.unsqueeze(0).to(dev)  # Add onw more dims

        output = model(tensor)
        digit = torch.argmax(output, dim=1)
        return digit.item()
コード例 #24
0
ファイル: impl.py プロジェクト: gauenk/cl_gen
 def __call__(self,pic):
     """
     :params pic: input image shaped [...,C,H,W]
     
     we assume C = 3 and then we convert it to BW. 
     """
     # if pic.max() <= 1: pic *= 255.
     # print("noise",torch.get_rng_state())
     device = pic.device
     pix_max = 2**self.nbits-1
     pic_bw = tvF.rgb_to_grayscale(pic,1)
     ll_pic = torch.poisson(self.alpha*pic_bw,generator=self.seed)
     ll_pic += self.read_noise*torch.randn(ll_pic.shape,device=device)
     if pic.shape[-3] == 3: ll_pic = self._add_color_channel(ll_pic)
     if self.use_adc:
         ll_pic = torch.round(ll_pic)
         ll_pic = torch.clamp(ll_pic, 0, pix_max)
     ll_pic /= self.alpha
     return ll_pic
コード例 #25
0
ファイル: learn.py プロジェクト: gauenk/cl_gen
def get_nmlz_tgt_img(cfg, raw_img):
    pix_max = 2**3 - 1
    noise_type = cfg.noise_params.ntype
    if noise_type in ["g", "hg"]: nmlz_raw = raw_img - 0.5
    elif noise_type in ["qis"]:
        params = cfg.noise_params[noise_type]
        pix_max = 2**params['nbits'] - 1
        raw_img_bw = tvF.rgb_to_grayscale(raw_img, 1)
        raw_img_bw = add_color_channel(raw_img_bw)
        # nmlz_raw = raw_scale * raw_img_bw - 0.5
        # raw_img_bw *= params['alpha']
        # raw_img_bw = torch.round(raw_img_bw)
        # print("ll",ll_pic.min().item(),ll_pic.max().item())
        # raw_img_bw = torch.clamp(raw_img_bw, 0, pix_max)
        # raw_img_bw /= params['alpha']
        # -- end of qis noise transform --

        # -- start dnn normalization for optimization --
        nmlz_raw = raw_img_bw - 0.5
    else:
        print("[Warning]: Check normalize raw image.")
        nmlz_raw = raw_img
    return nmlz_raw
コード例 #26
0
ファイル: augmentation.py プロジェクト: Jasonlee1995/Dilation
 def forward(self, image, mask):
     num_output_channels = F._get_image_num_channels(image)
     if torch.rand(1) < self.p:
         return F.rgb_to_grayscale(
             image, num_output_channels=num_output_channels), mask
     return image, mask
コード例 #27
0
ファイル: image_utils.py プロジェクト: ludwig-ai/ludwig
def grayscale(img: torch.Tensor) -> torch.Tensor:
    """Grayscales RGB image."""
    return F.rgb_to_grayscale(img)
コード例 #28
0
ファイル: benchmark_noises.py プロジェクト: gauenk/cl_gen
def main():

    # -- init --
    cfg = get_main_config()
    cfg.gpuid = 0
    cfg.batch_size = 1
    cfg.N = 2
    cfg.num_workers = 0
    cfg.dynamic.frames = cfg.N
    cfg.rot = edict()
    cfg.rot.skip = 0  # big gap between 2 and 3.

    # -- dynamics --
    cfg.dataset.name = "rots"
    cfg.dataset.load_residual = True
    cfg.dynamic.frame_size = 256
    cfg.frame_size = cfg.dynamic.frame_size
    cfg.dynamic.ppf = 0
    cfg.dynamic.total_pixels = cfg.N * cfg.dynamic.ppf
    torch.cuda.set_device(cfg.gpuid)

    # -- sim params --
    K = 10
    patchsize = 9
    db_level = "frame"
    search_method = "l2"
    # database_str = f"burstAll"
    database_idx = 1
    database_str = "burst{}".format(database_idx)

    # -- grab grids for experiments --
    noise_settings = create_noise_level_grid(cfg)
    # sim_settings = create_sim_grid(cfg)
    # motion_settings = create_motion_grid(cfg)

    for ns in noise_settings:

        # -=-=-=-=-=-=-=-=-=-=-
        #     loop params
        # -=-=-=-=-=-=-=-=-=-=-
        noise_level = 0.
        noise_type = ns.ntype
        noise_str = set_noise_setting(cfg, ns)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #    create path for results
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        path_args = (K, patchsize, cfg.batch_size, cfg.N, noise_str,
                     database_str, db_level, search_method)
        base = Path(f"output/benchmark_noise_types/{cfg.dataset.name}")
        root = Path(base /
                    "k{}_ps{}_b{}_n{}_{}_db-{}_sim-{}-{}".format(*path_args))
        print(f"Writing to {root}")
        if root.exists(): print("Running Experiment Again.")
        else: root.mkdir(parents=True)

        # -=-=-=-=-=-=-
        #   dataset
        # -=-=-=-=-=-=-
        data, loader = load_dataset(cfg, 'dynamic')
        if cfg.dataset.name == "voc":
            sample = next(iter(loader.tr))
        else:
            sample = data.tr[0]

        # -- load sample --
        burst, raw_img, res = sample['burst'], sample['clean'] - 0.5, sample[
            'res']
        kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)
        N, B, C, H, W = burst.shape
        if 'clean_burst' in sample: clean = sample['clean_burst'] - 0.5
        else: clean = burst - res
        if noise_type in ["qis", "pn"]: tvF.rgb_to_grayscale(clean, 3)
        # burst = tvF.rgb_to_grayscale(burst,3)
        # raw_img = tvF.rgb_to_grayscale(raw_img,3)
        # clean = tvF.rgb_to_grayscale(clean,3)

        # -- temp (delete me soon) --
        search_rot_grid = np.linspace(.3, .32, 100)
        losses = np.zeros_like(search_rot_grid)
        for idx, angle in enumerate(search_rot_grid):
            save_alpha_burst = 0.5 * burst[0] + 0.5 * tvF.rotate(
                burst[1], angle)
            losses[idx] = F.mse_loss(save_alpha_burst, burst[0]).item()
        min_arg = np.argmin(losses)
        angle = search_rot_grid[min_arg]

        ref_img = tvF.rotate(burst[1], angle)
        shift_grid = np.linspace(-20, 20, 40 - 1).astype(np.int)
        losses = np.zeros_like(shift_grid).astype(np.float)
        for idx, shift in enumerate(shift_grid):
            save_alpha_burst = 0.5 * burst[0] + 0.5 * torch.roll(
                ref_img, shift, -2)
            losses[idx] = F.mse_loss(save_alpha_burst, burst[0]).item()
        min_arg = np.argmin(losses)
        shift = shift_grid[min_arg]

        # -- run search --
        kindex = kindex_ds[0]
        database = None
        if database_str == f"burstAll":
            database = burst
            clean_db = clean
        else:
            database = burst[[database_idx]]
            clean_db = clean[[database_idx]]
        query = burst[[0]]
        sim_outputs = compute_similar_bursts_analysis(
            cfg,
            query,
            database,
            clean_db,
            K,
            patchsize=patchsize,
            shuffle_k=False,
            kindex=kindex,
            only_middle=cfg.sim_only_middle,
            db_level=db_level,
            search_method=search_method,
            noise_level=noise_level / 255.)
        sims, csims, wsims, b_dist, b_indx = sim_outputs

        # -- save images --
        fs = cfg.frame_size
        save_K = 1
        save_sims = rearrange(sims[:, :, :save_K],
                              'n b k1 c h w -> (n b k1) c h w')
        save_csims = rearrange(csims[:, :, :save_K],
                               'n b k1 c h w -> (n b k1) c h w')
        save_cdelta = clean[0] - save_csims[0]
        save_alpha_burst = 0.5 * burst[0] + 0.5 * torch.roll(
            tvF.rotate(burst[1], angle), shift, -2)

        save_burst = rearrange(burst, 'n b c h w -> (b n) c h w')
        save_clean = rearrange(clean, 'n b c h w -> (b n) c h w')
        save_b_dist = rearrange(b_dist[:, :, :save_K],
                                'n b k1 h w -> (n b k1) 1 h w')
        save_b_indx = rearrange(b_indx[:, :, :save_K],
                                'n b k1 h w -> (n b k1) 1 h w')
        save_b_indx = torch.abs(
            torch.arange(fs * fs).reshape(fs, fs) - save_b_indx).float()
        save_b_indx /= (torch.sum(save_b_indx) + 1e-16)
        tv_utils.save_image(save_sims,
                            root / 'sims.png',
                            nrow=B,
                            normalize=True,
                            range=(-0.5, 0.5))
        tv_utils.save_image(save_csims,
                            root / 'csims.png',
                            nrow=B,
                            normalize=True,
                            range=(-0.5, 0.5))
        tv_utils.save_image(save_cdelta,
                            root / 'cdelta.png',
                            nrow=B,
                            normalize=True,
                            range=(-0.5, 0.5))
        tv_utils.save_image(save_clean,
                            root / 'clean.png',
                            nrow=N,
                            normalize=True,
                            range=(-0.5, 0.5))
        tv_utils.save_image(save_burst,
                            root / 'burst.png',
                            nrow=N,
                            normalize=True,
                            range=(-0.5, 0.5))
        tv_utils.save_image(save_b_dist,
                            root / 'b_dist.png',
                            nrow=B,
                            normalize=True)
        tv_utils.save_image(raw_img, root / 'raw.png', nrow=B, normalize=True)
        tv_utils.save_image(save_b_indx,
                            root / 'b_indx.png',
                            nrow=B,
                            normalize=True)
        tv_utils.save_image(save_alpha_burst,
                            root / 'alpha_burst.png',
                            nrow=B,
                            normalize=True)

        # -- save top K patches at location --
        b = 0
        ref_img = clean[0, b]
        ps, fs = patchsize, cfg.frame_size
        xx, yy = np.mgrid[32:48, 48:64]
        xx, yy = xx.ravel(), yy.ravel()
        clean_pad = F.pad(clean[database_idx, [b]],
                          (ps // 2, ps // 2, ps // 2, ps // 2),
                          mode='reflect')[0]
        patches = []
        for x, y in zip(xx, yy):
            gt_patch = tvF.crop(ref_img, x - ps // 2, y - ps // 2, ps, ps)
            patches_xy = [gt_patch]
            for k in range(save_K):
                indx = b_indx[0, 0, k, x, y]
                xp, yp = (indx // fs) + ps // 2, (indx % fs) + ps // 2
                t, l = xp - ps // 2, yp - ps // 2
                clean_patch = tvF.crop(clean_pad, t, l, ps, ps)
                patches_xy.append(clean_patch)
                pix_diff = F.mse_loss(gt_patch[:, ps // 2, ps // 2],
                                      clean_patch[:, ps // 2, ps // 2]).item()
                pix_diff_img = pix_diff * torch.ones_like(clean_patch)
                patches_xy.append(pix_diff_img)
            patches_xy = torch.stack(patches_xy, dim=0)
            patches.append(patches_xy)
        patches = torch.stack(patches, dim=0)
        R = patches.shape[1]
        patches = rearrange(patches, 'l k c h w -> (l k) c h w')
        fn = f"patches_{b}.png"
        tv_utils.save_image(patches, root / fn, nrow=R, normalize=True)

        # -- stats about distance --
        mean_along_k = reduce(b_dist, 'n b k1 h w -> k1', 'mean')
        std_along_k = torch.std(b_dist, dim=(0, 1, 3, 4))
        fig, ax = plt.subplots(figsize=(8, 8))
        R = mean_along_k.shape[0]
        ax.errorbar(np.arange(R), mean_along_k, yerr=std_along_k)
        plt.savefig(root / "distance_stats.png", dpi=300)
        plt.clf()
        plt.close("all")

        # -- psnr between 1st neighbor and clean --
        psnrs = pd.DataFrame({
            "b": [],
            "k": [],
            "psnr": [],
            'crop200_psnr': []
        })
        for b in range(B):
            for k in range(K):

                # -- psnr --
                crop_raw = clean[0, b]
                crop_cmp = csims[0, b, k]
                rc_mse = F.mse_loss(crop_raw, crop_cmp,
                                    reduction='none').reshape(1, -1)
                rc_mse = torch.mean(rc_mse, 1).numpy() + 1e-16
                psnr_bk = np.mean(mse_to_psnr(rc_mse))
                print(psnr_bk)

                # -- crop psnr --
                crop_raw = tvF.center_crop(clean[0, b], 200)
                crop_cmp = tvF.center_crop(csims[0, b, k], 200)
                rc_mse = F.mse_loss(crop_raw, crop_cmp,
                                    reduction='none').reshape(1, -1)
                rc_mse = torch.mean(rc_mse, 1).numpy() + 1e-16
                crop_psnr = np.mean(mse_to_psnr(rc_mse))
                # if np.isinf(psnr_bk): psnr_bk = 50.
                psnrs = psnrs.append(
                    {
                        'b': b,
                        'k': k,
                        'psnr': psnr_bk,
                        'crop200_psnr': crop_psnr
                    },
                    ignore_index=True)
        # psnr_ave = np.mean(psnrs)
        # psnr_std = np.std(psnrs)
        # print( "PSNR: %2.2f +/- %2.2f" % (psnr_ave,psnr_std) )
        psnrs = psnrs.astype({
            'b': int,
            'k': int,
            'psnr': float,
            'crop200_psnr': float
        })
        psnrs.to_csv(root / "psnrs.csv", sep=",", index=False)
コード例 #29
0
            else:
                if not os.path.isfile(mark_img) and \
                        not os.path.isfile(mark_img := os.path.join(dir_path, mark_img)):
                    raise FileNotFoundError(mark_img.removeprefix(dir_path))
                mark_img = F.convert_image_dtype(F.pil_to_tensor(Image.open(mark_img)))
        if isinstance(mark_img, np.ndarray):
            mark_img = torch.from_numpy(mark_img)
        mark: torch.Tensor = mark_img.to(device=env['device'])
        if not already_processed:
            mark = F.resize(mark, size=(self.mark_width, self.mark_height))
            alpha_mask = torch.ones_like(mark[0])
            if mark.size(0) == 4:
                mark = mark[:-1]
                alpha_mask = mark[-1]
            if self.data_shape[0] == 1 and mark.size(0) == 3:
                mark = F.rgb_to_grayscale(mark, num_output_channels=1)
            mark = torch.cat([mark, alpha_mask.unsqueeze(0)])

            if mark_background_color is not None:
                mark = update_mark_alpha_channel(mark, get_edge_color(mark, mark_background_color))
            if self.mark_random_init:
                mark[:-1] = torch.rand_like(mark[:-1])

            if self.mark_scattered:
                mark_scattered_shape = [mark.size(0), self.mark_scattered_height, self.mark_scattered_width]
                mark = self.scatter_mark(mark, mark_scattered_shape)
        self.mark_height, self.mark_width = mark.shape[-2:]
        self.mark = mark
        return mark

コード例 #30
0
        raise ValueError(f"Invalid image resize method: {resize_method}")
    return img


def grayscale(img: torch.Tensor) -> torch.Tensor:
    try:
        import torchvision.transforms.functional as F
    except ImportError:
        logger.error(
            "torchvision is not installed. "
            "In order to install all image feature dependencies run "
            "pip install ludwig[image]"
        )
        sys.exit(-1)

    return F.rgb_to_grayscale(img)


def num_channels_in_image(img: torch.Tensor):
    if img is None or img.ndim < 2:
        raise ValueError("Invalid image data")

    if img.ndim == 2:
        return 1
    else:
        return img.shape[0]


def to_np_tuple(prop: Union[int, Iterable]) -> np.ndarray:
    """Creates a np array of length 2 from a Conv2D property.