def style_loss(features_frame_style, gram_style, batch_size): mse_loss = torch.nn.MSELoss() style_loss = 0. for ft_frame_style, gm_s in zip(features_frame_style, gram_style): # loop on feature layers gm_frame_style = utils.gram_matrix(ft_frame_style) style_loss += mse_loss(gm_frame_style, gm_s[:batch_size, :, :]) return style_loss
def train(dataset_path, style_image_path, save_model_dir, has_cuda, epochs=2, image_limit=None, checkpoint_model_dir=None, image_size=256, style_size=None, seed=42, content_weight=1, style_weight=10, temporal_weight=10, tv_weight=1e-3, lr=1e-3, log_interval=500, checkpoint_interval=2000, model_filename="myModel"): device = torch.device("cuda" if has_cuda else "cpu") np.random.seed(seed) torch.manual_seed(seed) batch_size = 1 # needs to be 1, batch is created using MyDataSet loss_list = [] loss_filename = model_filename + '_losses.txt' transform = transforms.Compose([ transforms.Resize((image_size, image_size)), # transforms.Resize(image_size), # transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # videos_list = os.listdir(dataset_path) # train_dataset = {} # train_loader = {} # for video_name in videos_list: # video_dataset_path = os.path.join(dataset_path, video_name) # train_dataset[video_name] = MyDataSet(video_dataset_path, transform) # train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size) # video_dataset_path = os.path.join(dataset_path, "Monkaa") # dataset_path = "Data/Monkaa" train_dataset_path = os.path.join(dataset_path, "frames_cleanpass") flow_path = os.path.join(dataset_path, "optical_flow_resized") train_dataset = MyDataSet( train_dataset_path, flow_path, transform, image_limit=image_limit) # remove if using all datasets train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) transformer_net = TransformerNet().to(device) optimizer = Adam(transformer_net.parameters(), lr) vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) style_image = utils.load_image(style_image_path, size=style_size) style_image = style_transform(style_image) style_image = style_image.repeat(batch_size, 1, 1, 1).to(device) features_style = vgg(style_image) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(epochs): batch_num = 0 # for video_name in videos_list: for frames_and_flow in tqdm(train_loader): (frames_curr_lr, flow_lr, frames_next_lr) = frames_and_flow batch_num += 1 for i in [0, 1]: # Left, Right to_save = (batch_num + 1) % (checkpoint_interval / 4) == 0 frame_curr = frames_curr_lr[i] frame_next = frames_next_lr[i] flow = flow_lr[i] batch_size = len(frame_curr) optimizer.zero_grad() frame_curr_to_save = frame_curr.permute(2, 3, 1, 0).squeeze(3) frame_next_to_save = frame_next.permute(2, 3, 1, 0).squeeze(3) namefile_frame_curr = 'test_images/frame_curr/frame_curr_epo' + str( e) + 'batch_num' + str(batch_num) + '.png' namefile_frame_next = 'test_images/frame_next/frame_next_epo' + str( e) + 'batch_num' + str(batch_num) + '.png' if to_save: utils.save_image_loss(frame_curr_to_save, namefile_frame_curr) utils.save_image_loss(frame_next_to_save, namefile_frame_next) frame_curr = frame_curr.to(device) frame_next = frame_next.to(device) frame_style = transformer_net(frame_curr) frame_next_style = transformer_net(frame_next) # TODO: input frames to net as batch (frame_curr, frame_next) features_frame = vgg(frame_curr) features_frame_style = vgg(frame_style) # print(frame_curr.shape) content_loss = losses.content_loss(features_frame, features_frame_style) style_loss = losses.style_loss(features_frame_style, gram_style, batch_size) if to_save: temporal_loss = losses.temporal_loss(frame_style, frame_next_style, flow, device, to_save=to_save, batch_num=batch_num, e=e) else: temporal_loss = losses.temporal_loss( frame_style, frame_next_style, flow, device) tv_loss = losses.tv_loss(frame_curr) total_loss = (content_weight * content_loss + style_weight * style_loss + temporal_weight * temporal_loss) total_loss.backward() optimizer.step() if ( batch_num + 1 ) % log_interval == 0: # TODO: Choose between TQDM and printing mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttemporal: {:.6f}" \ "\ttotal: {:.6f}".format( time.ctime(), e + 1, batch_num + 1, len(train_dataset), content_loss.item(), style_loss.item(), temporal_loss.item(), total_loss.item() ) # print(mesg) losses_string = (str(content_loss.item()) + "," + str(style_loss.item()) + "," + str(temporal_loss.item()) + "," + str(total_loss.item())) loss_list.append(losses_string) utils.save_loss_file(loss_list, loss_filename) if (checkpoint_model_dir is not None and (batch_num + 1) % checkpoint_interval == 0): transformer_net.eval().cpu() ckpt_model_filename = (model_filename + "_ckpt_epoch_" + str(e + 1) + "_batch_id_" + str(batch_num + 1) + ".pth") ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename) torch.save(transformer_net.state_dict(), ckpt_model_path) utils.save_loss_file(loss_list, loss_filename) transformer_net.to(device).train() # save model transformer_net.eval().cpu() # save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( # content_weight) + "_" + str(style_weight) + ".model" save_model_path = os.path.join(save_model_dir, model_filename + ".pth") torch.save(transformer_net.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(dataset_path, style_image_path, save_model_dir, has_cuda, epochs=2, image_limit=None, checkpoint_model_dir=None, image_size=256, style_size=None, seed=42, content_weight=1e5, style_weight=1e10, lr=1e-3, log_interval=500, checkpoint_interval=2000): device = torch.device("cuda" if has_cuda else "cpu") np.random.seed(seed) torch.manual_seed(seed) batch_size = 1 # needs to be 1, batch is created using MyDataSet transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # videos_list = os.listdir(dataset_path) # train_dataset = {} # train_loader = {} # for video_name in videos_list: # video_dataset_path = os.path.join(dataset_path, video_name) # train_dataset[video_name] = MyDataSet(video_dataset_path, transform) # train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size) # video_dataset_path = os.path.join(dataset_path, "Monkaa") # dataset_path = "Data/Monkaa" train_dataset = MyDataSet( dataset_path, transform, image_limit=image_limit) # remove if using all datasets train_loader = DataLoader(train_dataset, batch_size=batch_size) transformer_net = TransformerNet().to(device) optimizer = Adam(transformer_net.parameters(), lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) style_image = utils.load_image(style_image_path, size=style_size) style_image = style_transform(style_image) style_image = style_image.repeat(batch_size, 1, 1, 1).to(device) features_style = vgg(style_image) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(epochs): batch_num = 0 # for video_name in videos_list: for frames_curr_next in tqdm(train_loader): (frames_curr, frames_next) = frames_curr_next batch_num += 1 for frame in frames_curr: # Left + Right batch_size = len(frame) optimizer.zero_grad() frame = frame.to(device) frame_style = transformer_net(frame) features_frame = vgg(frame) features_frame_style = vgg(frame_style) content_loss = losses.content_loss(features_frame, features_frame_style) style_loss = losses.style_loss(features_frame_style, gram_style, batch_size) total_loss = content_weight * content_loss + style_weight * style_loss total_loss.backward() optimizer.step() if ( batch_num + 1 ) % log_interval == 0: # TODO: Choose between TQDM and printing mesg = "\n{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, batch_num, 2 * len(train_dataset), content_loss.item(), style_loss.item(), total_loss) print(mesg) if checkpoint_model_dir is not None and ( batch_num + 1) % checkpoint_interval == 0: transformer_net.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_num + 1) + ".pth" ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename) torch.save(transformer_net.state_dict(), ckpt_model_path) transformer_net.to(device).train() # save model transformer_net.eval().cpu() # save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( # content_weight) + "_" + str(style_weight) + ".model" save_model_filename = "myModel.pth" save_model_path = os.path.join(save_model_dir, save_model_filename) torch.save(transformer_net.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss( features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(dataset_path, style_image_path, save_model_dir, has_cuda, epochs=2, image_limit=None, checkpoint_model_dir=None, image_size=(360, 640), style_size=None, seed=42, content_weight=1, style_weight=10, temporal_weight=10, tv_weight=1e-3, disp_weight=1e-3, lr=1e-3, log_interval=500, checkpoint_interval=2000, model_filename="myModel", model_init=None): device = torch.device("cuda" if has_cuda else "cpu") np.random.seed(seed) torch.manual_seed(seed) batch_size = 1 # needs to be 1, batch is created using MyDataSet loss_list = [] loss_filename = model_filename + '_losses.txt' transform = transforms.Compose([ transforms.Resize(image_size), # transforms.Resize(image_size), # transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # videos_list = os.listdir(dataset_path) # train_dataset = {} # train_loader = {} # for video_name in videos_list: # video_dataset_path = os.path.join(dataset_path, video_name) # train_dataset[video_name] = MyDataSet(video_dataset_path, transform) # train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size) # video_dataset_path = os.path.join(dataset_path, "Monkaa") # dataset_path = "Data/Monkaa" train_dataset_path = os.path.join(dataset_path, "frames_cleanpass") flow_path = os.path.join(dataset_path, "optical_flow_resized") train_dataset = MyDataSet( train_dataset_path, flow_path, transform, image_limit=image_limit) # remove if using all datasets train_loader = DataLoader(train_dataset, batch_size=batch_size) if model_init is not None: transformer_net = TransformerNet() state_dict = torch.load(model_init) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint # for k in list(state_dict.keys()): # if re.search(r'in\d+\.running_(mean|var)$', k): # del state_dict[k] transformer_net.load_state_dict(state_dict) transformer_net.to(device) else: transformer_net = TransformerNet().to(device) optimizer = Adam(transformer_net.parameters(), lr) vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) style_image = utils.load_image(style_image_path, size=style_size) style_image = style_transform(style_image) style_image = style_image.repeat(batch_size, 1, 1, 1).to(device) features_style = vgg(style_image) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(epochs): batch_num = 0 # for video_name in videos_list: for frames_and_flow in tqdm(train_loader): (frames_curr_lr, flow_lr, frames_next_lr) = frames_and_flow batch_num += 1 total_loss = 0 to_save = (batch_num + 1) % checkpoint_interval == 0 optimizer.zero_grad() frames_left_batch = torch.cat( (frames_curr_lr[0], frames_next_lr[0]), 0) frames_right_batch = torch.cat( (frames_curr_lr[1], frames_next_lr[1]), 0) frames_left_batch = frames_left_batch.to(device) frames_right_batch = frames_right_batch.to(device) frame_style_left, frame_style_right = transformer_net( frames_left_batch, frames_right_batch) # Two batches 2 x 3 x H x W frame_curr_style_combined = (frame_style_left[0, ::].unsqueeze(0), frame_style_right[0, ::].unsqueeze(0)) frame_next_style_combined = (frame_style_left[1, ::].unsqueeze(0), frame_style_right[1, ::].unsqueeze(0)) disparity_loss_l2r = losses.disparity_loss( frame_curr_style_combined[0], frame_curr_style_combined[1], disparity[0], device) disparity_loss_r2l = losses.disparity_loss( frame_curr_style_combined[1], frame_curr_style_combined[0], disparity[1], device, to_save, batch_num, e) total_loss = disp_weight * (disparity_loss_l2r + disparity_loss_r2l) # total_loss = disp_weight * disparity_loss_l2r for i in [0, 1]: # Left, Right to_save = (batch_num + 1) % checkpoint_interval == 0 frame_curr = frames_curr_lr[i] frame_next = frames_next_lr[i] flow = flow_lr[i] batch_size = len(frame_curr) frame_style = frame_curr_style_combined[i] frame_next_style = frame_next_style_combined[i] features_frame = vgg(frame_curr) features_frame_style = vgg(frame_style) content_loss = losses.content_loss(features_frame, features_frame_style) style_loss = losses.style_loss(features_frame_style, gram_style, batch_size) if to_save: temporal_loss = losses.temporal_loss(frame_style, frame_next_style, flow, device, to_save=to_save, batch_num=batch_num, e=e) else: temporal_loss = losses.temporal_loss( frame_style, frame_next_style, flow, device) tv_loss = losses.tv_loss(frame_curr) total_loss = total_loss + (content_weight * content_loss + style_weight * style_loss + temporal_weight * temporal_loss) # Save stuff: frame_curr_to_save = frame_curr.permute(2, 3, 1, 0).squeeze(3) frame_next_to_save = frame_next.permute(2, 3, 1, 0).squeeze(3) namefile_frame_curr = 'test_images/frame_curr/frame_curr_epo' + str( e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png' namefile_frame_next = 'test_images/frame_next/frame_next_epo' + str( e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png' namefile_frame_flow = 'test_images/frame_flow/frame_next_epo' + str( e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png' frame_flow, _ = utils.apply_flow(frame_curr, flow) if to_save: utils.save_image_loss(frame_curr_to_save, namefile_frame_curr) utils.save_image_loss(frame_next_to_save, namefile_frame_next) utils.save_image_loss(frame_flow, namefile_frame_flow) frame_curr = frame_curr.to(device) frame_next = frame_next.to(device) frames_batch = torch.cat((frame_curr, frame_next)) frames_style_batch = transformer_net(frames_batch) # frame_style = transformer_net(frame_curr) # frame_next_style = transformer_net(frame_next) frame_style = frames_style_batch[0, ::].unsqueeze( 0) # add batch dim frame_next_style = frames_style_batch[1, ::].unsqueeze( 0) # add batch dim features_frame = vgg(frame_curr) features_frame_style = vgg(frame_style) # print(frame_curr.shape) content_loss = losses.content_loss(features_frame, features_frame_style) style_loss = losses.style_loss(features_frame_style, gram_style, batch_size) if to_save: temporal_loss = losses.temporal_loss(frame_style, frame_next_style, flow, device, to_save=to_save, batch_num=batch_num, e=e) else: temporal_loss = losses.temporal_loss( frame_style, frame_next_style, flow, device) tv_loss = losses.tv_loss(frame_curr) total_loss = (content_weight * content_loss + style_weight * style_loss + temporal_weight * temporal_loss) total_loss.backward() optimizer.step() if ( batch_num + 1 ) % log_interval == 0: # TODO: Choose between TQDM and printing mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttemporal: {:.6f}" \ "\ttotal: {:.6f}".format( time.ctime(), e + 1, batch_num + 1, len(train_dataset), content_loss.item(), style_loss.item(), temporal_loss.item(), total_loss.item() ) # print(mesg) losses_string = (str(content_loss.item()) + "," + str(style_loss.item()) + "," + str(temporal_loss.item()) + "," + str(total_loss.item())) loss_list.append(losses_string) utils.save_loss_file(loss_list, loss_filename) if (checkpoint_model_dir is not None and (batch_num + 1) % checkpoint_interval == 0): transformer_net.eval().cpu() ckpt_model_filename = (model_filename + "_ckpt_epoch_" + str(e + 1) + "_batch_id_" + str(batch_num + 1) + ".pth") ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename) torch.save(transformer_net.state_dict(), ckpt_model_path) transformer_net.to(device).train() # save model transformer_net.eval().cpu() save_model_path = os.path.join(save_model_dir, model_filename + ".pth") torch.save(transformer_net.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)