style_loss = 0. for m in range(len(features_y)): gram_s = gram_style[m] gram_y = gram_matrix(features_y[m]) style_loss += args.style_weight * loss(gram_y, gram_s.expand_as(gram_y)) total_loss = content_loss + style_loss + reg_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] agg_reg_loss += reg_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "[{}/{}] content: {:.6f} style: {:.6f} reg: {:.6f} total: {:.6f}".format( count, len(train_dataset), agg_content_loss / count, agg_style_loss / count, agg_reg_loss / count, (agg_content_loss + agg_style_loss + agg_reg_loss) / count) print(mesg) # save model transformer.eval() if torch.cuda.is_available(): transformer.cpu() model_file = 'model_' + str(epoch) + '.pth' torch.save(transformer.state_dict(), model_file) print('\nSaved model to ' + model_file + '.')
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(utils.preprocess_batch(x)) if args.cuda: x = x.cuda() y = transformer(x) xc = Variable(x.data.clone(), volatile=True) utils.subtract_imagenet_mean_batch(y) utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) features_xc = vgg(xc) f_xc_c = Variable(features_xc[1].data, requires_grad=False) content_loss = args.content_weight * mse_loss( features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss( gram_y, gram_s[:n_batch, :, :]) total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) niter = e * len(train_dataset) + batch_id writer.add_scalar('content loss', agg_content_loss / (batch_id + 1), niter) writer.add_scalar('style loss', agg_style_loss / (batch_id + 1), niter) writer.add_scalar( 'total loss', agg_content_loss / (agg_content_loss + agg_style_loss) / (batch_id + 1), niter) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} class RGB2YUV(object): def __call__(self, img): import numpy as np import cv2 npimg = np.array(img) yuvnpimg = cv2.cvtColor(npimg, cv2.COLOR_RGB2YUV) pilimg = Image.fromarray(yuvnpimg) return pilimg transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), RGB2YUV(), transforms.ToTensor(), # transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet(in_channels=1, out_channels=2) # input: Y, predict: UV optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() # vgg = Vgg16() # utils.init_vgg16(args.vgg_model_dir) # vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) transformer = nn.DataParallel(transformer) if args.cuda: if not torch.cuda.is_available(): raise RuntimeError( "CUDA is requested, but related driver/device is not set properly." ) transformer.cuda() for e in range(args.epochs): transformer.train() # agg_content_loss = 0. # agg_style_loss = 0. count = 0 for batch_id, (imgs, _) in enumerate(train_loader): n_batch = len(imgs) count += n_batch optimizer.zero_grad() # First channel x = imgs[:, :1, :, :].clone() # Second and third channels gt = imgs[:, 1:, :, :].clone() if args.cuda: x = x.cuda() gt = gt.cuda() y = transformer(x) total_loss = mse_loss(y, gt) total_loss.backward() optimizer.step() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), total_loss / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" os.makedirs(args.save_model_dir, exist_ok=True) save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} transform = transforms.Compose([transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet() if (args.premodel != ""): transformer.load_state_dict(torch.load(args.premodel)) print("load pretrain model:"+args.premodel) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) style_v = utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] hori=0 writer = SummaryWriter(args.logdir,comment=args.logdir) for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. agg_cate_loss = 0. agg_cam_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(utils.preprocess_batch(x)) if args.cuda: x = x.cuda() y = transformer(x) xc = Variable(x.data.clone(), volatile=True) #print(y.size()) #(4L, 3L, 224L, 224L) # Calculate focus loss and category loss y_cam = utils.depreprocess_batch(y) y_cam = utils.subtract_mean_std_batch(y_cam) xc_cam = utils.depreprocess_batch(xc) xc_cam = utils.subtract_mean_std_batch(xc_cam) del features_blobs[:] logit_x = net(xc_cam) logit_y = net(y_cam) label=[] cam_loss = 0 for i in range(len(xc_cam)): h_x = F.softmax(logit_x[i]) probs_x, idx_x = h_x.data.sort(0, True) label.append(idx_x[0]) h_y = F.softmax(logit_y[i]) probs_y, idx_y = h_y.data.sort(0, True) x_cam = returnCAM(features_blobs[0][i], weight_softmax, idx_x[0]) x_cam = Variable(x_cam.data,requires_grad = False) y_cam = returnCAM(features_blobs[1][i], weight_softmax, idx_y[0]) cam_loss += mse_loss(y_cam, x_cam) #the focus loss cam_loss *= 80 #the category loss label = Variable(torch.LongTensor(label),requires_grad = False).cuda() cate_loss = 10000 * torch.nn.CrossEntropyLoss()(logit_y,label) y = utils.subtract_imagenet_mean_batch(y) xc = utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) features_xc = vgg(xc) #f_xc_c = Variable(features_xc[1].data, requires_grad=False) #content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c) f_xc_c = Variable(features_xc[2].data, requires_grad=False) content_loss = args.content_weight * mse_loss(features_y[2], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :]) #add the total four loss and backward total_loss = style_loss + content_loss + cam_loss + cate_loss total_loss.backward() optimizer.step() #something for display agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] agg_cate_loss += cate_loss.data[0] agg_cam_loss += cam_loss.data[0] writer.add_scalar("Loss_Cont", agg_content_loss / (batch_id + 1), hori) writer.add_scalar("Loss_Style", agg_style_loss / (batch_id + 1), hori) writer.add_scalar("Loss_CAM", agg_cam_loss / (batch_id + 1), hori) writer.add_scalar("Loss_Cate", agg_cate_loss / (batch_id + 1), hori) hori += 1 if (batch_id + 1) % args.log_interval == 0: mesg = "{}Epoch{}:[{}/{}] content:{:.2f} style:{:.2f} cate:{:.2f} cam:{:.2f} total:{:.2f}".format( time.strftime("%a %H:%M:%S"),e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), agg_cate_loss / (batch_id + 1), agg_cam_loss / (batch_id + 1), (agg_content_loss + agg_style_loss + agg_cate_loss + agg_cam_loss ) / (batch_id + 1) ) print(mesg) if (batch_id + 1) % 2500 == 0: transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(e+1) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) transformer.cuda() transformer.train() print("saved at ",count) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) writer.close() print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 12, 'pin_memory': False} else: kwargs = {} from transform.color_op import Linearize, SRGB2XYZ, XYZ2CIE RGB2YUV = transforms.Compose([ Linearize(), SRGB2XYZ(), XYZ2CIE() ]) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), RGB2YUV(), transforms.ToTensor(), # transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) transformer = TransformerNet(in_channels=2, out_channels=1) # input: LS, predict: M optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() transformer = nn.DataParallel(transformer) if args.cuda: if not torch.cuda.is_available(): raise RuntimeError("CUDA is requested, but related driver/device is not set properly.") transformer.cuda() for e in range(args.epochs): transformer.train() # agg_content_loss = 0. # agg_style_loss = 0. count = 0 for batch_id, (imgs, _) in enumerate(train_loader): n_batch = len(imgs) count += n_batch optimizer.zero_grad() # First channel x = torch.cat([imgs[:, :1, :, :].clone(), imgs[:, -1:, :, :].clone()], dim=1) # Second and third channels gt = imgs[:, 1:2, :, :].clone() if args.cuda: x = x.cuda() gt = gt.cuda() y = transformer(x) total_loss = mse_loss(y, gt) total_loss.backward() optimizer.step() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), total_loss / (batch_id + 1) ) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" os.makedirs(args.save_model_dir, exist_ok=True) save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': False} else: kwargs = {} training_set = np.loadtxt(args.dataset, dtype=np.float32) training_set_size = training_set.shape[1] num_batch = int(training_set_size / args.batch_size) transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() style = np.loadtxt(args.style_image, dtype=np.float32) style = style.reshape((1, 1, args.style_size_x, args.style_size_y)) style = torch.from_numpy(style) style = style.repeat(args.batch_size, 3, 1, 1) if args.cuda: style = style.cuda() style_v = Variable(style, volatile=True) style_v = utils.subtract_imagenet_mean_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] # Hard data if args.hard_data: hard_data = np.loadtxt(args.hard_data_file) # if not isinstance(hard_data[0], list): # hard_data = [hard_data] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 # for batch_id, (x, _) in enumerate(train_loader): for batch_id in range(num_batch): x = training_set[:, batch_id * args.batch_size:(batch_id + 1) * args.batch_size] n_batch = x.shape[1] count += n_batch x = x.transpose() x = x.reshape((n_batch, 1, args.image_size_x, args.image_size_y)) # plt.imshow(x[0,:,:,:].squeeze(0)) # plt.show() x = torch.from_numpy(x).float() optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) if args.hard_data: hard_data_loss = 0 num_hard_data = 0 for hd in hard_data: hard_data_loss += args.hard_data_weight * ( y[:, 0, hd[1], hd[0]] - hd[2] * 255.0).norm()**2 / n_batch num_hard_data += 1 hard_data_loss /= num_hard_data y = y.repeat(1, 3, 1, 1) # x = Variable(utils.preprocess_batch(x)) # xc = x.data.clone() # xc = xc.repeat(1, 3, 1, 1) # xc = Variable(xc, volatile=True) y = utils.subtract_imagenet_mean_batch(y) # xc = utils.subtract_imagenet_mean_batch(xc) features_y = vgg(y) # features_xc = vgg(xc) # f_xc_c = Variable(features_xc[1].data, requires_grad=False) # content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = Variable(gram_style[m].data, requires_grad=False) gram_y = utils.gram_matrix(features_y[m]) style_loss += args.style_weight * mse_loss( gram_y, gram_s[:n_batch, :, :]) # total_loss = content_loss + style_loss total_loss = style_loss if args.hard_data: total_loss += hard_data_loss total_loss.backward() optimizer.step() # agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: if args.hard_data: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\thard_data: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, num_batch, agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), hard_data_loss.data[0], (agg_content_loss + agg_style_loss) / (batch_id + 1)) else: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, num_batch, agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) # save model transformer.eval() transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): # make sure each time we train, if args.seed stays the same, then # the random number we get is same as last time we train. np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) # 0-1 to 0-255 ]) # note the order: give where the images at; load the images and transform; give the batch size train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # TODO: in transformernet transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() # TODO: relus in vgg16 vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) # style2 = utils.load_image(args.style_image2, size=args.style_size) style = style_transform(style) # style2 = style_transform(style2) # repeat the style tensor 4 times style = style.repeat(args.batch_size, 1, 1, 1) # style2 = style2.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() # style2 = style2.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) # style_v2 = Variable(style2) # style_v2 = utils.normalize_batch(style_v2) # features_style2 = vgg(style_v2) # to determine style loss, make use of gram matrix gram_style = [utils.gram_matrix(y) for y in features_style] # gram_style2 = [utils.gram_matrix(y) for y in features_style2] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() # pytorch accumulates gradients, making them zero for each minibatch x = Variable(x) if args.cuda: x = x.cuda() # forward pass y = transformer(x) # after transformer - y y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # TODO: mse_loss of which relu could be modified content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) # style_loss += mse_loss(gm_y, gm_s2[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss # backward pass total_loss.backward() # this simply computes the gradients for each learnable parameters # update weights optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: msg = "Epoch "+str(e + 1)+" "+str(count)+"/"+str(len(train_dataset)) msg += " content loss : "+str(agg_content_loss / (batch_id + 1)) msg += " style loss : " +str(agg_style_loss / (batch_id + 1)) msg += " total loss : " +str((agg_content_loss + agg_style_loss) / (batch_id + 1)) print(msg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print("Loading data") transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) print "Building the model" transformer = TransformerNet() optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1) if args.cuda: transformer.cuda() vgg.cuda() style = style.cuda() style_v = Variable(style) style_v = utils.normalize_batch(style_v) features_style = vgg(style_v) gram_style = [utils.gram_matrix(y) for y in features_style] def multiply(loss, weight): return loss * weight def add(loss1, loss2): return loss1 + loss2 metrics_names = ['Content Loss', 'Style Loss', 'Total Loss'] with missinglink_project.create_experiment( transformer, display_name='Style Transfer PyTorch', optimizer=optimizer, train_data_object=train_loader, metrics={metrics_names[0]: multiply, metrics_names[1]: multiply, metrics_names[2]: add} ) as experiment: (wrapped_content_loss, wrapped_style_loss, wrapped_total_loss) = [experiment.metrics[metric_name] for metric_name in metrics_names] print("Starting to train") for e in experiment.epoch_loop(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in experiment.batch_loop(iterable=train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = Variable(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2) content_loss = wrapped_content_loss(content_loss, args.content_weight) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss = wrapped_style_loss(style_loss, args.style_weight) total_loss = wrapped_total_loss(content_loss, style_loss) total_loss.backward() optimizer.step() agg_content_loss += content_loss.data[0] agg_style_loss += style_loss.data[0] if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval() if args.cuda: transformer.cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) if args.cuda: transformer.cuda() transformer.train() # save model transformer.eval() if args.cuda: transformer.cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): serialNumFile = "serialNum.txt" serial = 0 if os.path.isfile(serialNumFile): with open(serialNumFile, "r") as t: serial = int(t.read()) serial += 1 with open(serialNumFile, "w") as t: t.write(str(serial)) if args.mysql: cnx = mysql.connector.connect(user='******', database='midburn', password='******') cursor = cnx.cursor() location = args.dataset.split("/") if location[-1] == "": location = location[-2] else: location = location[-1] save_model_filename = str(serial) + "_" + extractName( args.style_image) + "_" + str(args.epochs) + "_" + str( int(args.content_weight)) + "_" + str(int( args.style_weight)) + "_size_" + str( args.image_size) + "_dataset_" + str(location) + ".model" print(save_model_filename) np.random.seed(args.seed) torch.manual_seed(args.seed) m_epoch = 0 if args.cuda: torch.cuda.manual_seed(args.seed) #kwargs = {'num_workers': 0, 'pin_memory': False} kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) transformer = TransformerNet() #transformer = ResNeXtNet() transformer_type = transformer.__class__.__name__ optimizer = Adam(transformer.parameters(), args.lr) if args.l1: loss_criterion = torch.nn.L1Loss() else: loss_criterion = torch.nn.MSELoss() loss_type = loss_criterion.__class__.__name__ if args.visdom: vis = VisdomLinePlotter("Style Transfer: " + transformer_type) else: vis = None vgg = Vgg16() utils.init_vgg16(args.vgg_model_dir) vgg.load_state_dict( torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight"))) if args.cuda: transformer.cuda() vgg.cuda() if args.model is not None: transformer.load_state_dict(torch.load(args.model)) save_model_filename = save_model_filename + "@@@@@@" + str( int(getEpoch(args.model)) + int(args.epochs)) m_epoch += int(getEpoch(args.model)) print("loaded model\n") for param in vgg.parameters(): param.requires_grad = False with torch.no_grad(): style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size) style = style.repeat(args.batch_size, 1, 1, 1) style = utils.preprocess_batch(style) if args.cuda: style = style.cuda() style = utils.subtract_imagenet_mean_batch(style) features_style = vgg(style) gram_style = [utils.gram_matrix(y) for y in features_style] del features_style del style # TODO: scheduler and style-loss criterion unused at the moment scheduler = StepLR(optimizer, step_size=15000 // args.batch_size) style_loss_criterion = torch.nn.CosineSimilarity() total_count = 0 if args.mysql: q1 = ("REPLACE INTO `images`(`name`) VALUES ('" + args.style_image + "')") cursor.execute(q1) cnx.commit() imgId = cursor.lastrowid for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch total_count += n_batch optimizer.zero_grad() x = utils.preprocess_batch(x) if args.cuda: x = x.cuda() y = transformer(x) y = utils.subtract_imagenet_mean_batch(y) xc = utils.subtract_imagenet_mean_batch(x) features_y = vgg(y) f_xc_c = vgg.content_features(xc) content_loss = args.content_weight * loss_criterion( features_y[1], f_xc_c) style_loss = 0. for m in range(len(features_y)): gram_s = gram_style[m] gram_y = utils.gram_matrix(features_y[m]) style_loss += loss_criterion(gram_y, gram_s[:n_batch, :, :]) #style_loss -= style_loss_criterion(gram_y, gram_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() # TODO: enable #scheduler.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: if args.mysql: q1 = ( "REPLACE INTO `statistics`(`imgId`,`epoch`, `iteration_id`, `content_loss`, `style_loss`, `loss`) VALUES (" + str(imgId) + "," + str(int(e) + m_epoch) + "," + str(batch_id) + "," + str(agg_content_loss / (batch_id + 1)) + "," + str(agg_style_loss / (batch_id + 1)) + "," + str( (agg_content_loss + agg_style_loss) / (batch_id + 1)) + ")") cursor.execute(q1) cnx.commit() mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) sys.stdout.flush() print(mesg) if vis is not None: vis.plot(loss_type, "Content Loss", total_count, content_loss.item()) vis.plot(loss_type, "Style Loss", total_count, style_loss.item()) vis.plot(loss_type, "Total Loss", total_count, total_loss.item()) # save model transformer.eval() transformer.cpu() save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)