Example #1
0
 def run_G(self, z, c, sync):
     with misc.ddp_sync(self.G_mapping, sync):
         ws = self.G_mapping(z, c)
         if self.style_mixing_prob > 0:
             with torch.autograd.profiler.record_function('style_mixing'):
                 cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                 cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                 ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
     with misc.ddp_sync(self.G_synthesis, sync):
         img = self.G_synthesis(ws)
     return img, ws
Example #2
0
 def run_G(self, codes, c, sync, truncation_psi=1, truncation_cutoff=None):
     with misc.ddp_sync(self.G_mapping, sync):
         for i in range(codes.shape[1]):
             ws = []
             for i in range(codes.shape[1]):
                 ws.append(self.G_mapping(codes[:, i], c,
                                        truncation_psi=truncation_psi if truncation_cutoff is None or i < truncation_cutoff else 1,
                                        truncation_cutoff=truncation_cutoff,
                                        skip_w_avg_update=True,
                                        broadcast=False))
             ws = torch.stack(ws, dim=1)
     with misc.ddp_sync(self.G_synthesis, sync):
         img = self.G_synthesis(ws, noise_mode='none')
     return img, ws
Example #3
0
    def run_G(self, z, c, sync):
        with misc.ddp_sync(self.G_mapping, sync):
            if self.synthesis_cfg.patchwise.enabled:
                if self.synthesis_cfg.num_modes > 1:
                    modes_idx = torch.randint(low=0, high=self.synthesis_cfg.num_modes, size=(len(z),), device=z.device)
                else:
                    modes_idx = None

                if self.synthesis_cfg.num_modes > 1 and self.synthesis_cfg.patchwise.mode_mixing_prob > 0:
                    context_modes_idx = []
                    for i in [0, 1]:
                        new_modes_idx = torch.randint(low=0, high=self.synthesis_cfg.num_modes, size=(len(z),), device=z.device)
                        mode_mixing_mask = (torch.rand(len(z), device=z.device) < self.synthesis_cfg.patchwise.mode_mixing_prob).float()
                        new_modes_idx = new_modes_idx.float() * mode_mixing_mask + modes_idx.float() * (1 - mode_mixing_mask)
                        context_modes_idx.append(new_modes_idx.long())
                else:
                    context_modes_idx = [modes_idx, modes_idx]

                ws_context = torch.stack([
                    self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[0]),
                    self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[1]),
                ], dim=1)
                w_dist = int(0.5 * self.synthesis_cfg.patchwise.w_coord_dist * self.synthesis_cfg.patchwise.grid_size)
                left_borders_idx = torch.randint(low=0, high=(2 * w_dist - self.synthesis_cfg.patchwise.grid_size), size=z.shape[:1], device=z.device)
            else:
                modes_idx = None
                ws_context = None
                left_borders_idx = None

            ws = self.G_mapping(z, c, modes_idx=modes_idx)

            if self.style_mixing_prob > 0:
                with torch.autograd.profiler.record_function('style_mixing'):
                    cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                    cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                    ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=modes_idx)[:, cutoff:]

                    if self.synthesis_cfg.patchwise.enabled:
                        for i in [0, 1]:
                            cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                            cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                            ws_context[:, i, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[i])[:, cutoff:]

        with misc.ddp_sync(self.G_synthesis, sync):
            img = self.G_synthesis(ws, ws_context=ws_context, left_borders_idx=left_borders_idx)

        return img, ws
Example #4
0
 def run_D(self, img, c, sync, phase):
     if self.diffaugment and phase in self.diffaugment_placement.split(','):
         img = DiffAugment(img, policy=self.diffaugment)
     if self.augment_pipe is not None:
         img = self.augment_pipe(img)
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Example #5
0
 def run_D(self, img, c, sync):
     if self.diffaugment:
         img = DiffAugment(img, policy=self.diffaugment)
     if self.augment_pipe is not None:
         img = self.augment_pipe(img)
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Example #6
0
 def run_G(self, z, c, sync):
     with misc.ddp_sync(self.G, sync):
         ws = self.G(z, c, subnet = "mapping")
         if self.style_mixing > 0:
             with torch.autograd.profiler.record_function("style_mixing"):
                 cutoff = torch.empty([], dtype = torch.int64, device = ws.device).random_(1, ws.shape[2])
                 cutoff = torch.where(torch.rand([], device = ws.device) < self.style_mixing, cutoff, torch.full_like(cutoff, ws.shape[2]))
                 ws[:, :, cutoff:] = self.G(torch.randn_like(z), c, skip_w_avg_update = True, subnet = "mapping")[:, :, cutoff:]
         if self.component_mixing > 0:
             with torch.autograd.profiler.record_function("component_mixing"):
                 cutoff = torch.empty([], dtype = torch.int64, device = ws.device).random_(1, ws.shape[1])
                 cutoff = torch.where(torch.rand([], device = ws.device) < self.style_mixing, cutoff, torch.full_like(cutoff, ws.shape[1]))
                 ws[:, cutoff:] = self.G(torch.randn_like(z), c, skip_w_avg_update = True, subnet = "mapping")[:, cutoff:]
     # with misc.ddp_sync(self.G, sync):
         img = self.G(ws = ws, subnet = "synthesis")
     return img, ws
Example #7
0
 def run_D(self, img, c, sync):
     if self.augment_pipe is not None:
         img = self.augment_pipe(img)
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Example #8
0
 def run_D(self, img, c, sync):
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Example #9
0
 def run_E(self, img, c, sync):
     with misc.ddp_sync(self.E, sync):
         codes = self.E(img, c)
     return codes