def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() print(style_model) state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) for name, param in style_model.named_parameters(): if param.requires_grad: print(name) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx, verbose=False) else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def train(args): """Meta train the model""" device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) # first move parameters to GPU transformer = TransformerNet().to(device) vgg = Vgg16(requires_grad=False).to(device) global optimizer optimizer = Adam(transformer.parameters(), args.meta_lr) global mse_loss mse_loss = torch.nn.MSELoss() content_loader, style_loader, query_loader = get_data_loader(args) content_weight = args.content_weight style_weight = args.style_weight lr = args.lr writer = SummaryWriter(args.log_dir) for iteration in trange(args.max_iter): transformer.train() # bookkeeping # using state_dict causes problems, use named_parameters instead all_meta_grads = [] avg_train_c_loss = 0.0 avg_train_s_loss = 0.0 avg_train_loss = 0.0 avg_eval_c_loss = 0.0 avg_eval_s_loss = 0.0 avg_eval_loss = 0.0 contents = content_loader.next()[0].to(device) features_contents = vgg(utils.normalize_batch(contents)) querys = query_loader.next()[0].to(device) features_querys = vgg(utils.normalize_batch(querys)) # learning rate scheduling lr = args.lr / (1.0 + iteration * 2.5e-5) meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5) for param_group in optimizer.param_groups: param_group['lr'] = meta_lr for i in range(args.meta_batch_size): # sample a style style = style_loader.next()[0].to(device) style = style.repeat(args.iter_batch_size, 1, 1, 1) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name)) for j in range(args.meta_step): # run forward transformation on contents transformed = transformer(contents, fast_weights) # compute loss features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight) # compute grad grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True) # update fast weights fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads)) avg_train_c_loss += c_loss.item() avg_train_s_loss += s_loss.item() avg_train_loss += loss.item() # run forward transformation on querys transformed = transformer(querys, fast_weights) # compute loss features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight) grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters()) all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)}) avg_eval_c_loss += c_loss.item() avg_eval_s_loss += s_loss.item() avg_eval_loss += loss.item() writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1) # compute dummy loss to refresh buffer transformed = transformer(querys) features_transformed = vgg(utils.standardize_batch(transformed)) dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight) meta_updates(transformer, dummy_loss, all_meta_grads) if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \ str(args.content_weight) + "_" + \ str(args.style_weight) + "_" + \ str(args.lr) + "_" + \ str(args.meta_lr) + "_" + \ str(args.meta_batch_size) + "_" + \ str(args.meta_step) + "_" + \ time.ctime() + ".pth" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print "Done, trained model saved at {}".format(save_model_path)
def fast_train(args): """Fast training""" device = torch.device("cuda" if args.cuda else "cpu") transformer = TransformerNet().to(device) if args.model: transformer.load_state_dict(torch.load(args.model)) vgg = Vgg16(requires_grad=False).to(device) global mse_loss mse_loss = torch.nn.MSELoss() content_weight = args.content_weight style_weight = args.style_weight lr = args.lr content_transform = transforms.Compose([ transforms.Resize(args.content_size), transforms.CenterCrop(args.content_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_dataset = datasets.ImageFolder(args.content_dataset, content_transform) content_loader = DataLoader(content_dataset, batch_size=args.iter_batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_workers) content_loader = iter(content_loader) style_transform = transforms.Compose([ transforms.Resize((args.style_size, args.style_size)), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style_image = utils.load_image(args.style_image) style_image = style_transform(style_image) style_image = style_image.unsqueeze(0).to(device) features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1))) gram_style = [utils.gram_matrix(y) for y in features_style] if args.only_in: optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr) else: optimizer = Adam(transformer.parameters(), lr=lr) for i in trange(args.update_step): contents = content_loader.next()[0].to(device) features_contents = vgg(utils.normalize_batch(contents)) transformed = transformer(contents) features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight) optimizer.zero_grad() loss.backward() optimizer.step() # save model transformer.eval().cpu() style_name = os.path.basename(args.style_image).split(".")[0] save_model_filename = style_name + ".pth" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") # log content and style weight parameters if hvd.rank() == 0: run.log('content_weight', np.float(args.content_weight)) run.log('style_weight', np.float(args.style_weight)) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_dataset = datasets.ImageFolder(args.dataset, transform) # Horovod: partition dataset among workers using DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) transformer = TransformerNet().to(device) # Horovod: broadcast parameters from rank 0 to all other processes hvd.broadcast_parameters(transformer.state_dict(), root_rank=0) # Horovod: scale learning rate by the number of GPUs optimizer = Adam(transformer.parameters(), args.lr * hvd.size()) # Horovod: wrap optimizer with DistributedOptimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=transformer.named_parameters()) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] print("starting training...") for e in range(args.epochs): print("epoch {}...".format(e)) transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: avg_content_loss = agg_content_loss / (batch_id + 1) avg_style_loss = agg_style_loss / (batch_id + 1) avg_total_loss = (agg_content_loss + agg_style_loss) / (batch_id + 1) mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_sampler), avg_content_loss, avg_style_loss, avg_total_loss) print(mesg) # log the losses the run history run.log('avg_content_loss', np.float(avg_content_loss)) run.log('avg_style_loss', np.float(avg_style_loss)) run.log('avg_total_loss', np.float(avg_total_loss)) if hvd.rank() == 0 and args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model if hvd.rank() == 0: transformer.eval().cpu() if args.export_to_onnx: # export model to ONNX format dummy_input = torch.randn(1, 3, 1024, 1024, device='cpu') save_model_path = os.path.join(args.save_model_dir, '{}.onnx'.format(args.model_name)) torch.onnx.export(transformer, dummy_input, save_model_path) else: save_model_path = os.path.join(args.save_model_dir, '{}.pth'.format(args.model_name)) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)