def decode(self, content, model_code): # decode content and style codes to an image DebugNet.setName("FewShotGen_Decode") adain_params = self.mlp(model_code) assign_adain_params(adain_params, self.dec) images = self.dec(content) return images
def forward(self, one_image, model_set): # reconstruct an image DebugNet.setName("FewShotGen") content, model_codes = self.encode(one_image, model_set) model_code = torch.mean(model_codes, dim=0).unsqueeze(0) images_trans = self.decode(content, model_code) return images_trans
def useFUNIT(class_img_folder, content_img_pth, output_path, trainer, transform): classPath = class_img_folder imgPths = [] imgNames = next(os.walk(classPath))[2] for imgName in imgNames: imgpath = os.path.join(classPath, imgName) imgPths.append(imgpath) final_class_code = trainer.model.gen_test.enc_class_model( transform(default_loader_custom(imgPths[0])).unsqueeze(0).cuda()) DebugNet.setName("input") image = default_loader_custom(content_img_pth) content_img = transform(image) #DebugNet.safeImage(content_img) content_img = content_img.unsqueeze(0) with torch.no_grad(): output_image = trainer.model.translate_simple(content_img, final_class_code) image = output_image.detach().cpu().squeeze().numpy() image = ((image + 1) * 0.5 * 255.0) if (len(image.shape) == 3): image = np.transpose(image, (1, 2, 0)) imsave(output_path, image) print('Save output to %s' % output_path)
def encode(self, one_image, model_set): # extract content code from the input image DebugNet.setName("FewShotGen_Encode") content = self.enc_content(one_image) # extract model code from the images in the model set class_codes = self.enc_class_model(model_set) class_code = torch.mean(class_codes, dim=0).unsqueeze(0) return content, class_code
def forward(self, x, y): DebugNet.setName("GPPatchMcResDis") #print("FORWARD ",self.__class__.__name__) assert (x.size(0) == y.size(0)) feat = self.cnn_f(x) out = self.cnn_c(feat) index = torch.LongTensor(range(out.size(0))).cuda() out = out[index, y, :, :] return out, feat
for i, f in enumerate(imgPths): img = default_loader_custom(f) img_tensor = transform(img).unsqueeze(0).cuda() with torch.no_grad(): class_code = trainer.model.compute_k_style(img_tensor, 1) if i == 0: new_class_code = class_code else: new_class_code += class_code final_class_code = new_class_code / len(imgPths) """ final_class_code = trainer.model.gen_test.enc_class_model( transform(default_loader_custom(imgPths[0])).unsqueeze(0).cuda()) print("Shape: ", final_class_code.shape) DebugNet.setName("input") image = default_loader_custom(opts.input) content_img = transform(image) #DebugNet.safeImage(content_img) content_img = content_img.unsqueeze(0) print('Compute translation for %s' % opts.input) with torch.no_grad(): output_image = trainer.model.translate_simple(content_img, final_class_code) image = output_image.detach().cpu().squeeze().numpy() print("Image has shape: ", image.shape) print("MIN: ", image.min()) print("MAX: ", image.max()) #image = np.transpose(image, (1, 2, 0)) image = ((image + 1) * 0.5 * 255.0)
def forward(self, x): #print("FORWARD ",self.__class__.__name__) DebugNet.setName("MLP") return self.model(x.view(x.size(0), -1))
def forward(self, x): #print("FORWARD ",self.__class__.__name__) DebugNet.setName("Decoder") return self.model(x)