Example #1
0
def eval(model, epoch):

    dataset = data.DATA(data_dir='Data_Challenge2', mode='test')
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             num_workers=2,
                                             shuffle=False)

    model.eval()
    model.cuda()
    preds = []
    gts = []
    masked_imgs = []
    im = []
    with torch.no_grad():
        for idx, (masked_img, mask, gt) in enumerate(dataloader):
            masked_img = masked_img.cuda()
            gt = gt.cuda()
            mask = mask.cuda()
            pred = model(masked_img, mask)
            gts.append(gt.squeeze())
            masked_imgs.append(masked_img.squeeze())
            preds.append(pred.squeeze())
            torchvision.utils.save_image(preds[idx],
                                         'output/{}.jpg'.format(idx + 401))
        for i in range(len(preds)):
            pred = np.array(preds[i].cpu().detach().numpy())
            pred = (pred * 255).astype('uint8')
            pred = np.swapaxes(pred, 0, 2)
            pred = np.swapaxes(pred, 0, 1)
            img = 'Data_Challenge2/test/{}_masked.jpg'.format(401 + i)
            img = Image.open(img)
            height, width = img.size
            j = 'output/{}.jpg'.format(401 + i)
            jk = cv2.imread(j)
            jk = cv2.resize(jk, (height, width))
            cv2.imwrite('output/{}.jpg'.format(i + 401), jk)

        mse_total = 0
        ssim_total = 0
        for i in range(len(preds)):
            pred = np.array(preds[i].cpu().detach().numpy())
            gt = gts[i].cpu().detach().numpy()
            mse_total += get_mse(pred, gt)
            ssim_total += get_ssim(pred, gt)
        mse_avg = mse_total / (i + 1)
        ssim_avg = ssim_total / (i + 1)

    return mse_avg, ssim_avg
Example #2
0
    if not os.path.exists(sys.argv[1]):
        os.makedirs(sys.argv[1])
    if not os.path.exists(sys.argv[2]):
        os.makedirs(sys.argv[2])
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

    net = model.Base_Network()

    # fixed seed
    manualSeed = 96
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    net.load_state_dict(torch.load('./best_model.pth.tar'))

    dataset = data.DATA(data_dir='Data_Challenge2', mode='test')
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             num_workers=2,
                                             shuffle=False)

    net.eval()
    net.cuda()
    preds = []
    gts = []
    masked_imgs = []
    im = []
    with torch.no_grad():
        for idx, (masked_img, mask, gt) in enumerate(dataloader):
            masked_img = masked_img.cuda()
            gt = gt.cuda()
Example #3
0
def fine_tune(device, model, model_pre_train_pth, model_fine_tune_pth):

    print()
    print("***** Start FINE-TUNING *****")
    print()

    # ------------
    #  Load Dunhuang Grottoes data
    # ------------

    print("---> preparing dataloader...")

    # Training dataloader. Length = dataset size / batch size
    train_dataset = data.DATA(mode="train", train_status="finetune")
    dataloader_train = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=argparser.n_cpu
    )
    print("---> length of training dataset: ", len(train_dataset))

    # Load test images
    test_dataset = data.DATA(mode="test", train_status="test")
    dataloader_test = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=argparser.n_cpu
    )
    print("---> length of test dataset: ", len(test_dataset))

    # -------
    # Test reconstruct
    # -------
    # for idx, (imgs_masked, masks, gts, info) in enumerate(dataloader_test):
    #     print("masked images shape: ", gts.shape)
    #     print("masked images shape: ", imgs_masked.shape)
    #     print("masked images shape: ", masks.shape)
    #     name = str.split(info['name'][0], '_')
    #     reconstruct = reconstruct_img.reconstruct(imgs_masked.squeeze(), int(info['Heigth']), int(info['Width']), name[0], args)
    #     reconstruct.save('test.jpg')
    #     te = np.asarray(reconstruct)
    #     print(te.shape)
    #     print(name[0])
    #
    #     gts = gts.squeeze()
    #     gts = gts.permute(1, 2, 0).numpy()
    #     gts = (gts * 255).astype('uint8')

    # -------
    # Model
    # -------

    # load model from fine-tune checkpoint if available
    if os.path.exists(model_fine_tune_pth):
        print("---> found previously saved {}, loading checkpoint and CONTINUE fine-tuning"
              .format(args.saved_fine_tune_name))
        load_model(model, model_fine_tune_pth)
    # load best pre-train model and start fine-tuning
    elif os.path.exists(model_pre_train_pth) and args.train_mode == "w_pretrain":
        print("---> found previously saved {}, loading checkpoint and START fine-tuning"
              .format(args.saved_pre_train_name))
        load_model(model, model_pre_train_pth)

    # freeze batch-norm params in fine-tuning
    if args.train_mode == "w_pretrain" and args.pretrain_epochs > 10:
        model.freeze()

    # ----------------
    #  Optimizer
    # ----------------

    # Optimizer
    print("---> preparing optimizer...")
    optimizer = optim.Adam(model.parameters(), lr=argparser.LR_FT)
    criterion = nn.MSELoss()
    # Move model to device
    model.to(device)

    # ----------
    #  Training
    # ----------

    print("---> start training cycle ...")
    with open(os.path.join(args.output_dir, "finetune_losses.csv"), "w", newline="") as csv_losses:
        with open(os.path.join(args.output_dir, "finetune_scores.csv"), "w", newline="") as csv_scores:
            writer_losses = csv.writer(csv_losses)
            writer_losses.writerow(["Epoch", "Iteration", "Loss"])

            writer_scores = csv.writer(csv_scores)
            writer_scores.writerow(["Epoch", "Total Loss", "MSE", "SSIM", "Final Score"])

            iteration = 0
            highest_final_score = 0.0   # the higher the better, combines mse and ssim

            for epoch in range(args.finetune_epochs):

                model.train()

                loss_sum = 0    # store accumulated loss for one epoch

                for idx, (imgs_masked, masks, gts) in enumerate(dataloader_train):

                    # Move to device
                    imgs_masked = imgs_masked.to(device)    # (N, 3, H, W)
                    masks = masks.to(device)                # (N, 1, H, W)
                    gts = gts.to(device)                    # (N, 3, H, W)

                    #print("masked images shape: ",imgs_masked.shape)
                    #print("masks shape: ",masks.shape)
                    #print("target images shape: ",gts.shape)

                    # Model forward path => predicted images
                    preds = model(imgs_masked, masks)

                    original_pixels = torch.mul(masks, imgs_masked)
                    ones = torch.ones(masks.size()).cuda()
                    reversed_masks = torch.sub(ones, masks)
                    predicted_pixels = torch.mul(reversed_masks, preds)
                    preds = torch.add(original_pixels, predicted_pixels)

                    # Calculate total loss
                    #train_loss = loss.total_loss(preds, gts)
                    train_loss = criterion(preds, gts)
                    # Execute Back-Propagation
                    optimizer.zero_grad()
                    train_loss.backward()
                    optimizer.step()

                    print("\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f]" %
                          (epoch + 1, args.finetune_epochs, (idx + 1), len(dataloader_train), train_loss), end="")

                    loss_sum += train_loss.item()
                    writer_losses.writerow([epoch+1, iteration+1, train_loss.item()])
                    iteration += 1

                # ------------------
                #  Evaluate & Save Model
                # ------------------

                if (epoch+1) % args.val_epoch == 0:
                    mse, ssim = test.test(args, model, device, dataloader_test, mode="validate")
                    final_score = 1 - mse / 100 + ssim
                    print("\nMetrics on test set @ epoch {}:".format(epoch+1))
                    print("-> Average MSE:  {:.5f}".format(mse))
                    print("-> Average SSIM: {:.5f}".format(ssim))
                    print("-> Final Score:  {:.5f}".format(final_score))

                    if final_score > highest_final_score:
                        save_model(model, model_fine_tune_pth)
                        highest_final_score = final_score

                    writer_scores.writerow([epoch+1, loss_sum, mse, ssim, final_score])

                save_model(model, os.path.join(args.model_dir_fine_tune, "Net_finetune_epoch{}.pth.tar".format(epoch+1)))
                if epoch > 0:
                    remove_prev_model(os.path.join(args.model_dir_fine_tune, "Net_finetune_epoch{}.pth.tar".format(epoch)))

    print("\n***** Fine-tuning FINISHED *****")
Example #4
0
def pre_train(device, model, model_pre_train_pth):

    print()
    print("***** PRE-TRAINING *****")
    print()

    # ------------
    #  Load Places2 data
    # ------------

    print("---> preparing dataloader...")

    # Training dataloader. Length = dataset size / batch size
    train_dataset = data.DATA(mode="train", train_status="pretrain")
    dataloader_train = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=argparser.n_cpu
    )
    print("---> length of training dataset: ", len(train_dataset))

    # Load test images
    test_dataset = data.DATA(mode="test", train_status="test")
    dataloader_test = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=argparser.n_cpu
    )
    print("---> length of test dataset: ", len(test_dataset))

    # -------
    # Model
    # -------

    # load model from checkpoint if available
    if os.path.exists(model_pre_train_pth):
        print("---> Found previously saved {}, loading checkpoint and CONTINUE pre-training"
              .format(args.saved_pre_train_name))
        load_model(model, model_pre_train_pth)
    else:
        print("---> Start pre-training from scratch: no checkpoint found")

    # ----------------
    #  Optimizer
    # ----------------

    # Optimizer
    print("---> preparing optimizer...")
    optimizer = optim.Adam(model.parameters(), lr=argparser.LR)
    criterion = nn.MSELoss()

    # Move model to device
    model.to(device)

    # ----------
    #  Training
    # ----------

    print("---> start training cycle ...")
    with open(os.path.join(args.output_dir, "pretrain_losses.csv"), "w", newline="") as csv_losses:
        with open(os.path.join(args.output_dir, "pretrain_scores.csv"), "w", newline="") as csv_scores:
            writer_losses = csv.writer(csv_losses)
            writer_losses.writerow(["Epoch", "Iteration", "Loss"])

            writer_scores = csv.writer(csv_scores)
            writer_scores.writerow(["Epoch", "Total Loss", "MSE", "SSIM", "Final Score"])

            highest_final_score = 0.0   # the higher the better, combines mse and ssim
            iteration = 0

            for epoch in range(args.pretrain_epochs):

                model.train()

                loss_sum = 0  # store accumulated loss for one epoch

                for idx, (imgs_masked, masks, gts) in enumerate(dataloader_train):
                    # Move to device
                    imgs_masked = imgs_masked.to(device)  # (N, 3, H, W)
                    masks = masks.to(device)  # (N, 3, H, W)
                    gts = gts.to(device)  # (N, 3, H, W)

                    #print("masked images shape: ",imgs_masked.shape) #torch.Size([32, 3, 256, 256])
                    #print("masks shape: ",masks.shape) #torch.Size([32, 1, 256, 256])
                    #print("target images shape: ",gts.shape) #torch.Size([32, 3, 256, 256])

                    # Model forward path => predicted images
                    preds = model(imgs_masked, masks)

                    original_pixels = torch.mul(masks, imgs_masked)
                    ones = torch.ones(masks.size()).cuda()
                    reversed_masks = torch.sub(ones,masks)
                    predicted_pixels = torch.mul(reversed_masks, preds)
                    preds = torch.add(original_pixels, predicted_pixels)

                    # Calculate total loss
                    #train_loss = loss.total_loss(preds, gts)
                    train_loss = criterion(preds, gts)

                    # Execute Back-Propagation
                    optimizer.zero_grad()
                    train_loss.backward()
                    optimizer.step()

                    print("\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f]" %
                          (epoch + 1, args.pretrain_epochs, (idx + 1), len(dataloader_train), train_loss), end="")

                    loss_sum += train_loss.item()
                    writer_losses.writerow([epoch+1, iteration+1, train_loss.item()])
                    iteration += 1

                # ------------------
                #  Evaluate & Save Model
                # ------------------

                if (epoch + 1) % args.val_epoch == 0:
                    mse, ssim = test.test(args, model, device, dataloader_test, mode="validate")
                    final_score = 1 - mse / 100 + ssim
                    print("\nMetrics on test set @ epoch {}:".format(epoch+1))
                    print("-> Average MSE:  {:.5f}".format(mse))
                    print("-> Average SSIM: {:.5f}".format(ssim))
                    print("-> Final Score:  {:.5f}".format(final_score))

                    if final_score > highest_final_score:
                        save_model(model, model_pre_train_pth)
                        highest_final_score = final_score

                    writer_scores.writerow([epoch+1, loss_sum, mse, ssim, final_score])

                save_model(model, os.path.join(args.model_dir_pre_train, "Net_pretrain_epoch{}.pth.tar".format(epoch + 1)))
                if epoch > 0:
                    remove_prev_model(os.path.join(args.model_dir_pre_train, "Net_pretrain_epoch{}.pth.tar".format(epoch)))

    print("\n***** Pre-Training FINISHED *****")
Example #5
0
    # print(model)
    # print(list(model.parameters()))

    # SET paths to best models
    model_pre_train_pth = os.path.join(args.model_dir_pre_train, args.saved_pre_train_name)
    model_fine_tune_pth = os.path.join(args.model_dir_fine_tune, args.saved_fine_tune_name)

    # -------
    #  Test Evaluate
    # ------

    # checkpoint = torch.load('Net_best_fine_tune.pth.tar', map_location='cpu')
    # model.load_state_dict(checkpoint)
    #
    # Load test images
    test_dataset = data.DATA(mode="test", train_status="TA")
    dataloader_test = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=argparser.n_cpu)

    train_dataset = data.DATA(mode="train", train_status="x")
    dataloader_train = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=argparser.n_cpu)
    # #
    # # for idx, (imgs_masked, masks, gts, _) in enumerate(dataloader_test):
    # #     print(imgs_masked.size())