Example #1
0
        def __init__(self,
                     height=64,
                     width=64,
                     with_r=False,
                     with_boundary=False):
            super(AddCoordsTh, self).__init__()
            self.with_r = with_r
            self.with_boundary = with_boundary
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')

            with torch.no_grad():
                x_coords = torch.arange(height).unsqueeze(1).expand(
                    height, width).float()
                y_coords = torch.arange(width).unsqueeze(0).expand(
                    height, width).float()
                x_coords = (x_coords / (height - 1)) * 2 - 1
                y_coords = (y_coords / (width - 1)) * 2 - 1
                coords = torch.stack([x_coords, y_coords],
                                     dim=0)  # (2, height, width)

                if self.with_r:
                    rr = torch.sqrt(
                        torch.pow(x_coords, 2) +
                        torch.pow(y_coords, 2))  # (height, width)
                    rr = (rr / torch.max(rr)).unsqueeze(0)
                    coords = torch.cat([coords, rr], dim=0)

                self.coords = coords.unsqueeze(0).to(
                    device)  # (1, 2 or 3, height, width)
                self.x_coords = x_coords.to(device)
                self.y_coords = y_coords.to(device)
Example #2
0
    def finetune(self, z, y):
        h = self.shared(z)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = porch.stack(out, dim=1)  # (batch, num_domains, style_dim)

        s = porch.take(
            out, list(zip(range(y.size(0)),
                          y.numpy().astype(int).tolist())))
        return s, h, out
Example #3
0
    def forward(self, x, y):

        h = self.shared(x)
        h = porch.varbase_to_tensor(h)
        h = h.view(h.size(0), -1)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = porch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        s = porch.take(
            out, list(zip(range(y.shape[0]),
                          y.numpy().astype(int).tolist())))
        return s
Example #4
0
def calculate_lpips_given_images(group_of_images):
    # group_of_images = [porch.randn(N, C, H, W) for _ in range(10)]
    device = porch.device('cuda' if porch.cuda.is_available() else 'cpu')
    lpips = LPIPS(pretrained_weights_fn="./metrics/LPIPS_pretrained.pdparams")
    lpips.eval()
    lpips_values = []
    num_rand_outputs = len(group_of_images)

    # calculate the average of pairwise distances among all random outputs
    for i in range(num_rand_outputs - 1):
        for j in range(i + 1, num_rand_outputs):
            lpips_values.append(lpips(group_of_images[i], group_of_images[j]))
    lpips_value = porch.mean(porch.stack(lpips_values, dim=0))
    return lpips_value.numpy()
Example #5
0
def shift(x, N):
    """Shift N pixels up or down."""

    up = N >= 0
    N = abs(N)
    _, _, H, W = x.shape
    if N == 0:
        return x
    if up:
        head = torch.arange(H - N) + N
        tail = torch.arange(N)
    else:
        head = torch.arange(N) + (H - N)
        tail = torch.arange(H - N)

    # permutation indices
    perm = torch.cat([head, tail])
    out = torch.stack([x[:, :, int(a)] for a in perm.numpy()], dim=2)
    return out
Example #6
0
def calculate_metrics(nets, args, step, mode):
    print('Calculating evaluation metrics...')
    assert mode in ['latent', 'reference']
    device = porch.device('cuda' if porch.cuda.is_available() else 'cpu')
    for name in nets:
        nets[name].eval()
    domains = os.listdir(args.val_img_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)
    enable_lpips = True  # save time to check FID result
    if enable_lpips:
        lpips_dict = OrderedDict()
        for trg_idx, trg_domain in enumerate(domains):
            src_domains = [x for x in domains if x != trg_domain]

            if mode == 'reference':
                path_ref = os.path.join(args.val_img_dir, trg_domain)
                loader_ref = get_eval_loader(root=path_ref,
                                             img_size=args.img_size,
                                             batch_size=args.val_batch_size,
                                             imagenet_normalize=False,
                                             drop_last=True)

            for src_idx, src_domain in enumerate(src_domains):
                path_src = os.path.join(args.val_img_dir, src_domain)
                loader_src = get_eval_loader(root=path_src,
                                             img_size=args.img_size,
                                             batch_size=args.val_batch_size,
                                             imagenet_normalize=False)

                task = '%s2%s' % (src_domain, trg_domain)
                path_fake = os.path.join(args.eval_dir, task)
                shutil.rmtree(path_fake, ignore_errors=True)
                os.makedirs(path_fake)

                lpips_values = []
                print('Generating images and calculating LPIPS for %s...' %
                      task)
                for i, x_src in enumerate(
                        tqdm(loader_src, total=len(loader_src))):
                    x_src = porch.stack(
                        x_src, dim=0)  #porch.varbase_to_tensor(x_src[0])
                    N = x_src.size(0)
                    y_trg = porch.tensor([trg_idx] * N)
                    masks = nets.fan.get_heatmap(
                        x_src) if args.w_hpf > 0 else None

                    # generate 10 outputs from the same input
                    group_of_images = []
                    for j in range(args.num_outs_per_domain):
                        if mode == 'latent':
                            z_trg = porch.randn(N, args.latent_dim)
                            s_trg = nets.mapping_network(z_trg, y_trg)
                        else:
                            try:
                                x_ref = next(iter_ref)
                            except:
                                iter_ref = iter(loader_ref)
                                x_ref = next(iter_ref)
                            x_ref = porch.stack(x_ref, dim=0)
                            if x_ref.size(0) > N:
                                x_ref = x_ref[:N]
                            s_trg = nets.style_encoder(x_ref, y_trg)
                        x_fake = nets.generator(x_src, s_trg, masks=masks)

                        group_of_images.append(x_fake)

                        # save generated images to calculate FID later
                        for k in range(N):
                            filename = os.path.join(
                                path_fake, '%.4i_%.2i.png' %
                                (i * args.val_batch_size + (k + 1), j + 1))
                            utils.save_image(x_fake[k],
                                             ncol=1,
                                             filename=filename)

                    lpips_value = calculate_lpips_given_images(group_of_images)
                    lpips_values.append(lpips_value)

                # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
                lpips_mean = np.array(lpips_values).mean().astype(float)
                lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean

            # delete dataloaders
            del loader_src
            if mode == 'reference':
                del loader_ref
                del iter_ref

        # calculate the average LPIPS for all tasks
        lpips_mean = 0
        for _, value in lpips_dict.items():
            lpips_mean += value / len(lpips_dict)
        lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean

        # report LPIPS values
        filename = os.path.join(args.eval_dir,
                                'LPIPS_%.5i_%s.json' % (step, mode))
        utils.save_json(lpips_dict, filename)

    # calculate and report fid values
    return calculate_fid_for_all_tasks(args, domains, step=step, mode=mode)
    for name in nets:
        nets[name].train()
Example #7
0
def make_grid(tensor,
              nrow=8,
              padding=2,
              normalize=False,
              range=None,
              scale_each=False,
              pad_value=0):
    """Make a grid of images.

    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        nrow (int, optional): Number of images displayed in each row of the grid.
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by the min and max values specified by :attr:`range`. Default: ``False``.
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.

    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_

    """

    tensor = torch.Tensor(tensor)
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list)
             and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(
            type(tensor)))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = torch.stack(tensor, dim=0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.unsqueeze(0)
    if tensor.dim() == 3:  # single image
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
            tensor = torch.cat((tensor, tensor, tensor), 0)
        tensor = tensor.unsqueeze(0)

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)

    if normalize is True:
        tensor = tensor.clone()  # avoid modifying tensor in-place
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

    if tensor.size(0) == 1:
        return tensor.squeeze(0)

    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.size(2) +
                        padding), int(tensor.size(3) + padding)
    num_channels = tensor.size(1)
    # grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
    grid = np.zeros((num_channels, height * ymaps + padding,
                     width * xmaps + padding)) + pad_value
    k = 0
    for y in irange(ymaps):
        for x in irange(xmaps):
            if k >= nmaps:
                break
            # sub_grid=grid[:, (y * height + padding):(y * height + padding+height - padding) ][:,:,(x * width + padding):(x * width + padding+width - padding)]
            # torch.copy(tensor[k],sub_grid)
            grid[:, (y * height + padding):(y * height + padding + height -
                                            padding),
                 (x * width + padding):(x * width + padding + width -
                                        padding)] = tensor[k].numpy()
            # torch.copy(tensor[k],torch.narrow(torch.narrow(grid,1, y * height + padding, height - padding)\
            #     ,2, x * width + padding, width - padding) )

            k = k + 1
    return torch.Tensor(grid)