def __init__(self, opt): super().__init__() self.opt = opt self.use_E = opt.netE is not None and len(opt.netE) self.netSR, self.netD, self.netE = self.initialize_networks(opt) self.mp = self.opt.model_parallel_mode self.model_variant = "guided" if "full" in self.opt.netE else "independent" self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ else torch.FloatTensor # set loss functions if opt.isTrain: self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) # holding variables for logging. Is overwritten at every iteration self.logs = OrderedDict() self.last_encoded_style_is_full = True self.last_encoded_style_is_noisy = False gpu_info("Init SR Model", self.opt)
def forward(self, x=None, seg=None, mode="mini"): gpu_info("Mini style encoder start", self.opt) x, activations = self.forward_main(x) x = self.final(x) # x has new height and width so we need to adjust the segmask size if seg.size(2) != x.size(2) or seg.size(3) != x.size(3): seg = F.interpolate(seg, size=(x.size(2), x.size(3)), mode='nearest') style_matrix = self.extract_style_matrix(x, seg) gpu_info("Mini style encoder end", self.opt) return style_matrix, activations
def forward(self, x=None, seg=None, mode=None, no_noise=False): gpu_info("Combined encoder start", self.opt) if mode == "full": x, activations = self.encoder_full.forward_main(x) elif mode == "mini": x, activations = self.encoder_mini.forward_main(x) else: raise NotImplementedError() x = self.final(x) # Shared layer style_matrix = self.extract_style_matrix(x, seg) if self.noisy_style and not no_noise: style_matrix = self.corrupt_style_matrix(style_matrix, self.max_range_noise) gpu_info("Combined encoder end", self.opt) return style_matrix, activations
def forward(self, x, seg, style=None, split_location=-1): if self.add_noise: x = self.noise_in(x) # style can be None x_s = self.shortcut(x, seg, style) # # # gpu_info("resblk line 84", self.opt) if split_location == 1: x = x.cuda(1) seg = seg.cuda(1) style = style.cuda(1) self.norm_0 = self.norm_0.cuda(1) self.conv_0 = self.conv_0.cuda(1) self.norm_1 = self.norm_1.cuda(1) gpu_info("resblk line 94", self.opt) x = self.actvn(self.norm_0(x, seg, style)) gpu_info("resblk line 96", self.opt) if self.efficient: dx = checkpoint(self.conv_0, x) else: dx = self.conv_0(x) gpu_info("resblk line 104", self.opt) if split_location == 2: dx = dx.cuda(3) seg = seg.cuda(3) style = style.cuda(3) self.norm_1 = self.norm_1.cuda(3) self.conv_1 = self.conv_1.cuda(3) if self.add_noise: self.noise_middle = self.noise_middle.cuda(3) gpu_info("resblk line 99", self.opt) if self.add_noise: dx = self.noise_middle(dx) dx = self.actvn(self.norm_1(dx, seg, style)) if split_location == 1: dx = dx.cuda(0) if self.efficient: dx = checkpoint(self.conv_1, dx) else: dx = self.conv_1(dx) if split_location == 2: x_s = x_s.cuda(3) out = x_s + dx # if split_location == 1: # out = out.cuda(0) return out
def generate_fake(self, input_semantics, image_downsized, encoded_style=None, full_image=None, no_noise=False, guiding_image=None, guiding_label=None): gpu_info("Generate fake start", self.opt) encoder_activations = None if encoded_style is None and "style" in self.opt.netE: encoded_style, encoder_activations = self.encode_style( downscaled_image=image_downsized, input_semantics=input_semantics, full_image=full_image, no_noise=no_noise, guiding_image=guiding_image, guiding_label=guiding_label, encode_full=self.opt.full_style_image) gpu_info("Generate fake before SR", self.opt) fake_image = self.netSR(image_downsized, seg=input_semantics, z=encoded_style) gpu_info("Generate fake after SR", self.opt) if self.mp == 1: fake_image = fake_image.cuda(0) encoded_style = encoded_style.cuda(0) return fake_image, encoder_activations, encoded_style
def compute_discriminator_loss(self, input_semantics, image_full, image_downsized, guiding_image, guiding_label): gpu_info("D loss start", self.opt) D_losses = {} with torch.no_grad(): fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_downsized, full_image=image_full, guiding_image=guiding_image, guiding_label=guiding_label) fake_image = fake_image.detach() fake_image.requires_grad_() gpu_info("D loss after generate", self.opt) pred_fake, pred_real = self.discriminate(input_semantics, fake_image, image_full) D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) D_losses['D_Real'] = self.criterionGAN(pred_real, True, for_discriminator=True) gpu_info("D loss end", self.opt) return D_losses
def forward(self, x=None, seg=None, mode="full", no_noise=False): gpu_info("full style encoder start", self.opt) if self.opt.random_style_matrix: x = torch.randn((seg.size(0), seg.size(1), self.opt.crop_size, self.opt.crop_size), device=seg.device) x = x * seg # We set the unused regions to zero. x, activations = self.forward_main(x) x = self.final(x) # x has new height and width so we need to adjust the segmask size if seg.size(2) != x.size(2) or seg.size(3) != x.size(3): seg = F.interpolate(seg, size=(x.size(2), x.size(3)), mode='nearest') style_matrix = self.extract_style_matrix(x, seg) if self.noisy_style and not no_noise: # For reconstruction or also for demo we want no noise style_matrix = self.corrupt_style_matrix(style_matrix, self.max_range_noise) gpu_info("full style encoder end", self.opt) return style_matrix, activations
def forward(self, x_downsized, seg=None, z=None): split_location = -1 gpu_info("SR start", self.opt) x = self.initial(x_downsized) x = self.head_0(x, seg, z) x = self.up(x) x = self.G_middle_0(x, seg, z) x = self.G_middle_1(x, seg, z) for i in range(self.n_blocks - 1): if self.mp == 1 and i == 3: split_location = 1 if self.mp == 2 and i == 3: self.up_list[i] = self.up_list[i].cuda(1) # self.conv_img = self.conv_img.cuda(1) x = x.cuda(1) seg = seg.cuda(1) z = z.cuda(1) if self.mp >= 2 and i == 4: self.up_list[i] = self.up_list[i].cuda(2) # self.conv_img = self.conv_img.cuda(1) x = x.cuda(2) seg = seg.cuda(2) z = z.cuda(2) split_location = 2 x = self.up(x) x = self.up_list[i](x, seg, z, split_location=split_location) gpu_info("SR after up {}".format(i), self.opt) if self.mp > 0: x = x.cuda(0) x = self.conv_img(F.leaky_relu(x, 2e-1)) x = F.tanh(x) gpu_info("SR end", self.opt) return x
def forward(self, data, mode, **kwargs): if self.opt.model_parallel_mode == 1: with torch.cuda.device(1): torch.cuda.empty_cache() gpu_info("Forward started", self.opt) input_semantics = data.get("input_semantics", None) image_lr = data.get("image_lr", None) image_hr = data.get("image_hr", None) # Only required if training # Only required if netE is "fullstyle" guiding_image = data.get("guiding_image", None) guiding_label = data.get("guiding_label", None) encoded_style = data.get("encoded_style", None) if mode == 'generator': g_loss, generated = self.compute_generator_loss( input_semantics, image_hr, image_lr, guiding_image, guiding_label) self.logs['image/downsized'] = image_lr return g_loss, generated elif mode == 'discriminator': d_loss = self.compute_discriminator_loss(input_semantics, image_hr, image_lr, guiding_image, guiding_label) return d_loss elif mode == 'inference': with torch.no_grad(): fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr, full_image=image_hr, no_noise=True, guiding_image=guiding_image, guiding_label=guiding_label) data["fake_image"] = fake_image # Filter out None values data = util.filter_none(data) return data elif mode == 'encode_only': encoded_style, encoder_activations = self.encode_style( downscaled_image=image_lr, input_semantics=input_semantics, full_image=image_hr, no_noise=True, guiding_image=guiding_image, guiding_label=guiding_label, encode_full=self.opt.full_style_image) return encoded_style elif mode == 'demo': with torch.no_grad(): fake_image = self.netSR(image_lr, seg=input_semantics, z=encoded_style) out = data out["fake_image"] = fake_image # Filter out None values out = util.filter_none(out) return out elif mode == 'baseline': image_baseline = networks.F.interpolate(image_lr, (image_hr.shape[-2:]), mode='bicubic').clamp( -1, 1) return OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", image_baseline), ("image_full", image_hr)]) elif mode == "inference_noise": with torch.no_grad(): n = self.opt.batchSize image_downsized_repeated = image_lr.repeat_interleave(n, 0) input_semantics_repeated = input_semantics.repeat_interleave( n, 0) # encoded_style = torch.randn((image_downsized_repeated.size(0), self.opt.label_nc, self.opt.regional_style_size), device=image_downsized_repeated.device) / 100 fake_image, _, _ = self.generate_fake( input_semantics=input_semantics_repeated, image_downsized=image_downsized_repeated, encoded_style=None) fake_image = torch.stack( [fake_image[i * n:i * n + n] for i in range(n)], dim=0) # Shape bs, bs, 3, H, W return OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_image), ("image_full", image_hr)]) elif mode == "inference_multi_modal": # Randomly varies the appearance for the given regions with torch.no_grad(): n = self.opt.n_interpolation # TODO: rename to something else consistent_regions = np.array( [4, 6, 8, 11]) # TODO: store in dataset class encoded_style, _ = self.encode_style( downscaled_image=image_lr, input_semantics=input_semantics, full_image=image_hr, no_noise=True, guiding_image=guiding_image, guiding_label=guiding_label) region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) # region_idx = np.random.choice(region_idx, 1) delta = self.opt.noise_delta fake_images = list() applied_style = list() for b in range(self.opt.batchSize): fake_samples = list() style_samples = list() for i in range(n): encoded_style_in = encoded_style[b].clone().detach() noise = self.get_noise( encoded_style_in[region_idx].shape, delta) encoded_style_in[region_idx] = ( encoded_style_in[region_idx] + noise).clamp(-1, 1) encoded_style_in[ consistent_regions] = encoded_style_in[ consistent_regions + 1] fake_image, _, _ = self.generate_fake( input_semantics=input_semantics[b].unsqueeze(0), image_downsized=image_lr[b].unsqueeze(0), encoded_style=encoded_style_in.unsqueeze(0)) fake_samples.append(fake_image) style_samples.append(encoded_style_in) if not self.opt.dont_merge_fake: to_append = torch.cat(fake_samples, -1) else: to_append = torch.stack(fake_samples, 1) to_append_style = torch.stack(style_samples) fake_images.append(to_append) applied_style.append(to_append_style) fake_out = torch.cat(fake_images, 0) elif mode == "inference_replace_semantics": with torch.no_grad(): # region_idx = self.opt.region_idx if self.opt.region_idx else list(range(input_semantics.size(1))) fake_images = list() fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr) fake_images.append(fake_image) regions_replace = [10] new_region_idx = 12 for i, rp in enumerate(regions_replace): if isinstance(new_region_idx, list): data['label'][data['label'] == rp] = new_region_idx[i] else: data['label'][data['label'] == rp] = new_region_idx input_semantics, image_hr, image_lr, guiding_image, guiding_label, guiding_image2, guiding_label2 = self.preprocess_input( data) # guiding image can be None fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr) fake_images.append(fake_image) fake_out = torch.cat(fake_images, -1) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_reference_semantics": with torch.no_grad(): # region_idx = self.opt.region_idx if self.opt.region_idx else list(range(input_semantics.size(1))) fake_images = list() bak_input_semantics = input_semantics.clone().detach() for b in range(self.opt.batchSize): current_semantics = input_semantics.clone().detach() for b_sem in range(self.opt.batchSize): current_semantics[b] = bak_input_semantics[(b_sem)] fake_image, _, _ = self.generate_fake( input_semantics=current_semantics, image_downsized=image_lr) fake_images.append(fake_image) fake_out = torch.cat(fake_images, -1) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_interpolation": with torch.no_grad(): if "style_matrix" in data: encoded_style = data["style_matrix"] else: encoded_style, _ = self.encode_style( downscaled_image=image_lr, input_semantics=input_semantics, full_image=image_hr, no_noise=True, guiding_image=guiding_image, guiding_label=guiding_label) n = self.opt.n_interpolation assert n % 2 == 1, "Please use an odd n such that the middle image has delta=0" delta = self.opt.noise_delta region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) fake_images = list() applied_style = list() for b in range(self.opt.batchSize): fake_samples = list() style_samples = list() for delta_step in np.linspace(-delta, delta, num=n): encoded_style_in = encoded_style[b].clone().detach() encoded_style_in[region_idx] = ( encoded_style_in[region_idx] + delta_step).clamp( -1, 1) fake_image, _, _ = self.generate_fake( input_semantics=input_semantics[b].unsqueeze(0), image_downsized=image_lr[b].unsqueeze(0), encoded_style=encoded_style_in.unsqueeze(0)) fake_samples.append(fake_image) style_samples.append(encoded_style_in) if not self.opt.dont_merge_fake: to_append = torch.cat(fake_samples, -1) else: to_append = torch.stack(fake_samples, 1) to_append_style = torch.stack(style_samples) applied_style.append(to_append_style) fake_images.append(to_append) fake_out = torch.cat(fake_images, 0) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr), ("style", applied_style)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_interpolation_style": with torch.no_grad(): encoded_style_from = data["style_from"].to( input_semantics.device) encoded_style_to = data["style_to"].to(input_semantics.device) n = self.opt.n_interpolation assert n % 2 == 1, "Please use an odd n such that the middle image has delta=0" fake_images = list() applied_style = list() for b in range(self.opt.batchSize): fake_samples = list() style_samples = list() for delta_step in np.linspace(0, 1, num=n): encoded_style_in = ( 1 - delta_step ) * encoded_style_from[b].clone().detach() + ( delta_step * encoded_style_to[b].clone().detach()) fake_image, _, _ = self.generate_fake( input_semantics=input_semantics[b].unsqueeze(0), image_downsized=image_lr[b].unsqueeze(0), encoded_style=encoded_style_in.unsqueeze(0)) fake_samples.append(fake_image) style_samples.append(encoded_style_in) if not self.opt.dont_merge_fake: to_append = torch.cat(fake_samples, -1) else: to_append = torch.stack(fake_samples, 1) to_append_style = torch.stack(style_samples) applied_style.append(to_append_style) fake_images.append(to_append) fake_out = torch.cat(fake_images, 0) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr), ("style", applied_style)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_particular_combined": with torch.no_grad(): encoded_style_mini, _ = self.encode_style( input_semantics=input_semantics, downscaled_image=image_lr, no_noise=True, encode_full=False, guiding_image=None, guiding_label=None) if self.opt.noise_delta > 0: region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) print("Adding noise to style for regions {}".format( region_idx)) noise = self.get_noise( encoded_style_mini[:, region_idx].shape, self.opt.noise_delta) encoded_style_mini[:, region_idx] = ( encoded_style_mini[:, region_idx] + noise).clamp( -1, 1) consistent_regions = np.array( [4, 6, 8, 11]) # TODO: store in dataset class encoded_style_mini[:, consistent_regions] = encoded_style_mini[:, consistent_regions + 1] fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr, encoded_style=encoded_style_mini) else: fake_image, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr, encoded_style=encoded_style_mini) # encoded_style_guided, _, _, _ = self.encode_style(downscaled_image=None, # no_noise=True, encode_full=True, # guiding_image=guiding_image, # guiding_label=guiding_label) # encoded_style_modified = encoded_style_mini.clone().detach() # encoded_style_modified[0, region_idx] = encoded_style_guided[0, region_idx] # fake_image_modified, _, _, _ = self.generate_fake(input_semantics=input_semantics, # image_downsized=image_downsized, # encoded_style=encoded_style_modified) out = OrderedDict([ ("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image_original", fake_image), # ("fake_image_modified", fake_image_modified), ("image_full", image_hr) ]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_particular_full": with torch.no_grad(): region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) encoded_style_full, _ = self.encode_style( input_semantics=None, downscaled_image=None, no_noise=True, encode_full=True, guiding_image=image_hr, guiding_label=input_semantics) fake_image_original, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr, encoded_style=encoded_style_full) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image_original", fake_image_original), ("image_full", image_hr)]) if self.opt.guiding_style_image: guiding_style, _ = self.encode_style( input_semantics=None, downscaled_image=None, no_noise=True, encode_full=True, guiding_image=guiding_image, guiding_label=guiding_label) # The fake image produced with a guiding style fake_image_guiding, _, _ = self.generate_fake( input_semantics=input_semantics, image_downsized=image_lr, encoded_style=guiding_style) out["fake_image_guiding"] = fake_image_guiding out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_reference": with torch.no_grad(): # encoded_style_mini, _, _, _ = self.encode_style(downscaled_image=image_downsized, # input_semantics=input_semantics, full_image=None, # no_noise=True) encoded_style_full, _ = self.encode_style( downscaled_image=None, input_semantics=input_semantics, full_image=image_hr, no_noise=True, encode_full=True, guiding_image=guiding_image, guiding_label=guiding_label) region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) fake_images = list() for b in range(self.opt.batchSize): fake_samples = list() for semantics_b in range(self.opt.batchSize): encoded_style_in = encoded_style_full[b].clone( ).detach() encoded_style_in[region_idx] = ( encoded_style_full[semantics_b, region_idx]).clamp(-1, 1) fake_image, _, _ = self.generate_fake( input_semantics=input_semantics[b].unsqueeze(0), image_downsized=image_lr[b].unsqueeze(0), encoded_style=encoded_style_in.unsqueeze(0)) fake_samples.append(fake_image) fake_images.append(torch.cat(fake_samples, -1)) fake_out = torch.cat(fake_images, 0) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out elif mode == "inference_reference_interpolation": with torch.no_grad(): # encoded_style_mini, _, _, _ = self.encode_style(downscaled_image=image_downsized, # input_semantics=input_semantics, full_image=None, # no_noise=True) encoded_style_full, _ = self.encode_style( downscaled_image=None, input_semantics=input_semantics, full_image=image_hr, no_noise=True, encode_full=True) region_idx = self.opt.region_idx if self.opt.region_idx else list( range(input_semantics.size(1))) fake_images = list() for b in range(self.opt.batchSize): fake_samples = list() idx_style_a = (b) % self.opt.batchSize style_a = encoded_style_full[idx_style_a].clone().detach() idx_style_b = (b + 1) % self.opt.batchSize semantics_b = encoded_style_full[idx_style_b].clone( ).detach() * self.opt.manipulate_scale for delta_step in np.linspace( 0, 1, num=self.opt.n_interpolation): encoded_style_in = style_a encoded_style_in[region_idx] = ( (1 - delta_step) * style_a[region_idx] + delta_step * semantics_b[region_idx]).clamp(-1, 1) fake_image, _, _ = self.generate_fake( input_semantics=input_semantics[b].unsqueeze(0), image_downsized=image_lr[b].unsqueeze(0), encoded_style=encoded_style_in.unsqueeze(0)) fake_samples.append(fake_image) fake_images.append(torch.cat(fake_samples, -1)) fake_out = torch.cat(fake_images, 0) out = OrderedDict([("input_label", input_semantics), ("image_downsized", image_lr), ("fake_image", fake_out), ("image_full", image_hr)]) if self.opt.guiding_style_image: out['guiding_image_id'] = data['guiding_image_id'] out['guiding_image'] = data['guiding_image'] out['guiding_input_label'] = data['guiding_label'] return out else: raise ValueError("|mode| is invalid")