def get_data_loader(args): content_transform = transforms.Compose([ transforms.Resize(args.content_size), transforms.CenterCrop(args.content_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style_transform = transforms.Compose([ transforms.Resize((args.style_size, args.style_size)), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_dataset = datasets.ImageFolder(args.content_dataset, content_transform) style_dataset = datasets.ImageFolder(args.style_dataset, style_transform) content_loader = DataLoader(content_dataset, batch_size=args.iter_batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_workers) style_loader = DataLoader(style_dataset, batch_size=1, sampler=InfiniteSamplerWrapper(style_dataset), num_workers=args.n_workers) query_loader = DataLoader(content_dataset, batch_size=args.iter_batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_workers) return iter(content_loader), iter(style_loader), iter(query_loader)
def load_dataset(content_dir, style_dir): content_tf = train_transform() style_tf = train_transform() content_dataset = FlatFolderDataset(content_dir, content_tf) style_dataset = FlatFolderDataset(style_dir, style_tf) content_iter = iter(data.DataLoader( content_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_threads)) style_iter = iter(data.DataLoader( style_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(style_dataset), num_workers=args.n_threads)) return content_iter, style_iter
decoder.load_state_dict(torch.load(args.decoder)) network = net.Net(vgg, decoder) network.train() network.to(device) content_tf = train_transform() style_tf = train_transform() content_dataset = FlatFolderDataset(args.content_dir, content_tf) style_dataset = FlatFolderDataset(args.style_dir, style_tf) content_iter = iter( data.DataLoader(content_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_threads)) style_iter = iter( data.DataLoader(style_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(style_dataset), num_workers=args.n_threads)) optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr) for i in tqdm(range(args.max_iter)): adjust_learning_rate(optimizer, iteration_count=i) content_images = next(content_iter).to(device) style_images = next(style_iter).to(device) loss_c, loss_s = network(content_images, style_images) loss_c = args.content_weight * loss_c
def train(args): # Device, save and log configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") save_dir = Path(os.path.join(args.save_dir, args.name)) save_dir.mkdir(exist_ok=True, parents=True) log_dir = Path(os.path.join(args.log_dir, args.name)) log_dir.mkdir(exist_ok=True, parents=True) writer = SummaryWriter(log_dir=str(log_dir)) # Prepare datasets content_dataset = TrainDataset(args.content_dir, args.img_size) texture_dataset = TrainDataset(args.texture_dir, args.img_size, gray_only=True) color_dataset = TrainDataset(args.color_dir, args.img_size) content_iter = iter( data.DataLoader(content_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_threads)) texture_iter = iter( data.DataLoader(texture_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(texture_dataset), num_workers=args.n_threads)) color_iter = iter( data.DataLoader(color_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(color_dataset), num_workers=args.n_threads)) # Prepare network network = Net(args) network.train() network.to(device) # Training options opt_L = torch.optim.Adam(network.L_path.parameters(), lr=args.lr) opt_AB = torch.optim.Adam(network.AB_path.parameters(), lr=args.lr) opts = [opt_L, opt_AB] # Start Training for i in tqdm(range(args.max_iter)): # S1: Adjust lr and prepare data adjust_learning_rate(opts, iteration_count=i, args=args) content_l, content_ab = [x.to(device) for x in next(content_iter)] texture_l = next(texture_iter).to(device) color_l, color_ab = [x.to(device) for x in next(color_iter)] # S2: Forward l_pred, ab_pred = network(content_l, content_ab, texture_l, color_ab) # S3: Calculate loss loss_ct, loss_t = network.ct_t_loss(l_pred, content_l, texture_l) loss_cr = network.cr_loss(ab_pred, color_ab) loss_ctw = args.content_weight * loss_ct loss_tw = args.texture_weight * loss_t loss_crw = args.color_weight * loss_cr loss = loss_ctw + loss_tw + loss_crw # S4: Backward for opt in opts: opt.zero_grad() loss.backward() for opt in opts: opt.step() # S5: Summary loss and save subnets writer.add_scalar('loss_content', loss_ct.item(), i + 1) writer.add_scalar('loss_texture', loss_t.item(), i + 1) writer.add_scalar('loss_color', loss_cr.item(), i + 1) if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: state_dict = network.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].to(torch.device('cpu')) torch.save(state_dict, save_dir / 'network_iter_{:d}.pth.tar'.format(i + 1)) writer.close()
def run_train(config): print('come!') #visualizer = Visualizer(config) # create a visualizer that display/save images and plots device = 'cpu' if config.cpu or not torch.cuda.is_available() else 'cuda:0' device = torch.device(device) transfer_at = set() if config.transfer_at_encoder: transfer_at.add('encoder') if config.transfer_at_decoder: transfer_at.add('decoder') if config.transfer_at_skip: transfer_at.add('skip') save_dir = Path(config.save_dir) save_dir.mkdir(exist_ok=True, parents=True) log_dir = Path(config.log_dir) log_dir.mkdir(exist_ok=True, parents=True) writer = SummaryWriter(log_dir=str(log_dir)) vgg = net.vgg wct2 = Lap_Sob_Gaus(transfer_at=transfer_at, option_unpool=config.option_unpool, device=device, verbose=config.verbose, vgg=vgg) encoder = Lap_Sob_GausEncoder(config.option_unpool).to(device) decoder = Lap_Sob_GausDecoder(config.option_unpool).to(device) vgg.load_state_dict(torch.load(config.vgg)) vgg = nn.Sequential(*list(vgg.children())[:31]) network = net.Net(encoder, decoder, vgg=vgg) network.train() network.to(device) content_tf = train_transform() style_tf = train_transform() # # Data loading # transfroms = tv.transforms.Compose([ # tv.transforms.Resize(config.image_size), # tv.transforms.CenterCrop(config.image_size), # tv.transforms.ToTensor(), # tv.transforms.Lambda(lambda x: x * 255) # ]) # dataset = tv.datasets.ImageFolder(config.data_root, transfroms) # dataloader = data.DataLoader(dataset, config.batch_size) content_dataset = FlatFolderDataset(config.content_dir, content_tf) style_dataset = FlatFolderDataset(config.style_dir, style_tf) content_iter = iter( data.DataLoader(content_dataset, batch_size=config.batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=config.n_threads)) style_iter = iter( data.DataLoader(style_dataset, batch_size=config.batch_size, sampler=InfiniteSamplerWrapper(style_dataset), num_workers=config.n_threads)) # Optimizer enoptimizer = torch.optim.Adam(network.encoder.parameters(), lr=config.lr) deoptimizer = torch.optim.Adam(network.decoder.parameters(), lr=config.lr) # # Loss meter # style_meter = tnt.meter.AverageValueMeter() # content_meter = tnt.meter.AverageValueMeter() vis = Visdom(env="loss") # style = utils.get_style_data(config.style_path) # vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1)) # style = style.to(device) contet_loss, style_loss, iters = 0, 0, 0 win_c = vis.line(np.array([contet_loss]), np.array([iters]), win='content_loss') win_s = vis.line(np.array([style_loss]), np.array([iters]), win='style_loss') # for epoch in range(config.epoches): # content_meter.reset() # style_meter.reset() # for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)): # Train for i in tqdm(range(config.max_iter)): enoptimizer.zero_grad() deoptimizer.zero_grad() # x = x.to(device) # y = network(x, style) adjust_learning_rate(enoptimizer, iteration_count=i) adjust_learning_rate(deoptimizer, iteration_count=i) content_images = next(content_iter).to(device) style_images = next(style_iter).to(device) content_images.requires_grad_() style_images.requires_grad_() loss_c, loss_s = network(content_images, style_images, wct2) loss_c = config.content_weight * loss_c loss_s = config.style_weight * loss_s loss = loss_c + loss_s # optimizer.zero_grad() loss.backward() enoptimizer.step() deoptimizer.step() if i % 50 == 1: print('\n') print('iters:', i, 'loss:', loss, 'loss_c:', loss_c, 'loss_s: ', loss_s) if i % 20 == 0: iters = np.array([i]) content_loss = np.array([loss_c.item()]) style_loss = np.array([loss_s.item()]) vis.line(content_loss, iters, win_c, update='append') vis.line(style_loss, iters, win_s, update='append') writer.add_scalar('loss_content', loss_c.item(), i + 1) writer.add_scalar('loss_style', loss_s.item(), i + 1) if (i + 1) % config.save_model_interval == 0 or (i + 1) == config.max_iter: state_dict = network.decoder.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].to(torch.device('cpu')) torch.save(state_dict, save_dir / 'decoder_iter_{:d}.pth.tar'.format(i + 1)) if (i + 1) % config.save_model_interval == 0 or (i + 1) == config.max_iter: state_dict = network.encoder.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].to(torch.device('cpu')) torch.save(state_dict, save_dir / 'encoder_iter_{:d}.pth.tar'.format(i + 1)) writer.close()
def fast_train(args): """Fast training""" device = torch.device("cuda" if args.cuda else "cpu") transformer = TransformerNet().to(device) if args.model: transformer.load_state_dict(torch.load(args.model)) vgg = Vgg16(requires_grad=False).to(device) global mse_loss mse_loss = torch.nn.MSELoss() content_weight = args.content_weight style_weight = args.style_weight lr = args.lr content_transform = transforms.Compose([ transforms.Resize(args.content_size), transforms.CenterCrop(args.content_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_dataset = datasets.ImageFolder(args.content_dataset, content_transform) content_loader = DataLoader(content_dataset, batch_size=args.iter_batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_workers) content_loader = iter(content_loader) style_transform = transforms.Compose([ transforms.Resize((args.style_size, args.style_size)), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style_image = utils.load_image(args.style_image) style_image = style_transform(style_image) style_image = style_image.unsqueeze(0).to(device) features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1))) gram_style = [utils.gram_matrix(y) for y in features_style] if args.only_in: optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr) else: optimizer = Adam(transformer.parameters(), lr=lr) for i in trange(args.update_step): contents = content_loader.next()[0].to(device) features_contents = vgg(utils.normalize_batch(contents)) transformed = transformer(contents) features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight) optimizer.zero_grad() loss.backward() optimizer.step() # save model transformer.eval().cpu() style_name = os.path.basename(args.style_image).split(".")[0] save_model_filename = style_name + ".pth" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path)
def main(): parser = argparse.ArgumentParser() # Basic options parser.add_argument('--content_dir', type=str, required=True, help='Directory path to a batch of content images') parser.add_argument('--style_dir', type=str, required=True, help='Directory path to a batch of style images') parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth') # training options parser.add_argument('--save_dir', default='./experiments', help='Directory to save the model') parser.add_argument('--log_dir', default='./logs', help='Directory to save the log') parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--lr_decay', type=float, default=5e-5) parser.add_argument('--max_iter', type=int, default=160000) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--style_weight', type=float, default=1.0) #defualt 10 parser.add_argument('--content_weight', type=float, default=1.0) parser.add_argument('--n_threads', type=int, default=8) parser.add_argument('--save_model_interval', type=int, default=20000) args = parser.parse_args() # 80000iter, b_s = 1; 160000iter, b_s=4 def adjust_learning_rate(optimizer, iteration_count): """Imitating the original implementation""" lr = args.lr / (1.0 + args.lr_decay * iteration_count) for param_group in optimizer.param_groups: param_group['lr'] = lr device = torch.device('cuda') save_dir = Path(args.save_dir) save_dir.mkdir(exist_ok=True, parents=True) log_dir = Path(args.log_dir) log_dir.mkdir(exist_ok=True, parents=True) writer = SummaryWriter(log_dir=str(log_dir)) decoder = net.decoder checkpoint = torch.load("experiments\\ex3\\decoder_iter_80000.pth.tar") decoder.load_state_dict(checkpoint) vgg = net.vgg vgg.load_state_dict(torch.load(args.vgg)) vgg = nn.Sequential(*list(vgg.children())[:31]) network = net.Net(vgg, decoder) network.train() network.to(device) content_tf = train_transform() style_tf = train_transform() content_dataset = FlatFolderDataset(args.content_dir, content_tf) style_dataset = FlatFolderDataset(args.style_dir, style_tf) content_iter = iter( data.DataLoader(content_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(content_dataset), num_workers=args.n_threads)) style_iter = iter( data.DataLoader(style_dataset, batch_size=args.batch_size, sampler=InfiniteSamplerWrapper(style_dataset), num_workers=args.n_threads)) optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr) for i in tqdm(range(args.max_iter)): adjust_learning_rate(optimizer, iteration_count=i) content_images = next(content_iter).to(device) style_images = next(style_iter).to(device) loss_c, loss_s = network(content_images, style_images) loss_c = args.content_weight * loss_c loss_s = args.style_weight * loss_s loss = loss_c + loss_s optimizer.zero_grad() loss.backward() optimizer.step() writer.add_scalar('loss_content', loss_c.item(), i + 1) writer.add_scalar('loss_style', loss_s.item(), i + 1) if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: state_dict = net.decoder.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].to(torch.device('cpu')) torch.save(state_dict, save_dir / 'decoder_iter_{:d}.pth.tar'.format(i + 1)) writer.close()
#load dataset from path style_dataset = datasets.ImageFolder(root=args.style_dir, transform=train_transform) content_dataset = datasets.ImageFolder(root=args.content_dir, transform=train_transform) #setup sampler content_sampler = None # style_sampler = None if args.distributed: content_sampler = torch.utils.data.distributed.DistributedSampler(content_dataset) # style_sampler = torch.utils.data.distributed.DistributedSampler(style_dataset) #make data loader args.dist_batch_size = int(args.batch_size/torch.distributed.get_world_size()) if args.distributed else args.batch_size content_loader = torch.utils.data.DataLoader(content_dataset, sampler=content_sampler, batch_size=args.dist_batch_size, shuffle=(content_sampler is None), drop_last=True, **kwargs) style_loader = torch.utils.data.DataLoader(style_dataset, sampler=InfiniteSamplerWrapper(style_dataset), batch_size=args.dist_batch_size, **kwargs) if not os.path.exists(args.save_dir) and args.local_rank == 0: os.mkdir(args.save_dir) if not os.path.exists(args.log_dir) and args.local_rank == 0: os.mkdir(args.log_dir) #writer = SummaryWriter(log_dir=args.log_dir) #Create model object #vgg and decoder needs to be created as objects when using distributed training. decoder = model.get_decoder()