def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_folder = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0 for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_contetn_loss + agg_style_loss) / (batch_id + 1) ) print(mesg)
transformer = TransformerNet() mse_loss = torch.nn.MSELoss() # l1_loss = torch.nn.L1Loss() if torch.cuda.is_available(): transformer.cuda() CONTENT_WEIGHT = 1e4 STYLE_WEIGHT = 1e10 LOG_INTERVAL = 200 REGULARIZATION = 1e-7 LR = 1e-4 optimizer = Adam(transformer.parameters(), LR) transformer.train() for epoch in range(3): agg_content_loss = 0. agg_style_loss = 0. agg_reg_loss = 0. count = 0 for batch_id, (x, _) in tqdm_notebook(enumerate(train_loader), total=len(train_loader)): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if torch.cuda.is_available(): x = x.cuda() y = transformer(x)
def run_train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) print('running training process...') if args.semantic == 1: print( 'multilabels semantic feedforward neural style transfer training...' ) elif args.semantic == 0: print('normal feedforward neural style transfer training...') if args.semantic == 1: loss_net, content_losses, style_losses, content_masks, n_channels = train_preparation_mask( args) elif args.semantic == 0: loss_net, content_losses, style_losses, n_channels = train_preparation( args) if args.backend == 'cudnn': torch.backends.cudnn.enabled = True transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transform_net = TransformerNet(n_channels).to(device) mse_loss = nn.MSELoss() optimizer = optim.Adam(transform_net.parameters(), lr=args.learning_rate) iteration = [0] while iteration[0] <= args.epochs - 1: transform_net.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): stloss = 0. ctloss = 0. n_batch = len(x) count += n_batch optimizer.zero_grad() #stack color_content_masks into x as input x, x_ori = x.to(device), x.to(device).clone() x = preprocess(x) x_ori = preprocess(x_ori) if args.semantic == 1: x = torch.cat((x, content_masks), 1) y = transform_net(x) #compute pixel loss if args.semantic == 1: y_pix = torch.cat((y, content_masks), 1) elif args.semantic == 0: y_pix = y pixloss = 0. if args.pixel_weight > 0: pixloss = mse_loss(x, y_pix) * args.pixel_weight #compute content loss and style loss for ctl in content_losses: ctl.mode = 'capture' loss_net(x_ori) for ctl in content_losses: ctl.mode = 'loss' for stl in style_losses: stl.mode = 'loss' loss_net(y) for ctl in content_losses: ctloss += mse_loss(ctl.input, ctl.target) * args.content_weight if args.semantic == 1: for stl in style_losses: for u in range(len(stl.color_codes)): input_msk = stl.input_masks[u].expand_as(stl.input) input_masked = torch.mul(stl.input, input_msk) input_msk_mean = torch.mean(stl.input_masks[u]) input_local_G = gram_matrix(input_masked) if input_msk_mean > 0: input_local_G.div(stl.input.nelement() * input_msk_mean) loss_local = mse_loss(input_local_G, stl.target[u]) loss_local *= input_msk_mean #larger target areas multiples smaller style weight if input_msk_mean > 0.2: stloss += loss_local * args.style_weights[0] #smaller target areas multiples larger style weight elif input_msk_mean <= 0.2: #print('aaaaa') stloss += loss_local * args.style_weights[1] elif args.semantic == 0: for stl in style_losses: gram = gram_matrix(stl.input) stloss += mse_loss(gram, stl.target) * args.style_weights[0] loss = ctloss + stloss + pixloss loss.backward() optimizer.step() agg_content_loss += ctloss.item() agg_style_loss += stloss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}, Epoch {}:\t[{}/{}], content: {:.6f}, style: {:.6f}, total: {:.6f}".format( time.ctime(), iteration[0], count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transform_net.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( iteration[0] + 1) + "_batch_id_" + str(batch_id + 1) + "_semantic_" + str( args.semantic) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transform_net.state_dict(), ckpt_model_path) transform_net.to(device).train() iteration[0] += 1 #save final model transform_net.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_content_" + str( args.content_weight) + "_style_" + str( args.style_weights[0]) + "_semantic_" + str( args.semantic) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transform_net.state_dict(), save_model_path) print("\n training process is Done!, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(utils.preprocess_batch(x)) if args.cuda: x = x.cuda() y = transformer(x) xc = Variable(x.data.clone(), volatile=True) utils.subtract_imagenet_mean_batch(y) utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) features_xc = vgg(xc) f_xc_c = Variable(features_xc[1].data, requires_grad=False) content_loss = args.content_weight * mse_loss( features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss( gram_y, gram_s[:n_batch, :, :]) total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize( args.image_size), # the shorter side is resize to match image_size transforms.CenterCrop(args.image_size), transforms.ToTensor(), # to tensor [0,1] transforms.Lambda(lambda x: x.mul(255)) # convert back to [0, 255] ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) # to provide a batch loader style_image = [f for f in os.listdir(args.style_image)] style_num = len(style_image) print(style_num) transformer = TransformerNet(style_num=style_num).to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.Resize(args.style_size), transforms.CenterCrop(args.style_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style_batch = [] for i in range(style_num): style = utils.load_image(args.style_image + style_image[i], size=args.style_size) style = style_transform(style) style_batch.append(style) style = torch.stack(style_batch).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) if n_batch < args.batch_size: break # skip to next epoch when no enough images left in the last batch of current epoch count += n_batch optimizer.zero_grad() # initialize with zero gradients batch_style_id = [ i % style_num for i in range(count - n_batch, count) ] y = transformer(x.to(device), style_id=batch_style_id) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y.to(device)) features_x = vgg(x.to(device)) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str( args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace( ':', '') + "_" + str(int(args.content_weight)) + "_" + str( int(args.style_weight)) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(): device = torch.device("cuda") np.random.seed(random_seed) torch.manual_seed(random_seed) transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), ]) train_dataset = datasets.ImageFolder(dataset_path, transform) train_loader = DataLoader(train_dataset, batch_size=batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), lr) mse_loss = torch.nn.MSELoss() if resume_TransformerNet_from_file: if os.path.isfile(TransformerNet_path): print("=> loading checkpoint '{}'".format(TransformerNet_path)) TransformerNet_par = torch.load(TransformerNet_path) for k in list(TransformerNet_par.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del TransformerNet_par[k] transformer.load_state_dict(TransformerNet_par) print("=> loaded checkpoint '{}'".format(TransformerNet_path)) else: print("=> no checkpoint found at '{}'".format(TransformerNet_path)) vgg = Vgg16(requires_grad=False).to(device) style = Image.open(style_image_path) style = transform(style) style = style.repeat(batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] model_fcrn = FCRN_for_transfer(batch_size=batch_size, requires_grad=False).to(device) model_fcrn_par = torch.load(FCRN_path) #start_epoch = model_fcrn_par['epoch'] model_fcrn.load_state_dict(model_fcrn_par['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( FCRN_path, model_fcrn_par['epoch'])) for e in range(epochs): transformer.train() agg_content_loss = 0. agg_depth_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) depth_y = model_fcrn(y) depth_x = model_fcrn(x) content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) depth_loss = depth_weight * mse_loss(depth_y, depth_x) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= style_weight total_loss = content_loss + depth_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_depth_loss += depth_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tdepth: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_depth_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if checkpoint_model_dir is not None and ( batch_id + 1) % checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( content_weight) + "_" + str(style_weight) + ".model" save_model_path = os.path.join(save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print("Loading data") transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) print "Building the model" transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] def multiply(loss, weight): return loss * weight def add(loss1, loss2): return loss1 + loss2 metrics_names = ['Content Loss', 'Style Loss', 'Total Loss'] with missinglink_project.create_experiment( transformer, display_name='Style Transfer PyTorch', optimizer=optimizer, train_data_object=train_loader, metrics={metrics_names[0]: multiply, metrics_names[1]: multiply, metrics_names[2]: add} ) as experiment: (wrapped_content_loss, wrapped_style_loss, wrapped_total_loss) = [experiment.metrics[metric_name] for metric_name in metrics_names] print("Starting to train") for e in experiment.epoch_loop(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in experiment.batch_loop(iterable=train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2) content_loss = wrapped_content_loss(content_loss, args.content_weight) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss = wrapped_style_loss(style_loss, args.style_weight) total_loss = wrapped_total_loss(content_loss, style_loss) total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} transform = transforms.Compose([transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet() if (args.premodel != ""): transformer.load_state_dict(torch.load(args.premodel)) print("load pretrain model:"+args.premodel) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) style_v = utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] hori=0 writer = SummaryWriter(args.logdir,comment=args.logdir) for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. agg_cate_loss = 0. agg_cam_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(utils.preprocess_batch(x)) if args.cuda: x = x.cuda() y = transformer(x) xc = Variable(x.data.clone(), volatile=True) #print(y.size()) #(4L, 3L, 224L, 224L) # Calculate focus loss and category loss y_cam = utils.depreprocess_batch(y) y_cam = utils.subtract_mean_std_batch(y_cam) xc_cam = utils.depreprocess_batch(xc) xc_cam = utils.subtract_mean_std_batch(xc_cam) del features_blobs[:] logit_x = net(xc_cam) logit_y = net(y_cam) label=[] cam_loss = 0 for i in range(len(xc_cam)): h_x = F.softmax(logit_x[i]) probs_x, idx_x = h_x.data.sort(0, True) label.append(idx_x[0]) h_y = F.softmax(logit_y[i]) probs_y, idx_y = h_y.data.sort(0, True) x_cam = returnCAM(features_blobs[0][i], weight_softmax, idx_x[0]) x_cam = Variable(x_cam.data,requires_grad = False) y_cam = returnCAM(features_blobs[1][i], weight_softmax, idx_y[0]) cam_loss += mse_loss(y_cam, x_cam) #the focus loss cam_loss *= 80 #the category loss label = Variable(torch.LongTensor(label),requires_grad = False).cuda() cate_loss = 10000 * torch.nn.CrossEntropyLoss()(logit_y,label) y = utils.subtract_imagenet_mean_batch(y) xc = utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) features_xc = vgg(xc) #f_xc_c = Variable(features_xc[1].data, requires_grad=False) #content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c) f_xc_c = Variable(features_xc[2].data, requires_grad=False) content_loss = args.content_weight * mse_loss(features_y[2], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :]) #add the total four loss and backward total_loss = style_loss + content_loss + cam_loss + cate_loss total_loss.backward() optimizer.step() #something for display agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] agg_cate_loss += cate_loss.data[0] agg_cam_loss += cam_loss.data[0] writer.add_scalar("Loss_Cont", agg_content_loss / (batch_id + 1), hori) writer.add_scalar("Loss_Style", agg_style_loss / (batch_id + 1), hori) writer.add_scalar("Loss_CAM", agg_cam_loss / (batch_id + 1), hori) writer.add_scalar("Loss_Cate", agg_cate_loss / (batch_id + 1), hori) hori += 1 if (batch_id + 1) % args.log_interval == 0: mesg = "{}Epoch{}:[{}/{}] content:{:.2f} style:{:.2f} cate:{:.2f} cam:{:.2f} total:{:.2f}".format( time.strftime("%a %H:%M:%S"),e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), agg_cate_loss / (batch_id + 1), agg_cam_loss / (batch_id + 1), (agg_content_loss + agg_style_loss + agg_cate_loss + agg_cam_loss ) / (batch_id + 1) ) print(mesg) if (batch_id + 1) % 2500 == 0: transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(e+1) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) transformer.cuda() transformer.train() print("saved at ",count) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) writer.close() print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): print("Number of Data : {}".format(len(train_dataset))) print("Number of Batch : {}".format(len(train_loader))) transformer.train() agg_content_loss = 0. agg_style_loss = 0. agg_tv_loss = 0. count = 0 for batch_id, (x, _) in tqdm(enumerate(train_loader)): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # Content Loss content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) # Style Loss style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight # Total Variance Loss tv_loss = 1e-7 * ( torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) total_loss = content_loss + style_loss + tv_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() agg_tv_loss += tv_loss.item() if (batch_id + 1) % args.log_interval == 0: ctime = datetime.today().strftime('%Y.%m.%d %H:%M') mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttv: {:.6f}\ttotal: {:.6f}".format( ctime, e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), agg_tv_loss / (batch_id + 1), (agg_content_loss + agg_style_loss + agg_tv_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() curr_time = datetime.today().strftime("%Y%m%d_%H%M") save_model_filename = "epoch_" + str( args.epochs) + "_" + curr_time + ".pth" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def run_train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) print('running training processing...') style_image = load_image(args.style_image, mask=False, size=args.image_style_size, scale=args.style_scale, square=True) style_image = preprocess(style_image) #save_image('style_image.png',style_image) cnn = None if args.loss_model == 'vgg19': cnn = models.vgg19(pretrained=True).features.to(device).eval() elif args.loss_model == 'vgg16': cnn = models.vgg16(pretrained=True).features.to(device).eval() # get tranform net, content losses, style losses loss_net, content_losses, style_losses, tv_loss = build_loss_model( cnn, args, style_image) #print(loss_net) #collect space back cnn = None del cnn if args.backend == 'cudnn': torch.backends.cudnn.enabled = True #this is to define the inchannels of transferm_net in_channels = 3 transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transform_net = TransformerNet(in_channels).to(device) mse_loss = nn.MSELoss() optimizer = optim.Adam(transform_net.parameters(), lr=args.learning_rate) iteration = [0] while iteration[0] <= args.epochs - 1: transform_net.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): stloss = 0. ctloss = 0. n_batch = len(x) count += n_batch optimizer.zero_grad() #stack color_content_masks into x as input x, x_ori = x.to(device), x.to(device).clone() x = preprocess(x) x_ori = preprocess(x_ori) #save_image('content_x1.png', x[3].unsqueeze(0)) #assert 0 == 1 #forward input to transform_net y = transform_net(x) #compute pixel loss pixloss = 0. if args.pixel_weight > 0: pixloss = mse_loss(x, y) * args.pixel_weight #compute content loss and style loss for ctl in content_losses: ctl.mode = 'capture' for stl in style_losses: stl.mode = 'None' loss_net(x_ori) for ctl in content_losses: ctl.mode = 'loss' for stl in style_losses: stl.mode = 'loss' loss_net(y) for ctl in content_losses: ctloss += mse_loss(ctl.target, ctl.input) * args.content_weight for stl in style_losses: local_G = gram_matrix(stl.input) stloss += mse_loss(local_G, stl.target) * args.style_weight if tv_loss is not None: tvloss = tv_loss.loss else: tvloss = 0. loss = ctloss + stloss + pixloss #+ tvloss loss.backward() optimizer.step() agg_content_loss += ctloss.item() agg_style_loss += stloss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}, Epoch {}:\t[{}/{}], content: {:.6f}, style: {:.6f}, total: {:.6f}".format( time.ctime(), iteration[0], count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transform_net.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( iteration[0] + 1) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transform_net.state_dict(), ckpt_model_path) transform_net.to(device).train() iteration[0] += 1 #save final model transform_net.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_content_" + str( args.content_weight) + "_style_" + str( args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transform_net.state_dict(), save_model_path) print("\n training process is Done!, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") # log content and style weight parameters if hvd.rank() == 0: run.log('content_weight', np.float(args.content_weight)) run.log('style_weight', np.float(args.style_weight)) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_dataset = datasets.ImageFolder(args.dataset, transform) # Horovod: partition dataset among workers using DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) transformer = TransformerNet().to(device) # Horovod: broadcast parameters from rank 0 to all other processes hvd.broadcast_parameters(transformer.state_dict(), root_rank=0) # Horovod: scale learning rate by the number of GPUs optimizer = Adam(transformer.parameters(), args.lr * hvd.size()) # Horovod: wrap optimizer with DistributedOptimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=transformer.named_parameters()) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] print("starting training...") for e in range(args.epochs): print("epoch {}...".format(e)) transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: avg_content_loss = agg_content_loss / (batch_id + 1) avg_style_loss = agg_style_loss / (batch_id + 1) avg_total_loss = (agg_content_loss + agg_style_loss) / (batch_id + 1) mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_sampler), avg_content_loss, avg_style_loss, avg_total_loss) print(mesg) # log the losses the run history run.log('avg_content_loss', np.float(avg_content_loss)) run.log('avg_style_loss', np.float(avg_style_loss)) run.log('avg_total_loss', np.float(avg_total_loss)) if hvd.rank() == 0 and args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model if hvd.rank() == 0: transformer.eval().cpu() if args.export_to_onnx: # export model to ONNX format dummy_input = torch.randn(1, 3, 1024, 1024, device='cpu') save_model_path = os.path.join(args.save_model_dir, '{}.onnx'.format(args.model_name)) torch.onnx.export(transformer, dummy_input, save_model_path) else: save_model_path = os.path.join(args.save_model_dir, '{}.pth'.format(args.model_name)) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): serialNumFile = "serialNum.txt" serial = 0 if os.path.isfile(serialNumFile): with open(serialNumFile, "r") as t: serial = int(t.read()) serial += 1 with open(serialNumFile, "w") as t: t.write(str(serial)) if args.mysql: cnx = mysql.connector.connect(user='******', database='midburn', password='******') cursor = cnx.cursor() location = args.dataset.split("/") if location[-1] == "": location = location[-2] else: location = location[-1] save_model_filename = str(serial) + "_" + extractName( args.style_image) + "_" + str(args.epochs) + "_" + str( int(args.content_weight)) + "_" + str(int( args.style_weight)) + "_size_" + str( args.image_size) + "_dataset_" + str(location) + ".model" print(save_model_filename) np.random.seed(args.seed) torch.manual_seed(args.seed) m_epoch = 0 if args.cuda: torch.cuda.manual_seed(args.seed) #kwargs = {'num_workers': 0, 'pin_memory': False} kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) transformer = TransformerNet() #transformer = ResNeXtNet() transformer_type = transformer.__class__.__name__ optimizer = Adam(transformer.parameters(), args.lr) if args.l1: loss_criterion = torch.nn.L1Loss() else: loss_criterion = torch.nn.MSELoss() loss_type = loss_criterion.__class__.__name__ if args.visdom: vis = VisdomLinePlotter("Style Transfer: " + transformer_type) else: vis = None vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() if args.model is not None: transformer.load_state_dict(torch.load(args.model)) save_model_filename = save_model_filename + "@@@@@@" + str( int(getEpoch(args.model)) + int(args.epochs)) m_epoch += int(getEpoch(args.model)) print("loaded model\n") for param in vgg.parameters(): param.requires_grad = False with torch.no_grad(): style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style = utils.subtract_imagenet_mean_batch(style) features_style = vgg(style) gram_style = [utils.gram_matrix(y) for y in features_style] del features_style del style # TODO: scheduler and style-loss criterion unused at the moment scheduler = StepLR(optimizer, step_size=15000 // args.batch_size) style_loss_criterion = torch.nn.CosineSimilarity() total_count = 0 if args.mysql: q1 = ("REPLACE INTO `images`(`name`) VALUES ('" + args.style_image + "')") cursor.execute(q1) cnx.commit() imgId = cursor.lastrowid for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch total_count += n_batch optimizer.zero_grad() x = utils.preprocess_batch(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.subtract_imagenet_mean_batch(y) xc = utils.subtract_imagenet_mean_batch(x) features_y = vgg(y) f_xc_c = vgg.content_features(xc) content_loss = args.content_weight * loss_criterion( features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = gram_style[m] gram_y = utils.gram_matrix(features_y[m]) style_loss += loss_criterion(gram_y, gram_s[:n_batch, :, :]) #style_loss -= style_loss_criterion(gram_y, gram_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() # TODO: enable #scheduler.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: if args.mysql: q1 = ( "REPLACE INTO `statistics`(`imgId`,`epoch`, `iteration_id`, `content_loss`, `style_loss`, `loss`) VALUES (" + str(imgId) + "," + str(int(e) + m_epoch) + "," + str(batch_id) + "," + str(agg_content_loss / (batch_id + 1)) + "," + str(agg_style_loss / (batch_id + 1)) + "," + str( (agg_content_loss + agg_style_loss) / (batch_id + 1)) + ")") cursor.execute(q1) cnx.commit() mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) sys.stdout.flush() print(mesg) if vis is not None: vis.plot(loss_type, "Content Loss", total_count, content_loss.item()) vis.plot(loss_type, "Style Loss", total_count, style_loss.item()) vis.plot(loss_type, "Total Loss", total_count, total_loss.item()) # save model transformer.eval() transformer.cpu() save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = "cuda" np.random.seed(args.seed) # load path of train images train_images = os.listdir(args.dataset) train_images = [ image for image in train_images if not image.endswith("txt") ] random.shuffle(train_images) images_num = len(train_images) print("dataset size: %d" % images_num) # Initialize transforemer net, optimizer, and loss function transformer = TransformerNet().to("cuda") optimizer = Adam(transformer.parameters(), args.lr) mse_loss = flow.nn.MSELoss() if args.load_checkpoint_dir is not None: state_dict = flow.load(args.load_checkpoint_dir) transformer.load_state_dict(state_dict) print("successfully load checkpoint from " + args.load_checkpoint_dir) # load pretrained vgg16 if args.vgg == "vgg19": vgg = vgg19(pretrained=True) else: vgg = vgg16(pretrained=True) vgg = VGG_WITH_FEATURES(vgg.features, requires_grad=False) vgg.to("cuda") style_image = utils.load_image(args.style_image) style_image_recover = recover_image(style_image) features_style = vgg( utils.normalize_batch(flow.Tensor(style_image).to("cuda"))) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0.0 agg_style_loss = 0.0 count = 0 for i in range(images_num): image = load_image("%s/%s" % (args.dataset, train_images[i])) n_batch = 1 count += n_batch x_gpu = flow.tensor(image, requires_grad=True).to("cuda") y_origin = transformer(x_gpu) x_gpu = utils.normalize_batch(x_gpu) y = utils.normalize_batch(y_origin) features_x = vgg(x_gpu) features_y = vgg(y) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0.0 for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() optimizer.zero_grad() agg_content_loss += content_loss.numpy() agg_style_loss += style_loss.numpy() if (i + 1) % args.log_interval == 0: if args.style_log_dir is not None: y_recover = recover_image(y_origin.numpy()) image_recover = recover_image(image) result = np.concatenate( (style_image_recover, image_recover), axis=1) result = np.concatenate((result, y_recover), axis=1) cv2.imwrite(args.style_log_dir + str(i + 1) + ".jpg", result) print(args.style_log_dir + str(i + 1) + ".jpg" + " saved") mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, images_num, agg_content_loss / (i + 1), agg_style_loss / (i + 1), (agg_content_loss + agg_style_loss) / (i + 1), ) print(mesg) if (args.checkpoint_model_dir is not None and (i + 1) % args.checkpoint_interval == 0): transformer.eval() ckpt_model_filename = ("CW_" + str(int(args.content_weight)) + "_lr_" + str(args.lr) + "ckpt_epoch" + str(e) + "_" + str(i + 1)) ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) flow.save(transformer.state_dict(), ckpt_model_path) transformer.train() # save model transformer.eval() save_model_filename = ("CW_" + str(args.content_weight) + "_lr_" + str(args.lr) + "sketch_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(" ", "_") + "_" + str(args.content_weight) + "_" + str(args.style_weight)) save_model_path = os.path.join(args.save_model_dir, save_model_filename) flow.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): log(json.dumps({"type": "status_update", "status": "Setting up training"})) device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) log(json.dumps({"type": "status_update", "status": "Loading dataset"})) train_dataset = datasets.ImageFolder(args.dataset, transform) log( json.dumps({ "type": "dataset_info", "dataset_length": len(train_dataset) * args.epochs })) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) log(json.dumps({"type": "status_update", "status": "Dataset loaded"})) log(json.dumps({"type": "status_update", "status": "Loading image"})) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] log(json.dumps({"type": "status_update", "status": "Image loaded"})) log(json.dumps({"type": "status_update", "status": "Training setup done"})) progress_count = 0 log(json.dumps({"type": "status_update", "status": "Starting training"})) for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() log( json.dumps({ "type": "training_progress", "progress": str(progress_count), "percent": str( round( progress_count / (len(train_dataset) * args.epochs) * 100, 2)) })) progress_count = progress_count + args.batch_size if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() if args.name is None: ckpt_model_filename = str( os.path.normpath(os.path.basename( args.style_image))[0:int( os.path. normpath(os.path.basename(args.style_image)). rfind("."))]) + "_" + str(batch_id + 1) + ".pth" else: ckpt_model_filename = str( args.name) + "_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() log(json.dumps({"type": "status_update", "status": "training done"})) # save model log(json.dumps({"type": "status_update", "status": "saving model"})) transformer.eval().cpu() if args.name is None: save_model_filename = str( os.path.normpath(os.path.basename(args.style_image))[0:int( os.path.normpath(os.path.basename(args.style_image)).rfind(".") )]) + ".pth" else: save_model_filename = str(args.name + ".pth") save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) log(json.dumps({"type": "status_update", "status": "model saved"}))
def train(args): if not os.path.exists(args.save_model_dir): os.makedirs(args.save_model_dir) device = torch.device("cuda" if args.is_cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) # print(style.size) # ss('yo') style = style_transform(style) # it's not transform style = style.repeat(args.batch_size, 1, 1, 1).to(device) # style = style.repeat(2,1,1,1).to(device) # print(style.shape) # print() # ss('ho') features_style = vgg(utils.normalize_batch(style)) # print(features_style.relu4_3.shape) # for i in features_style: # print(i.shape) # ss('normalize') gram_style = [utils.gram_matrix(y) for y in features_style] # for i in gram_style: # print(i.shape) # ss('main: gram style') for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) # print(n_batch) # ss('hi') count += n_batch optimizer.zero_grad() x = x.to(device) # print(x.shape) # print(x[0,0,0,:]) # ss('in epoch, batch') y = transformer(x) # ss('in epoch, batch') y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() # if (batch_id + 1) % args.log_interval == 0: if True: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.is_quickrun: if count > 10: break # if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: # transformer.eval().cpu() # ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" # ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) # torch.save(transformer.state_dict(), ckpt_model_path) # transformer.to(device).train() if (e % 50 == 0) or (e>400 and e % 10 ==0): # utils.save_image(args.save_model_dir+'/imgs/npepoch_{}.png'.format(e), y[0].detach().cpu()) # torchvision.utils.save_image(y, './imgs/epoch_{}.png'.format(e), normalize=True) torchvision.utils.save_image(y, './imgs/before/epoch_{}.png'.format(e), normalize=True) y = y.clamp(0, 255) torchvision.utils.save_image(y, './imgs/non/epoch_{}.png'.format(e)) torchvision.utils.save_image(y, './imgs/after/epoch_{}.png'.format(e), normalize=True) # ss('yo') # save model transformer.eval().cpu() save_model_filename = "style_"+args.style_name+"_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(start_epoch=0): np.random.seed(enums.seed) torch.manual_seed(enums.seed) if enums.cuda: torch.cuda.manual_seed(enums.seed) transform = transforms.Compose([ transforms.Resize(enums.image_size), transforms.CenterCrop(enums.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(enums.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=enums.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), enums.lr) if enums.subcommand == 'resume': ckpt_state = torch.load(enums.checkpoint_model) transformer.load_state_dict(ckpt_state['state_dict']) start_epoch = ckpt_state['epoch'] optimizer.load_state_dict(ckpt_state['optimizer']) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(enums.style_image, size=enums.style_size) style = style_transform(style) style = style.expand(enums.batch_size, *style.size()) # N,C,H,W if enums.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(start_epoch, enums.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if enums.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = enums.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= enums.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % enums.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if enums.checkpoint_model_dir is not None and ( e + 1) % enums.checkpoint_interval == 0: # transformer.eval() if enums.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e + 1) + ".pth" ckpt_model_path = os.path.join(enums.checkpoint_model_dir, ckpt_model_filename) save_checkpoint( { 'epoch': e + 1, 'state_dict': transformer.state_dict(), 'optimizer': optimizer.state_dict() }, ckpt_model_path) if enums.cuda: transformer.cuda() # transformer.train() # save model # transformer.eval() if enums.cuda: transformer.cpu() save_model_filename = "epoch_" + str(enums.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( enums.content_weight) + "_" + str(enums.style_weight) + ".model" save_model_path = os.path.join(enums.save_model_dir, save_model_filename) save_checkpoint( { 'epoch': e + 1, 'state_dict': transformer.state_dict(), 'optimizer': optimizer.state_dict() }, save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): # make sure each time we train, if args.seed stays the same, then # the random number we get is same as last time we train. np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) # 0-1 to 0-255 ]) # note the order: give where the images at; load the images and transform; give the batch size train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # TODO: in transformernet transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() # TODO: relus in vgg16 vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) # style2 = utils.load_image(args.style_image2, size=args.style_size) style = style_transform(style) # style2 = style_transform(style2) # repeat the style tensor 4 times style = style.repeat(args.batch_size, 1, 1, 1) # style2 = style2.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() # style2 = style2.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) # style_v2 = Variable(style2) # style_v2 = utils.normalize_batch(style_v2) # features_style2 = vgg(style_v2) # to determine style loss, make use of gram matrix gram_style = [utils.gram_matrix(y) for y in features_style] # gram_style2 = [utils.gram_matrix(y) for y in features_style2] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() # pytorch accumulates gradients, making them zero for each minibatch x = Variable(x) if args.cuda: x = x.cuda() # forward pass y = transformer(x) # after transformer - y y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # TODO: mse_loss of which relu could be modified content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) # style_loss += mse_loss(gm_y, gm_s2[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss # backward pass total_loss.backward() # this simply computes the gradients for each learnable parameters # update weights optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: msg = "Epoch "+str(e + 1)+" "+str(count)+"/"+str(len(train_dataset)) msg += " content loss : "+str(agg_content_loss / (batch_id + 1)) msg += " style loss : " +str(agg_style_loss / (batch_id + 1)) msg += " total loss : " +str((agg_content_loss + agg_style_loss) / (batch_id + 1)) print(msg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 12, 'pin_memory': False} else: kwargs = {} from transform.color_op import Linearize, SRGB2XYZ, XYZ2CIE RGB2YUV = transforms.Compose([ Linearize(), SRGB2XYZ(), XYZ2CIE() ]) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), RGB2YUV(), transforms.ToTensor(), # transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet(in_channels=2, out_channels=1) # input: LS, predict: M optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() transformer = nn.DataParallel(transformer) if args.cuda: if not torch.cuda.is_available(): raise RuntimeError("CUDA is requested, but related driver/device is not set properly.") transformer.cuda() for e in range(args.epochs): transformer.train() # agg_content_loss = 0. # agg_style_loss = 0. count = 0 for batch_id, (imgs, _) in enumerate(train_loader): n_batch = len(imgs) count += n_batch optimizer.zero_grad() # First channel x = torch.cat([imgs[:, :1, :, :].clone(), imgs[:, -1:, :, :].clone()], dim=1) # Second and third channels gt = imgs[:, 1:2, :, :].clone() if args.cuda: x = x.cuda() gt = gt.cuda() y = transformer(x) total_loss = mse_loss(y, gt) total_loss.backward() optimizer.step() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), total_loss / (batch_id + 1) ) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" os.makedirs(args.save_model_dir, exist_ok=True) save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = dataset.CustomImageDataset(args.dataset, transform=transform, img_size=args.image_size) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet(args.pad_type) transformer = transformer.train() optimizer = torch.optim.Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() #print(transformer) vgg = Vgg16() vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) vgg.eval() transformer = transformer.cuda() vgg = vgg.cuda() style = utils.tensor_load_resize(args.style_image, args.style_size) style = style.unsqueeze(0) print("=> Style image size: " + str(style.size())) #(1, H, W, C) style = utils.preprocess_batch(style).cuda() utils.tensor_save_bgrimage( style[0].detach(), os.path.join(args.save_model_dir, 'train_style.jpg'), True) style = utils.subtract_imagenet_mean_batch(style) features_style = vgg(style) gram_style = [utils.gram_matrix(y).detach() for y in features_style] for e in range(args.epochs): train_loader.dataset.reset() agg_content_loss = 0. agg_style_loss = 0. iters = 0 for batch_id, (x, _) in enumerate(train_loader): if x.size(0) != args.batch_size: print("=> Skip incomplete batch") continue iters += 1 optimizer.zero_grad() x = utils.preprocess_batch(x).cuda() y = transformer(x) if (batch_id + 1) % 1000 == 0: idx = (batch_id + 1) // 1000 utils.tensor_save_bgrimage( y.data[0], os.path.join(args.save_model_dir, "out_%d.png" % idx), True) utils.tensor_save_bgrimage( x.data[0], os.path.join(args.save_model_dir, "in_%d.png" % idx), True) y = utils.subtract_imagenet_mean_batch(y) x = utils.subtract_imagenet_mean_batch(x) features_y = vgg(y) features_x = vgg(center_crop(x, y.size(2), y.size(3))) #content target f_x = features_x[2].detach() # content f_y = features_y[2] content_loss = args.content_weight * mse_loss(f_y, f_x) style_loss = 0. for m in range(len(features_y)): gram_s = gram_style[m] gram_y = utils.gram_matrix(features_y[m]) batch_style_loss = 0 for n in range(gram_y.shape[0]): batch_style_loss += args.style_weight * mse_loss( gram_y[n], gram_s[0]) style_loss += batch_style_loss / gram_y.shape[0] total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data agg_style_loss += style_loss.data mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, batch_id + 1, len(train_loader), agg_content_loss / iters, agg_style_loss / iters, (agg_content_loss + agg_style_loss) / iters) print(mesg) agg_content_loss = agg_style_loss = 0.0 iters = 0 # save model save_model_filename = "epoch_" + str(e) + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) target_transform = transforms.ToTensor() train_dataset = VFDataset(args.dataset, transform, target_transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) if args.load_model is not None: transformer.load_state_dict(torch.load(args.load_model)) optimizer = Adam(transformer.parameters(), args.lr) # mse_loss = torch.nn.MSELoss() cosine_loss = torch.nn.CosineEmbeddingLoss() label = torch.ones(args.batch_size, 1, args.image_size, args.image_size).to(device) # log_file = open(args.log_file, "w") for e in range(args.epochs): transformer.train() agg_loss = 0. count = 0 for batch_id, (x, vf) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = utils.subtract_imagenet_mean_batch(x) x = x.to(device) y = transformer(x) vf = vf.to(device) # loss = mse_loss(y, vf) loss = cosine_loss(y, vf, label) loss.backward() optimizer.step() agg_loss += loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_loss / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} training_set = np.loadtxt(args.dataset, dtype=np.float32) training_set_size = training_set.shape[1] num_batch = int(training_set_size / args.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = np.loadtxt(args.style_image, dtype=np.float32) style = style.reshape((1, 1, args.style_size_x, args.style_size_y)) style = torch.from_numpy(style) style = style.repeat(args.batch_size, 3, 1, 1) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) style_v = utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] # Hard data if args.hard_data: hard_data = np.loadtxt(args.hard_data_file) # if not isinstance(hard_data[0], list): # hard_data = [hard_data] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 # for batch_id, (x, _) in enumerate(train_loader): for batch_id in range(num_batch): x = training_set[:, batch_id * args.batch_size:(batch_id + 1) * args.batch_size] n_batch = x.shape[1] count += n_batch x = x.transpose() x = x.reshape((n_batch, 1, args.image_size_x, args.image_size_y)) # plt.imshow(x[0,:,:,:].squeeze(0)) # plt.show() x = torch.from_numpy(x).float() optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) if args.hard_data: hard_data_loss = 0 num_hard_data = 0 for hd in hard_data: hard_data_loss += args.hard_data_weight * ( y[:, 0, hd[1], hd[0]] - hd[2] * 255.0).norm()**2 / n_batch num_hard_data += 1 hard_data_loss /= num_hard_data y = y.repeat(1, 3, 1, 1) # x = Variable(utils.preprocess_batch(x)) # xc = x.data.clone() # xc = xc.repeat(1, 3, 1, 1) # xc = Variable(xc, volatile=True) y = utils.subtract_imagenet_mean_batch(y) # xc = utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) # features_xc = vgg(xc) # f_xc_c = Variable(features_xc[1].data, requires_grad=False) # content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss( gram_y, gram_s[:n_batch, :, :]) # total_loss = content_loss + style_loss total_loss = style_loss if args.hard_data: total_loss += hard_data_loss total_loss.backward() optimizer.step() # agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: if args.hard_data: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\thard_data: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, num_batch, agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), hard_data_loss.data[0], (agg_content_loss + agg_style_loss) / (batch_id + 1)) else: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, num_batch, agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): """Meta train the model""" device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) # first move parameters to GPU transformer = TransformerNet().to(device) vgg = Vgg16(requires_grad=False).to(device) global optimizer optimizer = Adam(transformer.parameters(), args.meta_lr) global mse_loss mse_loss = torch.nn.MSELoss() content_loader, style_loader, query_loader = get_data_loader(args) content_weight = args.content_weight style_weight = args.style_weight lr = args.lr writer = SummaryWriter(args.log_dir) for iteration in trange(args.max_iter): transformer.train() # bookkeeping # using state_dict causes problems, use named_parameters instead all_meta_grads = [] avg_train_c_loss = 0.0 avg_train_s_loss = 0.0 avg_train_loss = 0.0 avg_eval_c_loss = 0.0 avg_eval_s_loss = 0.0 avg_eval_loss = 0.0 contents = content_loader.next()[0].to(device) features_contents = vgg(utils.normalize_batch(contents)) querys = query_loader.next()[0].to(device) features_querys = vgg(utils.normalize_batch(querys)) # learning rate scheduling lr = args.lr / (1.0 + iteration * 2.5e-5) meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5) for param_group in optimizer.param_groups: param_group['lr'] = meta_lr for i in range(args.meta_batch_size): # sample a style style = style_loader.next()[0].to(device) style = style.repeat(args.iter_batch_size, 1, 1, 1) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name)) for j in range(args.meta_step): # run forward transformation on contents transformed = transformer(contents, fast_weights) # compute loss features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight) # compute grad grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True) # update fast weights fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads)) avg_train_c_loss += c_loss.item() avg_train_s_loss += s_loss.item() avg_train_loss += loss.item() # run forward transformation on querys transformed = transformer(querys, fast_weights) # compute loss features_transformed = vgg(utils.standardize_batch(transformed)) loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight) grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters()) all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)}) avg_eval_c_loss += c_loss.item() avg_eval_s_loss += s_loss.item() avg_eval_loss += loss.item() writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1) writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1) # compute dummy loss to refresh buffer transformed = transformer(querys) features_transformed = vgg(utils.standardize_batch(transformed)) dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight) meta_updates(transformer, dummy_loss, all_meta_grads) if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \ str(args.content_weight) + "_" + \ str(args.style_weight) + "_" + \ str(args.lr) + "_" + \ str(args.meta_lr) + "_" + \ str(args.meta_batch_size) + "_" + \ str(args.meta_step) + "_" + \ time.ctime() + ".pth" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print "Done, trained model saved at {}".format(save_model_path)
def main(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # DATA # Transform and Dataloader for COCO dataset transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), # / 255. transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # MODEL # Define Image Transformation Network with MSE loss and Adam optimizer transformer = TransformerNet().to(device) mse_loss = nn.MSELoss() optimizer = optim.Adam(transformer.parameters(), args.learning_rate) # Pretrained VGG vgg = VGG16(requires_grad=False).to(device) # FEATURES style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) # Load the style image style = Image.open(args.style) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) # Compute the style features features_style = vgg(normalize_batch(style)) # Loop through VGG style layers to calculate Gram Matrix gram_style = [gram_matrix(y) for y in features_style] # TRAIN for epoch in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. for batch_id, (x, _) in tqdm(enumerate(train_loader), unit='batch'): x = x.to(device) n_batch = len(x) optimizer.zero_grad() # Parse throught Image Transformation network y = transformer(x) y = normalize_batch(y) x = normalize_batch(x) # Parse through VGG layers features_y = vgg(y) features_x = vgg(x) # Calculate content loss content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) # Calculate style loss style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() # Monitor if (batch_id + 1) % args.log_interval == 0: tqdm.write('[{}] ({})\t' 'content: {:.6f}\t' 'style: {:.6f}\t' 'total: {:.6f}'.format( epoch + 1, batch_id + 1, agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1))) # Checkpoint if (batch_id + 1) % args.save_interval == 0: # eval mode transformer.eval().cpu() style_name = args.style.split('/')[-1].split('.')[0] checkpoint_file = os.path.join(args.checkpoint_dir, '{}.pth'.format(style_name)) tqdm.write('Checkpoint {}'.format(checkpoint_file)) torch.save(transformer.state_dict(), checkpoint_file) # back to train mode transformer.to(device).train()
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} class RGB2YUV(object): def __call__(self, img): import numpy as np import cv2 npimg = np.array(img) yuvnpimg = cv2.cvtColor(npimg, cv2.COLOR_RGB2YUV) pilimg = Image.fromarray(yuvnpimg) return pilimg transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), RGB2YUV(), transforms.ToTensor(), # transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet(in_channels=1, out_channels=2) # input: Y, predict: UV optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() # vgg = Vgg16() # utils.init_vgg16(args.vgg_model_dir) # vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) transformer = nn.DataParallel(transformer) if args.cuda: if not torch.cuda.is_available(): raise RuntimeError( "CUDA is requested, but related driver/device is not set properly." ) transformer.cuda() for e in range(args.epochs): transformer.train() # agg_content_loss = 0. # agg_style_loss = 0. count = 0 for batch_id, (imgs, _) in enumerate(train_loader): n_batch = len(imgs) count += n_batch optimizer.zero_grad() # First channel x = imgs[:, :1, :, :].clone() # Second and third channels gt = imgs[:, 1:, :, :].clone() if args.cuda: x = x.cuda() gt = gt.cuda() y = transformer(x) total_loss = mse_loss(y, gt) total_loss.backward() optimizer.step() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), total_loss / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" os.makedirs(args.save_model_dir, exist_ok=True) save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(exp, args): device = exp.get_device() chrono = exp.chrono() transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) print(memory_size(vgg, batch_size=args.batch_size, input_size=(3, args.image_size, args.image_size)) * 4) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.repeat): transformer.train() with chrono.time('train') as t: agg_content_loss = 0. agg_style_loss = 0. for batch_id, (x, _) in enumerate(train_loader): if batch_id > args.number: break n_batch = len(x) x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) optimizer.zero_grad() features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() exp.log_batch_loss(total_loss.item()) optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() exp.log_epoch_loss(agg_content_loss + agg_style_loss) exp.show_eta(e, t) exp.report()
def train(args): if torch.cuda.is_available(): print('CUDA available, using GPU.') device = torch.device('cuda') else: print('GPU training unavailable... using CPU.') device = torch.device('cpu') np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # Image transformation network. transformer = TransformerNet() if args.model: state_dict = torch.load(args.model) transformer.load_state_dict(state_dict) transformer.to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() # Loss Network: VGG16 vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() # CUDA if available x = x.to(device) # Transform image y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) # Feature map of original image features_x = vgg(x) # Feature Map of transformed image features_y = vgg(y) # Difference between transformed image, original image. # Changed to pull from features_.relu3_3 vs .relu2_2 content_loss = args.content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3) # Compute gram matrix (dot product across each dimension G(4,3) = F4 * F3) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if True: #(batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): # 将torch.Tensor分配到的设备的对象CPU或GPU device = torch.device("cuda" if args.cuda else "cpu") # 初始化随机种子 np.random.seed(args.seed) # 为CPU设置种子用于生成随机数 torch.manual_seed(args.seed) """ 将多个transform组合起来使用 """ transform = transforms.Compose([ # 重新设定大小 transforms.Resize(args.image_size), # 将给定的Image进行中心切割 transforms.CenterCrop(args.image_size), # 把Image转成张量Tensor格式,大小范围为[0,1] transforms.ToTensor(), # 使用lambd作为转换器 transforms.Lambda(lambda x: x.mul(255)) ]) # 使用ImageFolder数据加载器,传入数据集的路径 # transform:一个函数,原始图片作为输入,返回一个转换后的图片 train_dataset = datasets.ImageFolder(args.dataset, transform) # 把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器 # batch_size:int,每个batch加载多少样本 train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # 加载模型TransformerNet到设备上 transformer = TransformerNet().to(device) # 我们选择Adam作为优化器 optimizer = Adam(transformer.parameters(), args.lr) # 均方损失函数 mse_loss = torch.nn.MSELoss() # 加载模型Vgg16到设备上 vgg = Vgg16(requires_grad=False).to(device) # 风格图片的处理 style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) # 载入风格图片 style = utils.load_image(args.style_image, size=args.style_size) # 处理风格图片 style = style_transform(style) # repeat(*sizes)沿着指定的维度重复tensor style = style.repeat(args.batch_size, 1, 1, 1).to(device) # 特征风格归一化 features_style = vgg(utils.normalize_batch(style)) # 风格特征图计算Gram矩阵 gram_style = [utils.gram_matrix(y) for y in features_style] # 迭代训练 for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch # 把梯度置零,也就是把loss关于weight的导数变成0 optimizer.zero_grad() y = transformer(x.to(device)) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x.cuda()) # 计算内容损失 content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) # 计算风格损失 style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight # 总损失 total_loss = content_loss + style_loss # 反向传播 total_loss.backward() # 更新参数 optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() # 准备打印相关信息,args.log_interval是最开头设置的好了的参数 if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) # 生成训练好的风格图片模型 and (batch_id + 1) % args.checkpoint_interval == 0 if args.checkpoint_model_dir is not None: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), # utils.RGB2LAB(), transforms.ToTensor(), # utils.LAB2Tensor(), ]) pert_transform = transforms.Compose([utils.ColorPerturb()]) trainset = utils.FlatImageFolder(args.dataset, transform, pert_transform) trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) model = TransformerNet() if args.gpus is not None: model = nn.DataParallel(model, device_ids=args.gpus) else: model = nn.DataParallel(model) if args.resume: state_dict = torch.load(args.resume) model.load_state_dict(state_dict) if args.cuda: model.cuda() optimizer = torch.optim.Adam(model.parameters(), args.lr) criterion = nn.MSELoss() start_time = datetime.now() for e in range(args.epochs): model.train() count = 0 acc_loss = 0.0 for batchi, (pert_img, ori_img) in enumerate(trainloader): count += len(pert_img) if args.cuda: pert_img = pert_img.cuda(non_blocking=True) ori_img = ori_img.cuda(non_blocking=True) optimizer.zero_grad() rec_img = model(pert_img) loss = criterion(rec_img, ori_img) loss.backward() optimizer.step() acc_loss += loss.item() if (batchi + 1) % args.log_interval == 0: mesg = '{}\tEpoch {}: [{}/{}]\ttotal loss: {:.6f}'.format( time.ctime(), e + 1, count, len(trainset), acc_loss / (args.log_interval)) print(mesg) acc_loss = 0.0 if args.checkpoint_dir and e + 1 != args.epochs: model.eval().cpu() ckpt_filename = 'ckpt_epoch_' + str(e + 1) + '.pth' ckpt_path = osp.join(args.checkpoint_dir, ckpt_filename) torch.save(model.state_dict(), ckpt_path) model.cuda().train() print('Checkpoint model at epoch %d saved' % (e + 1)) model.eval().cpu() if args.save_model_name: model_filename = args.save_model_name else: model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + ".model" model_path = osp.join(args.save_model_dir, model_filename) torch.save(model.state_dict(), model_path) end_time = datetime.now() print('Finished training after %s, trained model saved at %s' % (end_time - start_time, model_path))
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): """ Trains the models :param args: parameters :return: saves the model and checkpoints """ device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) image_transform = transforms.Compose([ transforms.Resize( args.image_size), # the shorter side is resize to match image_size transforms.CenterCrop(args.image_size), transforms.ToTensor(), # to tensor [0,1] transforms.Lambda(lambda x: x.mul(255)) # convert back to [0, 255] ]) train_dataset = datasets.ImageFolder(args.dataset, image_transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) # to provide a batch loader style_image = [f for f in os.listdir(args.style_image)] style_num = len(style_image) print(style_num) transformer = TransformerNet(style_number=style_num).to(device) adam_optimizer = Adam(transformer.parameters(), learning_rate) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.Resize(args.style_size), transforms.CenterCrop(args.style_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style_batch = [] for i in range(style_num): if ".ipynb" not in style_image[i]: style = utils.load_image(args.style_image + style_image[i], size=args.style_size) style = style_transform(style) print(style.shape, style_image[i]) style_batch.append(style) style = torch.stack(style_batch).to(device) # print("After stack") features_style = vgg(utils.normalize_batch(style)) # print("After feature style") gram_style = [utils.gram_matrix(y) for y in features_style] # print("starting epochs") for e in range(args.epochs): with open('/home/sbanda/Fall20-DL-CG/Project3/log.txt', 'a') as reader: reader.write("Epoch " + str(e) + ":->\n") transformer.train() aggregate_content_loss = 0. aggregate_style_loss = 0. counter = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) print(batch_id) if n_batch < args.batch_size: break counter += n_batch # Initialize gradients to zero adam_optimizer.zero_grad() batch_style_id = [ i % style_num for i in range(counter - n_batch, counter) ] y = transformer(x.to(device), style_id=batch_style_id) x = utils.normalize_batch(x) y = utils.normalize_batch(y) features_x = vgg(x.to(device)) features_y = vgg(y.to(device)) content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for feature_y, gm_style in zip(features_y, gram_style): gm_y = utils.gram_matrix(feature_y) style_loss += mse_loss(gm_y, gm_style[batch_style_id, :, :]) style_loss *= style_weight total_loss = content_loss + style_loss total_loss.backward() adam_optimizer.step() aggregate_content_loss += content_loss.item() aggregate_style_loss += style_loss.item() if (batch_id + 1) % log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, counter, len(train_dataset), aggregate_content_loss / (batch_id + 1), aggregate_style_loss / (batch_id + 1), (aggregate_content_loss + aggregate_style_loss) / (batch_id + 1)) with open('/home/sbanda/Fall20-DL-CG/Project3/log.txt', 'a') as reader: reader.write(mesg + "\n") print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_').replace(':', '') + "_" + str( int(content_weight)) + "_" + str(int(style_weight)) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)