Esempio n. 1
0
def get_result_G(G_name, z):
    G = Gs[G_name]
    P = Ps[G_name]
    if hasattr(G, "truncation"):
        wp = G.truncation(G.mapping(z))
        image, feature = G.synthesis(wp, generate_feature=True)
    else:
        image, feature = G(z, generate_feature=True)
    label = P(image, size=256)
    label_viz = segviz_numpy(label.cpu())
    image_set = []
    text_set = []
    SE = viz_models[G_name]
    segs = SE(feature, size=label.shape[2])
    for i, seg in enumerate(segs):
        est_label = bu(seg, 256).argmax(1)
        est_label_viz = segviz_numpy(est_label[0].cpu())
        image_set.append(est_label_viz)
        text_set.append(f"{i}")
    text_set[-1] = "LSE"
    text_set.append(formal_name(G_name))
    is_face = "ffhq" in G_name or "celebahq" in G_name
    text_set.append("UNet" if is_face else "DeeplabV3")
    image_set.append(torch2image(bu(image, 256))[0])
    image_set.append(label_viz)
    return image_set, text_set
Esempio n. 2
0
 def send_command(self, cmd):
     if cmd == "start":
         if self.running:
             return False
         self.lock.acquire()
         self.running = True
         self.lock.release()
         self.start()
         return True
     elif cmd == "pause":
         if not self.running:
             return False
         self.lock.acquire()
         self.running = False
         self.lock.release()
         return True
     elif cmd == "stop":
         self.lock.acquire()
         self.exit = True
         self.lock.release()
         return True
     elif cmd == "val":
         self.lock.acquire()
         images = []
         segvizs = []
         for _ in range(6):
             z = torch.randn(1, 512).cuda()
             with torch.no_grad():
                 image, segs = self.learner(z)
             seg = bu(segs[-1], 128).argmax(1)
             segvizs.append(segviz_numpy(seg.detach().cpu().numpy()))
             images.append(torch2image(bu(image, 128)).astype("uint8"))
         self.lock.release()
         return np.concatenate(images), np.stack(segvizs)
Esempio n. 3
0
 def __call__(self, images, size=None):
     """
 Expecting torch.Tensor as input
 """
     images = op.bu(images, self.resolution)
     x = self.net(images.clamp(-1, 1))  # (N, M, H, W)
     if size:
         x = op.bu(x, size)
     return x.argmax(1)  # (N, H, W)
Esempio n. 4
0
 def raw_prediction(self, images, size=None):
     """
 Expecting torch.Tensor as input
 """
     images = op.bu(images, self.resolution)
     x = self.net(images.clamp(-1, 1))  # (N, M, H, W)
     if size:
         x = op.bu(x, size)
     return x
Esempio n. 5
0
def eval_single(Gs, Ps, eval_file):
  G_name = listkey_convert(eval_file,
    ["stylegan2_ffhq", "stylegan2_bedroom", "stylegan2_church"])

  G = Gs[G_name]
  P = Ps[G_name]

  target_labels = read_labels(G_name, G, P)
  size = target_labels.shape[2]

  print(f"=> Loading from {eval_file}")
  z, wp = torch.load(eval_file, map_location='cpu')
  print(z.shape, wp.shape, target_labels.shape)
  N, M = z.shape[0] // 10, 10 # 10 repeats
  N_show = 4

  res_file = eval_file.replace(".pth", ".txt")
  is_gen = True #args.generate_image == "1" 
  is_eval = not os.path.exists(res_file)

  if is_gen or is_eval:
    images = []
    sample_labels = []
    for i in tqdm(range(wp.shape[0])):
      if not is_eval and i >= N_show * M:
        break
      with torch.no_grad():
        image = G.synthesis(wp[i:i+1].cuda())
        sample_labels.append(P(image, size=size).cpu())
      if i < N_show * M:
        images.append((bu(image, size).cpu() + 1) / 2)
    images = torch.cat(images)
    sample_labels = torch.cat(sample_labels)
    sample_labels = sample_labels.view(
      -1, M, *sample_labels.shape[1:])
    target_label_viz = bu(torch.stack([
      segviz_torch(x) for x in target_labels[:N_show]]), size)
    if is_gen:
      show_labels = bu(target_label_viz.cpu(), 256).unsqueeze(1)
      show_images = bu(images, 256).view(N_show, M, 3, 256, 256).cpu()
      all_images = torch.cat([show_labels, show_images], 1)
      disp_image = vutils.make_grid(all_images.view(
        -1, *all_images.shape[2:]),
        nrow=M+1, padding=10, pad_value=1)
      fpath = eval_file.replace(".pth", ".pdf")
      vutils.save_image(disp_image.unsqueeze(0), fpath)

    if is_eval:
      mIoU, c_ious = aggregate_iou(evaluate_predictions(
        target_labels, sample_labels))
      write_results(res_file, mIoU, c_ious)
  else:
    mIoU, c_iou = read_results(res_file)
Esempio n. 6
0
    def forward(self, features, size=None):
        """Given a set of features, return the segmentation.

    Args:
      features : A list of feature maps. When len(features) > len(self.layers)
                 , it is assumed that the features if taken from all the layers
                 from the generator and will be selected here. Otherwise, it is
                 assumed that the features correspond to self.layers.
      size : The target output size. Bilinear resize will be used if this
             argument is specified.
    Returns:
      A list of segmentations corresponding to each layer, with the last one 
      being the final segmentation integrating all layers.
    """

        outputs = []
        for i in range(len(self.layers)):
            feat = self._index_feature(features, i)
            x = self.extractor[i](feat)
            outputs.append(x)

        # detect final output size, if not specified
        size = size if size else outputs[-1].shape[2:]
        layers = op.bu(outputs, size)

        weight = self._calc_layer_weight()
        if self.lw_type == "none":
            final = sum(layers)
        else:
            final = sum([r * w for r, w in zip(layers, weight)])
        outputs.append(final)
        return outputs
Esempio n. 7
0
    def training_step(self, batch, batch_idx):
        segs, label = self(batch)
        seg = op.bu(segs[-1], label.size(2))

        if hasattr(self, "loss_fn_layer"):
            segloss = loss.segloss(segs, label, self.loss_fn_layer)
            n_layers = len(segloss) - 1  # The last one is final segmentation
            total_loss = 0
            for i in range(len(segloss)):  # 0 ~ len(segloss) - 1
                layer = 'final' if i == n_layers else f'{i}'
                layer_loss = segloss[i] * self.loss_layer_weight \
                  if i < n_layers else segloss[i]
                self.log(f'layer/{layer}', layer_loss)
                total_loss = total_loss + layer_loss
        else:
            total_loss = self.loss_fn_final(seg, label)
        self.log("main/total", total_loss)

        dt = seg.argmax(1).detach()
        gt = label.detach()
        IoU = iou(dt,
                  gt,
                  num_classes=self.model.n_class,
                  ignore_index=0,
                  absent_score=-1,
                  reduction='none')
        pixelacc = (dt == gt).sum() / float(dt.shape.numel())
        # pixelacc, mIoU, IoU
        self.train_evaluation.append([pixelacc, IoU])
        return total_loss
Esempio n. 8
0
    def fuse_stroke(G, SE, P, wp, image_stroke, image_mask, label_stroke,
                    label_mask):
        """
    Args:
      image_stroke : [3, H, W]
      image_mask : [1, H, W]
      label_stroke : [H, W]
      label_mask : [H, W]
    """
        size = label_mask.shape[1]
        image, feature = G.synthesis(wp, generate_feature=True)
        origin_image = bu(image, size=size).cpu()
        int_label = SE(feature, size=size)[-1].argmax(1).cpu() if SE else None
        ext_label = P(image, size=size).cpu() if P else None

        m = label_mask.cpu()
        fused_int_label = None if label_stroke is None or int_label is None else \
          ((1 - m) * int_label + m * label_stroke.cpu()).long()
        fused_ext_label = None if label_stroke is None or ext_label is None else \
          ((1 - m) * ext_label + m * label_stroke.cpu()).long()

        m = image_mask.cpu()
        fused_image = None if image_stroke is None else \
          (1 - m) * origin_image + m * image_stroke.cpu()

        return {
            "fused_image": fused_image,
            "fused_int_label": fused_int_label,
            "fused_ext_label": fused_ext_label,
            "origin_image": origin_image,
            "int_label": int_label,
            "ext_label": ext_label
        }
Esempio n. 9
0
def segloss(segs, label, loss_fn):
    """The final version of loss."""
    segloss = []
    size = label.size(2)
    for seg in segs:
        seg = op.bu(seg, size) if seg.size(2) != size else seg
        segloss.append(loss_fn(seg, label))
    return segloss
Esempio n. 10
0
 def training_step(self, batch, batch_idx):
     """
 batch is dummy, batch_idx is used for gradient accumulation
 """
     idx = batch_idx % len(self.label)
     feature = [f[idx:idx + 1].cuda() for f in self.feature]
     segs = self.model(feature, size=self.resolution)
     seg = op.bu(segs[-1], self.label.size(2))
     segloss = self.loss_fn_final(seg, self.label[idx:idx + 1])
     return segloss
Esempio n. 11
0
def segloss_bce(segs, label, loss_fn_layer, loss_fn_final):
    """Use BCE for each layer. It is slow and CPU intensive."""
    N = len(segs[0])
    seglosses = []
    for cat_id in range(label.shape[0]):
        segloss = []
        onehot = int2onehot(label[cat_id].unsqueeze(1),
                            segs[cat_id][0].shape[1])
        # BCE loss
        for i in range(N):
            seg = segs[cat_id][i]
            segloss.append(loss_fn_layer(seg if seg.size(2) == label.size(3) \
              else op.bu(seg, label.size(3)), onehot))
        # CE loss
        final = segs[cat_id][-1]
        segloss.append(loss_fn_final(final if final.size(2) == label.size(3) \
          else op.bu(final, label.size(3)), label[cat_id]))
        seglosses.append(segloss)
    return seglosses
Esempio n. 12
0
 def raw_prediction(self, images, size=256):
     """
 Args:
   images : torch.Tensor in [-1, 1]
   size : The target resolution
 """
     x = torch.stack([self.input_transform((1 + i) / 2) for i in images])
     y = self.net(x)[0]
     if hasattr(self, "label_indice"):
         y = y[:, self.label_indice]
     if y.size(2) != size:
         y = bu(x, size)
     return torch.cat([torch.zeros_like(y[:, :1]), y], 1)
Esempio n. 13
0
 def forward(self, features, size=None):
     for i in range(len(self.layers)):
         feat = self._index_feature(features, i)
         if i == 0:
             hidden = self.extractor[i](feat)
         else:
             if hidden.size(2) * 2 == feat.size(2):
                 hidden = F.interpolate(hidden,
                                        scale_factor=2,
                                        mode="nearest")
             hidden = self.reviser[i - 1](hidden)
             hidden = hidden + self.extractor[i](feat)
     x = self.visualizer(hidden)
     if size is not None and size != x.size(3):
         x = op.bu(x, size)
     return [x]
Esempio n. 14
0
def get_result_G(G_name, z):
  G = Gs[G_name]
  P = Ps[G_name]
  if hasattr(G, "truncation"):
    wp = G.truncation(G.mapping(z))
    image, feature = G.synthesis(wp, generate_feature=True)
  else:
    image, feature = G(z, generate_feature=True)
  label = P(image, size=256)
  label_viz = segviz_numpy(label.cpu())
  image_set = [torch2image(bu(image, 256))[0], label_viz]
  text_set = [
    formal_name(G_name).split("_")[0],
    "UNet" if is_face else "DeeplabV3"]
  for i, (SE_name, SE) in enumerate(SE_models[G_name].items()):    
    seg = SE(feature, size=label.shape[2])[-1]
    est_label = seg.argmax(1)
    est_label_viz = segviz_numpy(est_label[0].cpu())
    image_set.append(est_label_viz)
    text_set.append(SE_name)
  return image_set, text_set
Esempio n. 15
0
 def __sseg_image_z(G, z, tar, tar_mask, edit_strategy, P=None):
     """
 Precise semantic editing using semantic extractor.
 Args:
   P : The semantic extractor.
   G : The generator supporting the edit.
   z : The initial z to be edited.
   tar : The target semantic mask.
   tar_mask : Denoting user changed region. But not used currently.
 """
     edit_strategy.setup(z)
     z0 = edit_strategy.z0
     for i in (range(edit_strategy.n_iter)):
         z, wps = edit_strategy.to_std_form()
         image = bu(G.synthesis(wps), tar.shape[2])
         diff = (tar - image)**2
         mseregloss = ((1 - tar_mask) * diff).sum() / (1 - tar_mask).sum()
         mseeditloss = (tar_mask * diff).sum() / tar_mask.sum()
         regloss = 1e-3 * ((z - z0)**2).sum()
         priorloss = 1e-3 * (z**2).sum() / z.shape[0]
         edit_strategy.step(mseregloss + mseeditloss + regloss + priorloss)
     return edit_strategy.to_std_form()
 def fusion(feats, masks):
     masks = bu(masks, feats[0].shape[2])
     return sum([f * m for f, m in zip(feats, masks)])
Esempio n. 17
0
      read_data(args.data_dir, name_list)
    param_dict = dict(latent_strategy="mixwp",
                      optimizer='adam',
                      n_iter=50,
                      base_lr=0.002)
    print(z.shape, image_stroke.shape, image_mask.shape, label_stroke.shape,
          label_mask.shape)
    G_name = "stylegan2_ffhq"
    DIR = "predictors/pretrain/"
    G = build_generator(G_name).net
    P = build_predictor("face_seg")
    SE_full = load_semantic_extractor(f"{DIR}/{G_name}_LSE.pth").cuda()
    SE_fewshot = load_semantic_extractor(
        f"{DIR}/{G_name}_8shot_LSE.pth").cuda()

    proc = lambda x: ((bu(x.detach().cpu(), 256) + 1) / 2)
    vizproc = lambda x: (bu(segviz_torch(x.detach().cpu()).unsqueeze(0), 256))
    origins, labels, baselines, fewshots, fulls = [], [], [], [], []
    for i in tqdm(range(z.shape[0])):
        zs = z[i].cuda()
        with torch.no_grad():
            wp = EditStrategy.z_to_wp(G,
                                      zs,
                                      in_type="zs",
                                      out_type="notrunc-wp")

        res = ImageEditing.fuse_stroke(G, SE_full, P, wp, image_stroke[i],
                                       image_mask[i], label_stroke[i],
                                       label_mask[i])
        origins.append(proc(res["origin_image"]))
        labels.append(vizproc(res["ext_label"]))