def network_test(args): device = torch.device("cuda" if args.cuda_device_no >= 0 else 'cpu') transform_network = load_transform_network(args) transform_network = transform_network.to(device) input_image = imload(args.test_content, args.imsize).to(device) with torch.no_grad(): output_image = transform_network(input_image) imsave(output_image, args.output) return None
return input_image if __name__ == '__main__': # get arguments parser = build_parser() args = parser.parse_args() # gpu device set if args.cuda_device_no >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device_no) device = torch.device('cuda' if args.cuda_device_no >= 0 else 'cpu') # load target images content_image = imload(args.target_content_filename, args.imsize, args.cropsize) content_image = content_image.to(device) style_image = imload(args.target_style_filename, args.imsize, args.cropsize) style_image = style_image.to(device) # load pre-trianed vgg vgg = get_vgg_feature_network(args.vgg_flag) vgg = vgg.to(device) # stylize image output_image = stylize_image(vgg=vgg, device=device, content_image=content_image, style_image=style_image,
def network_train(args): device = torch.device("cuda" if args.cuda_device_no >= 0 else 'cpu') # Transform Network transform_network = TransformNetwork() transform_network = transform_network.to(device) # Content Data set train_dataset = ImageFolder(args.train_content, get_transformer(args.imsize, args.cropsize)) # Loss network loss_network = torchvision.models.__dict__[args.vgg_flag](pretrained=True).features.to(device) # Optimizer optimizer = torch.optim.Adam(params=transform_network.parameters(), lr=args.lr) # Target style image single_style_img=False if args.train_style!=None: print("args.train_style=",args.train_style) single_style_img=True if single_style_img: target_style_image = imload(args.train_style, imsize=args.imsize).to(device) b, c, h, w = target_style_image.size() target_style_image = target_style_image.expand(args.batchs, c, h, w) else: img_names=os.listdir(args.train_style_folder) # Train loss_logs = {'content_loss':[], 'style_loss':[], 'tv_loss':[], 'total_loss':[]} for iteration in range(args.max_iter): if not single_style_img: target_img_name=random.choice(img_names) target_img_path=os.path.join(args.train_style_folder,target_img_name) target_style_image = imload(target_img_path, imsize=args.imsize).to(device) b, c, h, w = target_style_image.size() target_style_image = target_style_image.expand(args.batchs, c, h, w) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchs, shuffle=True) image = next(iter(train_dataloader)) image = image.to(device) output_image = transform_network(image) target_content_features = extract_features(loss_network, image, args.content_layers) target_style_features = extract_features(loss_network, target_style_image, args.style_layers) output_content_features = extract_features(loss_network, output_image, args.content_layers) output_style_features = extract_features(loss_network, output_image, args.style_layers) content_loss = calc_Content_Loss(output_content_features, target_content_features) style_loss = calc_Gram_Loss(output_style_features, target_style_features) tv_loss = calc_TV_Loss(output_image) total_loss = content_loss * args.content_weight + style_loss * args.style_weight + tv_loss * args.tv_weight loss_logs['content_loss'].append(content_loss.item()) loss_logs['style_loss'].append(style_loss.item()) loss_logs['tv_loss'].append(tv_loss.item()) loss_logs['total_loss'].append(total_loss.item()) optimizer.zero_grad() total_loss.backward() optimizer.step() # print loss logs if iteration % args.check_iter == 0: str_ = '%s: iteration: [%d/%d/],\t'%(time.ctime(), iteration, args.max_iter) for key, value in loss_logs.items(): # check most recent 100 loss values str_ += '%s: %2.2f,\t'%(key, sum(value[-100:])/100) print(str_) imsave(output_image.cpu(), args.save_path+"training_images.png") torch.save(transform_network.state_dict(), args.save_path+"transform_network.pth") # save train results torch.save(loss_logs, args.save_path+"loss_logs.pth") torch.save(transform_network.state_dict(), args.save_path+"transform_network.pth") return transform_network