コード例 #1
0
def get_result_G(G_name, z):
    G = Gs[G_name]
    P = Ps[G_name]
    if hasattr(G, "truncation"):
        wp = G.truncation(G.mapping(z))
        image, feature = G.synthesis(wp, generate_feature=True)
    else:
        image, feature = G(z, generate_feature=True)
    label = P(image, size=256)
    label_viz = segviz_numpy(label.cpu())
    image_set = []
    text_set = []
    SE = viz_models[G_name]
    segs = SE(feature, size=label.shape[2])
    for i, seg in enumerate(segs):
        est_label = bu(seg, 256).argmax(1)
        est_label_viz = segviz_numpy(est_label[0].cpu())
        image_set.append(est_label_viz)
        text_set.append(f"{i}")
    text_set[-1] = "LSE"
    text_set.append(formal_name(G_name))
    is_face = "ffhq" in G_name or "celebahq" in G_name
    text_set.append("UNet" if is_face else "DeeplabV3")
    image_set.append(torch2image(bu(image, 256))[0])
    image_set.append(label_viz)
    return image_set, text_set
コード例 #2
0
 def send_command(self, cmd):
     if cmd == "start":
         if self.running:
             return False
         self.lock.acquire()
         self.running = True
         self.lock.release()
         self.start()
         return True
     elif cmd == "pause":
         if not self.running:
             return False
         self.lock.acquire()
         self.running = False
         self.lock.release()
         return True
     elif cmd == "stop":
         self.lock.acquire()
         self.exit = True
         self.lock.release()
         return True
     elif cmd == "val":
         self.lock.acquire()
         images = []
         segvizs = []
         for _ in range(6):
             z = torch.randn(1, 512).cuda()
             with torch.no_grad():
                 image, segs = self.learner(z)
             seg = bu(segs[-1], 128).argmax(1)
             segvizs.append(segviz_numpy(seg.detach().cpu().numpy()))
             images.append(torch2image(bu(image, 128)).astype("uint8"))
         self.lock.release()
         return np.concatenate(images), np.stack(segvizs)
コード例 #3
0
 def generate_new_image(self, model_name):
     print("=> [TrainerAPI] generate new image")
     G = self.Gs[model_name]
     z = torch.randn(1, 512).cuda()
     zs = z.repeat(G.num_layers, 1)  # use mixwp
     wp = G.mapping(z).unsqueeze(1).repeat(1, G.num_layers, 1)
     image, feature = G.synthesis(wp, generate_feature=True)
     zs = zs.detach().cpu().view(-1).numpy().tolist()
     image = torch2image(image).astype("uint8")[0]
     print("=> [TrainerAPI] done")
     return image, zs
コード例 #4
0
 def generate_new_image(self, model_name):
     G = self.Gs[model_name]
     z = torch.randn(1, 512).cuda()
     zs = z.repeat(G.num_layers, 1)  # use mixwp
     wp = G.mapping(z).repeat(G.num_layers, 1).unsqueeze(0)
     image, feature = G.synthesis(wp, generate_feature=True)
     seg = self.SE[model_name](feature)[-1]
     label = seg[0].argmax(0)
     image = torch2image(image).astype("uint8")[0]
     label_viz = segviz_numpy(torch2numpy(label))
     zs = zs.detach().cpu().view(-1).numpy().tolist()
     return image, label_viz, zs
コード例 #5
0
    def generate_image_given_stroke(self, model_name, zs, image_stroke,
                                    image_mask, label_stroke, label_mask):
        G, SE = self.Gs[model_name], self.SE[model_name]
        zs = np.array(zs, dtype=np.float32).reshape((G.num_layers, -1))
        time_str = get_time_str()
        p = f"{self.data_dir}/{time_str}"
        np.save(f"{p}_origin-zs.npy", zs)
        imwrite(f"{p}_image-stroke.png", image_stroke)
        imwrite(f"{p}_label-stroke.png", label_stroke)
        imwrite(f"{p}_image-mask.png", image_mask)
        imwrite(f"{p}_label-mask.png", label_mask)

        size = self.ma.models_config[model_name]["output_size"]
        zs = torch.from_numpy(zs).float().cuda().unsqueeze(0)  # (1, 18, 512)
        wp = EditStrategy.z_to_wp(G, zs, in_type="zs", out_type="notrunc-wp")
        image_stroke = preprocess_image(image_stroke, size).cuda()
        image_mask = preprocess_mask(image_mask, size).cuda()
        label_stroke = preprocess_label(label_stroke, SE.n_class, size)
        label_mask = preprocess_mask(label_mask, size).squeeze(1).cuda()
        fused_int_label = ImageEditing.fuse_stroke(
            G, SE, None, wp, image_stroke[0], image_mask[0], label_stroke[0],
            label_mask[0])["fused_int_label"]
        zs, wp = ImageEditing.sseg_edit(G,
                                        zs,
                                        fused_int_label,
                                        label_mask,
                                        SE,
                                        op="internal",
                                        latent_strategy="mixwp",
                                        optimizer='adam',
                                        n_iter=50,
                                        base_lr=0.01)

        image, feature = G.synthesis(wp.cuda(), generate_feature=True)
        label = SE(feature)[-1].argmax(1)
        image = torch2image(image)[0]
        label_viz = segviz_numpy(torch2numpy(label))
        zs = zs.detach().cpu().view(-1).numpy().tolist()
        imwrite(f"{p}_new-image.png", image)  # generated
        imwrite(f"{p}_new-label.png", label_viz)
        return image, label_viz, zs
コード例 #6
0
def get_result_G(G_name, z):
  G = Gs[G_name]
  P = Ps[G_name]
  if hasattr(G, "truncation"):
    wp = G.truncation(G.mapping(z))
    image, feature = G.synthesis(wp, generate_feature=True)
  else:
    image, feature = G(z, generate_feature=True)
  label = P(image, size=256)
  label_viz = segviz_numpy(label.cpu())
  image_set = [torch2image(bu(image, 256))[0], label_viz]
  text_set = [
    formal_name(G_name).split("_")[0],
    "UNet" if is_face else "DeeplabV3"]
  for i, (SE_name, SE) in enumerate(SE_models[G_name].items()):    
    seg = SE(feature, size=label.shape[2])[-1]
    est_label = seg.argmax(1)
    est_label_viz = segviz_numpy(est_label[0].cpu())
    image_set.append(est_label_viz)
    text_set.append(SE_name)
  return image_set, text_set