class CyclicPhotoWCT(nn.Module): def __init__(self): super(CyclicPhotoWCT, self).__init__() self.fw = PhotoWCT() self.bw = PhotoWCT() def transform(self, cont_img, styl_img, cont_seg, styl_seg): stylized_img = self.fw.transform(cont_img, styl_img, cont_seg, styl_seg) reversed_img = self.bw.transform(stylized_img, cont_img, cont_seg, styl_seg) return stylized_img, reversed_img def forward(self, *input): pass
cont_seg = Image.open(content_seg_path) styl_seg = Image.open(style_seg_path) except: cont_seg = [] styl_seg = [] cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) cont_img = Variable(cont_img.cuda(0), volatile=True) styl_img = Variable(styl_img.cuda(0), volatile=True) cont_seg = np.asarray(cont_seg) styl_seg = np.asarray(styl_seg) start_style_time = time.time() stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg) end_style_time = time.time() print('Elapsed time in stylization: %f' % (end_style_time - start_style_time)) utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1) start_propagation_time = time.time() out_img = p_pro.process(output_image_path, content_image_path) end_propagation_time = time.time() print('Elapsed time in propagation: %f' % (end_propagation_time - start_propagation_time)) cv2.imwrite(output_image_path, out_img) start_postprocessing_time = time.time() out_img = smooth_filter(output_image_path, content_image_path, f_radius=15,