def run_G(self, z, c, sync): with misc.ddp_sync(self.G_mapping, sync): ws = self.G_mapping(z, c) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] with misc.ddp_sync(self.G_synthesis, sync): img = self.G_synthesis(ws) return img, ws
def run_G(self, codes, c, sync, truncation_psi=1, truncation_cutoff=None): with misc.ddp_sync(self.G_mapping, sync): for i in range(codes.shape[1]): ws = [] for i in range(codes.shape[1]): ws.append(self.G_mapping(codes[:, i], c, truncation_psi=truncation_psi if truncation_cutoff is None or i < truncation_cutoff else 1, truncation_cutoff=truncation_cutoff, skip_w_avg_update=True, broadcast=False)) ws = torch.stack(ws, dim=1) with misc.ddp_sync(self.G_synthesis, sync): img = self.G_synthesis(ws, noise_mode='none') return img, ws
def run_G(self, z, c, sync): with misc.ddp_sync(self.G_mapping, sync): if self.synthesis_cfg.patchwise.enabled: if self.synthesis_cfg.num_modes > 1: modes_idx = torch.randint(low=0, high=self.synthesis_cfg.num_modes, size=(len(z),), device=z.device) else: modes_idx = None if self.synthesis_cfg.num_modes > 1 and self.synthesis_cfg.patchwise.mode_mixing_prob > 0: context_modes_idx = [] for i in [0, 1]: new_modes_idx = torch.randint(low=0, high=self.synthesis_cfg.num_modes, size=(len(z),), device=z.device) mode_mixing_mask = (torch.rand(len(z), device=z.device) < self.synthesis_cfg.patchwise.mode_mixing_prob).float() new_modes_idx = new_modes_idx.float() * mode_mixing_mask + modes_idx.float() * (1 - mode_mixing_mask) context_modes_idx.append(new_modes_idx.long()) else: context_modes_idx = [modes_idx, modes_idx] ws_context = torch.stack([ self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[0]), self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[1]), ], dim=1) w_dist = int(0.5 * self.synthesis_cfg.patchwise.w_coord_dist * self.synthesis_cfg.patchwise.grid_size) left_borders_idx = torch.randint(low=0, high=(2 * w_dist - self.synthesis_cfg.patchwise.grid_size), size=z.shape[:1], device=z.device) else: modes_idx = None ws_context = None left_borders_idx = None ws = self.G_mapping(z, c, modes_idx=modes_idx) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=modes_idx)[:, cutoff:] if self.synthesis_cfg.patchwise.enabled: for i in [0, 1]: cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws_context[:, i, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True, modes_idx=context_modes_idx[i])[:, cutoff:] with misc.ddp_sync(self.G_synthesis, sync): img = self.G_synthesis(ws, ws_context=ws_context, left_borders_idx=left_borders_idx) return img, ws
def run_D(self, img, c, sync, phase): if self.diffaugment and phase in self.diffaugment_placement.split(','): img = DiffAugment(img, policy=self.diffaugment) if self.augment_pipe is not None: img = self.augment_pipe(img) with misc.ddp_sync(self.D, sync): logits = self.D(img, c) return logits
def run_D(self, img, c, sync): if self.diffaugment: img = DiffAugment(img, policy=self.diffaugment) if self.augment_pipe is not None: img = self.augment_pipe(img) with misc.ddp_sync(self.D, sync): logits = self.D(img, c) return logits
def run_G(self, z, c, sync): with misc.ddp_sync(self.G, sync): ws = self.G(z, c, subnet = "mapping") if self.style_mixing > 0: with torch.autograd.profiler.record_function("style_mixing"): cutoff = torch.empty([], dtype = torch.int64, device = ws.device).random_(1, ws.shape[2]) cutoff = torch.where(torch.rand([], device = ws.device) < self.style_mixing, cutoff, torch.full_like(cutoff, ws.shape[2])) ws[:, :, cutoff:] = self.G(torch.randn_like(z), c, skip_w_avg_update = True, subnet = "mapping")[:, :, cutoff:] if self.component_mixing > 0: with torch.autograd.profiler.record_function("component_mixing"): cutoff = torch.empty([], dtype = torch.int64, device = ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device = ws.device) < self.style_mixing, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G(torch.randn_like(z), c, skip_w_avg_update = True, subnet = "mapping")[:, cutoff:] # with misc.ddp_sync(self.G, sync): img = self.G(ws = ws, subnet = "synthesis") return img, ws
def run_D(self, img, c, sync): if self.augment_pipe is not None: img = self.augment_pipe(img) with misc.ddp_sync(self.D, sync): logits = self.D(img, c) return logits
def run_D(self, img, c, sync): with misc.ddp_sync(self.D, sync): logits = self.D(img, c) return logits
def run_E(self, img, c, sync): with misc.ddp_sync(self.E, sync): codes = self.E(img, c) return codes