def eval(model, epoch): dataset = data.DATA(data_dir='Data_Challenge2', mode='test') dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2, shuffle=False) model.eval() model.cuda() preds = [] gts = [] masked_imgs = [] im = [] with torch.no_grad(): for idx, (masked_img, mask, gt) in enumerate(dataloader): masked_img = masked_img.cuda() gt = gt.cuda() mask = mask.cuda() pred = model(masked_img, mask) gts.append(gt.squeeze()) masked_imgs.append(masked_img.squeeze()) preds.append(pred.squeeze()) torchvision.utils.save_image(preds[idx], 'output/{}.jpg'.format(idx + 401)) for i in range(len(preds)): pred = np.array(preds[i].cpu().detach().numpy()) pred = (pred * 255).astype('uint8') pred = np.swapaxes(pred, 0, 2) pred = np.swapaxes(pred, 0, 1) img = 'Data_Challenge2/test/{}_masked.jpg'.format(401 + i) img = Image.open(img) height, width = img.size j = 'output/{}.jpg'.format(401 + i) jk = cv2.imread(j) jk = cv2.resize(jk, (height, width)) cv2.imwrite('output/{}.jpg'.format(i + 401), jk) mse_total = 0 ssim_total = 0 for i in range(len(preds)): pred = np.array(preds[i].cpu().detach().numpy()) gt = gts[i].cpu().detach().numpy() mse_total += get_mse(pred, gt) ssim_total += get_ssim(pred, gt) mse_avg = mse_total / (i + 1) ssim_avg = ssim_total / (i + 1) return mse_avg, ssim_avg
if not os.path.exists(sys.argv[1]): os.makedirs(sys.argv[1]) if not os.path.exists(sys.argv[2]): os.makedirs(sys.argv[2]) device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") net = model.Base_Network() # fixed seed manualSeed = 96 random.seed(manualSeed) torch.manual_seed(manualSeed) net.load_state_dict(torch.load('./best_model.pth.tar')) dataset = data.DATA(data_dir='Data_Challenge2', mode='test') dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2, shuffle=False) net.eval() net.cuda() preds = [] gts = [] masked_imgs = [] im = [] with torch.no_grad(): for idx, (masked_img, mask, gt) in enumerate(dataloader): masked_img = masked_img.cuda() gt = gt.cuda()
def fine_tune(device, model, model_pre_train_pth, model_fine_tune_pth): print() print("***** Start FINE-TUNING *****") print() # ------------ # Load Dunhuang Grottoes data # ------------ print("---> preparing dataloader...") # Training dataloader. Length = dataset size / batch size train_dataset = data.DATA(mode="train", train_status="finetune") dataloader_train = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=argparser.n_cpu ) print("---> length of training dataset: ", len(train_dataset)) # Load test images test_dataset = data.DATA(mode="test", train_status="test") dataloader_test = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=argparser.n_cpu ) print("---> length of test dataset: ", len(test_dataset)) # ------- # Test reconstruct # ------- # for idx, (imgs_masked, masks, gts, info) in enumerate(dataloader_test): # print("masked images shape: ", gts.shape) # print("masked images shape: ", imgs_masked.shape) # print("masked images shape: ", masks.shape) # name = str.split(info['name'][0], '_') # reconstruct = reconstruct_img.reconstruct(imgs_masked.squeeze(), int(info['Heigth']), int(info['Width']), name[0], args) # reconstruct.save('test.jpg') # te = np.asarray(reconstruct) # print(te.shape) # print(name[0]) # # gts = gts.squeeze() # gts = gts.permute(1, 2, 0).numpy() # gts = (gts * 255).astype('uint8') # ------- # Model # ------- # load model from fine-tune checkpoint if available if os.path.exists(model_fine_tune_pth): print("---> found previously saved {}, loading checkpoint and CONTINUE fine-tuning" .format(args.saved_fine_tune_name)) load_model(model, model_fine_tune_pth) # load best pre-train model and start fine-tuning elif os.path.exists(model_pre_train_pth) and args.train_mode == "w_pretrain": print("---> found previously saved {}, loading checkpoint and START fine-tuning" .format(args.saved_pre_train_name)) load_model(model, model_pre_train_pth) # freeze batch-norm params in fine-tuning if args.train_mode == "w_pretrain" and args.pretrain_epochs > 10: model.freeze() # ---------------- # Optimizer # ---------------- # Optimizer print("---> preparing optimizer...") optimizer = optim.Adam(model.parameters(), lr=argparser.LR_FT) criterion = nn.MSELoss() # Move model to device model.to(device) # ---------- # Training # ---------- print("---> start training cycle ...") with open(os.path.join(args.output_dir, "finetune_losses.csv"), "w", newline="") as csv_losses: with open(os.path.join(args.output_dir, "finetune_scores.csv"), "w", newline="") as csv_scores: writer_losses = csv.writer(csv_losses) writer_losses.writerow(["Epoch", "Iteration", "Loss"]) writer_scores = csv.writer(csv_scores) writer_scores.writerow(["Epoch", "Total Loss", "MSE", "SSIM", "Final Score"]) iteration = 0 highest_final_score = 0.0 # the higher the better, combines mse and ssim for epoch in range(args.finetune_epochs): model.train() loss_sum = 0 # store accumulated loss for one epoch for idx, (imgs_masked, masks, gts) in enumerate(dataloader_train): # Move to device imgs_masked = imgs_masked.to(device) # (N, 3, H, W) masks = masks.to(device) # (N, 1, H, W) gts = gts.to(device) # (N, 3, H, W) #print("masked images shape: ",imgs_masked.shape) #print("masks shape: ",masks.shape) #print("target images shape: ",gts.shape) # Model forward path => predicted images preds = model(imgs_masked, masks) original_pixels = torch.mul(masks, imgs_masked) ones = torch.ones(masks.size()).cuda() reversed_masks = torch.sub(ones, masks) predicted_pixels = torch.mul(reversed_masks, preds) preds = torch.add(original_pixels, predicted_pixels) # Calculate total loss #train_loss = loss.total_loss(preds, gts) train_loss = criterion(preds, gts) # Execute Back-Propagation optimizer.zero_grad() train_loss.backward() optimizer.step() print("\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f]" % (epoch + 1, args.finetune_epochs, (idx + 1), len(dataloader_train), train_loss), end="") loss_sum += train_loss.item() writer_losses.writerow([epoch+1, iteration+1, train_loss.item()]) iteration += 1 # ------------------ # Evaluate & Save Model # ------------------ if (epoch+1) % args.val_epoch == 0: mse, ssim = test.test(args, model, device, dataloader_test, mode="validate") final_score = 1 - mse / 100 + ssim print("\nMetrics on test set @ epoch {}:".format(epoch+1)) print("-> Average MSE: {:.5f}".format(mse)) print("-> Average SSIM: {:.5f}".format(ssim)) print("-> Final Score: {:.5f}".format(final_score)) if final_score > highest_final_score: save_model(model, model_fine_tune_pth) highest_final_score = final_score writer_scores.writerow([epoch+1, loss_sum, mse, ssim, final_score]) save_model(model, os.path.join(args.model_dir_fine_tune, "Net_finetune_epoch{}.pth.tar".format(epoch+1))) if epoch > 0: remove_prev_model(os.path.join(args.model_dir_fine_tune, "Net_finetune_epoch{}.pth.tar".format(epoch))) print("\n***** Fine-tuning FINISHED *****")
def pre_train(device, model, model_pre_train_pth): print() print("***** PRE-TRAINING *****") print() # ------------ # Load Places2 data # ------------ print("---> preparing dataloader...") # Training dataloader. Length = dataset size / batch size train_dataset = data.DATA(mode="train", train_status="pretrain") dataloader_train = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=argparser.n_cpu ) print("---> length of training dataset: ", len(train_dataset)) # Load test images test_dataset = data.DATA(mode="test", train_status="test") dataloader_test = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=argparser.n_cpu ) print("---> length of test dataset: ", len(test_dataset)) # ------- # Model # ------- # load model from checkpoint if available if os.path.exists(model_pre_train_pth): print("---> Found previously saved {}, loading checkpoint and CONTINUE pre-training" .format(args.saved_pre_train_name)) load_model(model, model_pre_train_pth) else: print("---> Start pre-training from scratch: no checkpoint found") # ---------------- # Optimizer # ---------------- # Optimizer print("---> preparing optimizer...") optimizer = optim.Adam(model.parameters(), lr=argparser.LR) criterion = nn.MSELoss() # Move model to device model.to(device) # ---------- # Training # ---------- print("---> start training cycle ...") with open(os.path.join(args.output_dir, "pretrain_losses.csv"), "w", newline="") as csv_losses: with open(os.path.join(args.output_dir, "pretrain_scores.csv"), "w", newline="") as csv_scores: writer_losses = csv.writer(csv_losses) writer_losses.writerow(["Epoch", "Iteration", "Loss"]) writer_scores = csv.writer(csv_scores) writer_scores.writerow(["Epoch", "Total Loss", "MSE", "SSIM", "Final Score"]) highest_final_score = 0.0 # the higher the better, combines mse and ssim iteration = 0 for epoch in range(args.pretrain_epochs): model.train() loss_sum = 0 # store accumulated loss for one epoch for idx, (imgs_masked, masks, gts) in enumerate(dataloader_train): # Move to device imgs_masked = imgs_masked.to(device) # (N, 3, H, W) masks = masks.to(device) # (N, 3, H, W) gts = gts.to(device) # (N, 3, H, W) #print("masked images shape: ",imgs_masked.shape) #torch.Size([32, 3, 256, 256]) #print("masks shape: ",masks.shape) #torch.Size([32, 1, 256, 256]) #print("target images shape: ",gts.shape) #torch.Size([32, 3, 256, 256]) # Model forward path => predicted images preds = model(imgs_masked, masks) original_pixels = torch.mul(masks, imgs_masked) ones = torch.ones(masks.size()).cuda() reversed_masks = torch.sub(ones,masks) predicted_pixels = torch.mul(reversed_masks, preds) preds = torch.add(original_pixels, predicted_pixels) # Calculate total loss #train_loss = loss.total_loss(preds, gts) train_loss = criterion(preds, gts) # Execute Back-Propagation optimizer.zero_grad() train_loss.backward() optimizer.step() print("\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f]" % (epoch + 1, args.pretrain_epochs, (idx + 1), len(dataloader_train), train_loss), end="") loss_sum += train_loss.item() writer_losses.writerow([epoch+1, iteration+1, train_loss.item()]) iteration += 1 # ------------------ # Evaluate & Save Model # ------------------ if (epoch + 1) % args.val_epoch == 0: mse, ssim = test.test(args, model, device, dataloader_test, mode="validate") final_score = 1 - mse / 100 + ssim print("\nMetrics on test set @ epoch {}:".format(epoch+1)) print("-> Average MSE: {:.5f}".format(mse)) print("-> Average SSIM: {:.5f}".format(ssim)) print("-> Final Score: {:.5f}".format(final_score)) if final_score > highest_final_score: save_model(model, model_pre_train_pth) highest_final_score = final_score writer_scores.writerow([epoch+1, loss_sum, mse, ssim, final_score]) save_model(model, os.path.join(args.model_dir_pre_train, "Net_pretrain_epoch{}.pth.tar".format(epoch + 1))) if epoch > 0: remove_prev_model(os.path.join(args.model_dir_pre_train, "Net_pretrain_epoch{}.pth.tar".format(epoch))) print("\n***** Pre-Training FINISHED *****")
# print(model) # print(list(model.parameters())) # SET paths to best models model_pre_train_pth = os.path.join(args.model_dir_pre_train, args.saved_pre_train_name) model_fine_tune_pth = os.path.join(args.model_dir_fine_tune, args.saved_fine_tune_name) # ------- # Test Evaluate # ------ # checkpoint = torch.load('Net_best_fine_tune.pth.tar', map_location='cpu') # model.load_state_dict(checkpoint) # # Load test images test_dataset = data.DATA(mode="test", train_status="TA") dataloader_test = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=argparser.n_cpu) train_dataset = data.DATA(mode="train", train_status="x") dataloader_train = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=argparser.n_cpu) # # # # for idx, (imgs_masked, masks, gts, _) in enumerate(dataloader_test): # # print(imgs_masked.size())