Exemple #1
0
def main():

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dirpath = "/aiml/data/log_" + timestamp
    os.mkdir(log_dirpath)

    handlers = [logging.FileHandler(
        log_dirpath + "/deep_lpf.log"), logging.StreamHandler()]
    logging.basicConfig(
        level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s', handlers=handlers)

    parser = argparse.ArgumentParser(
        description="Train the DeepLPF neural network on image pairs")

    parser.add_argument(
        "--num_epoch", type=int, required=False, help="Number of epoches (default 5000)", default=100000)
    parser.add_argument(
        "--valid_every", type=int, required=False, help="Number of epoches after which to compute validation accuracy",
        default=500)
    parser.add_argument(
        "--checkpoint_filepath", required=False, help="Location of checkpoint file", default=None)
    parser.add_argument(
        "--inference_img_dirpath", required=False,
        help="Directory containing images to run through a saved DeepLPF model instance", default=None)

    args = parser.parse_args()
    num_epoch = args.num_epoch
    valid_every = args.valid_every
    checkpoint_filepath = args.checkpoint_filepath
    inference_img_dirpath = args.inference_img_dirpath

    logging.info('######### Parameters #########')
    logging.info('Number of epochs: ' + str(num_epoch))
    logging.info('Logging directory: ' + str(log_dirpath))
    logging.info('Dump validation accuracy every: ' + str(valid_every))
    logging.info('##############################')

    training_data_loader = Adobe5kDataLoader(data_dirpath="/aiml/data/",
                                             img_ids_filepath="/aiml/data/images_train.txt")
    training_data_dict = training_data_loader.load_data()
    training_dataset = Dataset(data_dict=training_data_dict, transform=transforms.Compose(
        [transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(),
         transforms.ToTensor()]),
        normaliser=2 ** 8 - 1, is_valid=False)

    validation_data_loader = Adobe5kDataLoader(data_dirpath="/aiml/data/",
                                               img_ids_filepath="/aiml/data/images_valid.txt")
    validation_data_dict = validation_data_loader.load_data()
    validation_dataset = Dataset(data_dict=validation_data_dict,
                                 transform=transforms.Compose([transforms.ToTensor()]), normaliser=2 ** 8 - 1,
                                 is_valid=True)

    testing_data_loader = Adobe5kDataLoader(data_dirpath="/aiml/data/",
                                            img_ids_filepath="/aiml/data/images_test.txt")
    testing_data_dict = testing_data_loader.load_data()
    testing_dataset = Dataset(data_dict=testing_data_dict,
                              transform=transforms.Compose([transforms.ToTensor()]), normaliser=2 ** 8 - 1,
                              is_valid=True)

    training_data_loader = torch.utils.data.DataLoader(training_dataset, batch_size=1, shuffle=True,
                                                       num_workers=4)
    testing_data_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=False,
                                                      num_workers=4)
    validation_data_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1,
                                                         shuffle=False,
                                                         num_workers=4)

    if (checkpoint_filepath is not None) and (inference_img_dirpath is not None):

        inference_data_loader = Adobe5kDataLoader(data_dirpath=inference_img_dirpath,
                                                  img_ids_filepath=inference_img_dirpath+"/images_inference.txt")
        inference_data_dict = inference_data_loader.load_data()
        inference_dataset = Dataset(data_dict=inference_data_dict,
                                    transform=transforms.Compose([transforms.ToTensor()]), normaliser=2 ** 8 - 1,
                                    is_valid=True)

        inference_data_loader = torch.utils.data.DataLoader(inference_dataset, batch_size=1, shuffle=False,
                                                            num_workers=4)

        '''
        Performs inference on all the images in inference_img_dirpath
        '''
        logging.info(
            "Performing inference with images in directory: " + inference_img_dirpath)

        net = torch.load(checkpoint_filepath,
                         map_location=lambda storage, location: storage)

        # switch model to evaluation mode
        net.eval()

        criterion = model.DeepLPFLoss()

        testing_evaluator = metric.Evaluator(
            criterion, inference_data_loader, "test", log_dirpath)

        testing_evaluator.evaluate(net, epoch=0)

    else:

        net = model.DeepLPFNet()

        logging.info('######### Network created #########')
        logging.info('Architecture:\n' + str(net))

        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name)

        criterion = model.DeepLPFLoss(ssim_window_size=5)

        '''
        The following objects allow for evaluation of a model on the testing and validation splits of a dataset
        '''
        validation_evaluator = metric.Evaluator(
            criterion, validation_data_loader, "valid", log_dirpath)
        testing_evaluator = metric.Evaluator(
            criterion, testing_data_loader, "test", log_dirpath)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-4, betas=(0.9, 0.999),
                               eps=1e-08)
        best_valid_psnr = 0.0

        alpha = 0.0
        optimizer.zero_grad()
        net.train()

        running_loss = 0.0
        examples = 0
        psnr_avg = 0.0
        ssim_avg = 0.0
        batch_size = 1
        net.cuda()

        for epoch in range(num_epoch):

            # Train loss
            examples = 0.0
            running_loss = 0.0
            
            for batch_num, data in enumerate(training_data_loader, 0):

                input_img_batch, output_img_batch, category = Variable(data['input_img'],
                                                                       requires_grad=False).cuda(), Variable(data['output_img'],
                                                                                                             requires_grad=False).cuda(), data[
                    'name']

                start_time = time.time()
                net_output_img_batch = net(
                    input_img_batch)
                net_output_img_batch = torch.clamp(
                    net_output_img_batch, 0.0, 1.0)

                elapsed_time = time.time() - start_time

                loss = criterion(net_output_img_batch, output_img_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.data[0]
                examples += batch_size
            
            logging.info('[%d] train loss: %.15f' %
                         (epoch + 1, running_loss / examples))

            # Valid loss
            examples = 0.0
            running_loss = 0.0

            for batch_num, data in enumerate(validation_data_loader, 0):

                net.eval()

                input_img_batch, output_img_batch, category = Variable(
                    data['input_img'],
                    requires_grad=False).cuda(), Variable(data['output_img'],
                                                         requires_grad=False).cuda(), \
                    data[
                    'name']

                net_output_img_batch = net(
                    input_img_batch)
                net_output_img_batch = torch.clamp(
                    net_output_img_batch, 0.0, 1.0)

                optimizer.zero_grad()

                loss = criterion(net_output_img_batch, output_img_batch)

                running_loss += loss.data[0]
                examples += batch_size

            logging.info('[%d] valid loss: %.15f' %
                         (epoch + 1, running_loss / examples))

            net.train()

            if (epoch + 1) % valid_every == 0:

                logging.info("Evaluating model on validation and test dataset")

                valid_loss, valid_psnr, valid_ssim = validation_evaluator.evaluate(
                    net, epoch)
                test_loss, test_psnr, test_ssim = testing_evaluator.evaluate(
                    net, epoch)

                # update best validation set psnr
                if valid_psnr > best_valid_psnr:

                    logging.info(
                        "Validation PSNR has increased. Saving the more accurate model to file: " + 'deeplpf_validpsnr_{}_validloss_{}_testpsnr_{}_testloss_{}_epoch_{}_model.pt'.format(valid_psnr,
                                                                                                                                                                                         valid_loss.tolist()[0], test_psnr, test_loss.tolist()[
                                                                                                                                                                                             0],
                                                                                                                                                                                         epoch))

                    best_valid_psnr = valid_psnr
                    snapshot_prefix = os.path.join(
                        log_dirpath, 'deeplpf')
                    snapshot_path = snapshot_prefix + '_validpsnr_{}_validloss_{}_testpsnr_{}_testloss_{}_epoch_{}_model.pt'.format(valid_psnr,
                                                                                                                                    valid_loss.tolist()[
                                                                                                                                        0],
                                                                                                                                    test_psnr, test_loss.tolist()[
                                                                                                                                        0],
                                                                                                                                    epoch)
                    torch.save(net, snapshot_path)

                net.train()

        '''
        Run the network over the testing dataset split
        '''
        testing_evaluator.evaluate(net, epoch=0)

        snapshot_prefix = os.path.join(log_dirpath, 'deep_lpf')
        snapshot_path = snapshot_prefix + "_" + str(num_epoch)
        torch.save(net.state_dict(), snapshot_path)
Exemple #2
0
def main():

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dirpath = "./log_" + timestamp
    os.mkdir(log_dirpath)

    parser = argparse.ArgumentParser(
        description="Train the DeepLPF neural network on image pairs")

    parser.add_argument(
        "--num_epoch", type=int, required=False, help="Number of epoches (default 5000)", default=1000)
    parser.add_argument(
        "--valid_every", type=int, required=False, help="Number of epoches after which to compute validation accuracy",
        default=50)
    parser.add_argument(
        "--checkpoint_filepath", required=False, help="Location of checkpoint file",
        default=None)
    parser.add_argument(
        "--inference_img_dirpath", required=False,
        help="Directory containing images to run through a saved DeepLPF model instance",
        default="/home/ubuntu/Volume/Sunyong/Danbi/dataset_CURL/210308_paper_dataset/DeepLPF_only/test")
    parser.add_argument(
        "--training_img_dirpath", required=False,
        help="Directory containing images to train a DeepLPF model instance",
        default="/home/ubuntu/Volume/Sunyong/Danbi/dataset_CURL/210308_paper_dataset/DeepLPF_only/train")

    args = parser.parse_args()
    num_epoch = args.num_epoch
    valid_every = args.valid_every
    checkpoint_filepath = args.checkpoint_filepath
    inference_img_dirpath = args.inference_img_dirpath
    training_img_dirpath = args.training_img_dirpath
    num_workers = 2

    print('######### Parameters #########')
    print('Number of epochs: ' + str(num_epoch))
    print('Logging directory: ' + str(log_dirpath))
    print('Dump validation accuracy every: ' + str(valid_every))
    print('Training image directory: ' + str(training_img_dirpath))
    print('##############################')


    if (checkpoint_filepath is not None) and (inference_img_dirpath is not None):

        '''
        inference_img_dirpath: the actual filepath should have "input" in the name an in the level above where the images 
        for inference are located, there should be a file "images_inference.txt with each image filename as one line i.e."
        
        images_inference.txt    ../
                                a1000.tif
                                a1242.tif
                                etc
        '''
        inference_data_loader = Adobe5kDataLoader(data_dirpath=inference_img_dirpath,
                                                  img_ids_filepath=inference_img_dirpath+"/images_inference.txt")
        inference_data_dict = inference_data_loader.load_data()
        inference_dataset = Dataset(data_dict=inference_data_dict,
                                    transform=transforms.Compose([transforms.ToTensor()]), normaliser=1,
                                    is_inference=True)

        inference_data_loader = torch.utils.data.DataLoader(inference_dataset, batch_size=1, shuffle=False,
                                                            num_workers=num_workers)

        '''
        Performs inference on all the images in inference_img_dirpath
        '''
        print(
            "Performing inference with images in directory: " + inference_img_dirpath)

        net = model.DeepLPFNet()
        net.load_state_dict(torch.load(checkpoint_filepath))
        net.eval()

        criterion = model.DeepLPFLoss()

        inference_evaluator = metric.Evaluator(
            criterion, inference_data_loader, "test", log_dirpath)

        inference_evaluator.evaluate(net, epoch=0)

    else:
        training_data_loader = Adobe5kDataLoader(data_dirpath=training_img_dirpath,
                                                 img_ids_filepath=training_img_dirpath+"/images_train.txt")
        training_data_dict = training_data_loader.load_data()
        training_dataset = Dataset(data_dict=training_data_dict, normaliser=1, is_valid=False)

        validation_data_loader = Adobe5kDataLoader(data_dirpath=training_img_dirpath,
                                               img_ids_filepath=training_img_dirpath+"/images_valid.txt")
        validation_data_dict = validation_data_loader.load_data()
        validation_dataset = Dataset(data_dict=validation_data_dict, normaliser=1, is_valid=True)

        testing_data_loader = Adobe5kDataLoader(data_dirpath=inference_img_dirpath,
                                            img_ids_filepath=inference_img_dirpath+"/images_test.txt")
        testing_data_dict = testing_data_loader.load_data()
        testing_dataset = Dataset(data_dict=testing_data_dict, normaliser=1,is_valid=True)

        training_data_loader = torch.utils.data.DataLoader(training_dataset, batch_size=1, shuffle=True,
                                                       num_workers=num_workers)
        testing_data_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=False,
                                                      num_workers=num_workers)
        validation_data_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1,
                                                         shuffle=False,
                                                         num_workers=num_workers)
        net = model.DeepLPFNet()
        net.cuda(0)

        print('######### Network created #########')
        print('Architecture:\n' + str(net))

        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name)

        criterion = model.DeepLPFLoss(ssim_window_size=5)

        '''
        The following objects allow for evaluation of a model on the testing and validation splits of a dataset
        '''
        validation_evaluator = metric.Evaluator(
            criterion, validation_data_loader, "valid", log_dirpath)
        testing_evaluator = metric.Evaluator(
            criterion, testing_data_loader, "test", log_dirpath)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-4, betas=(0.9, 0.999),
                               eps=1e-08)
        best_valid_psnr = 0.0

        alpha = 0.0
        optimizer.zero_grad()
        net.train()

        running_loss = 0.0
        examples = 0
        psnr_avg = 0.0
        ssim_avg = 0.0
        batch_size = 1
        total_examples = 0
        log_interval = 50

        for epoch in range(num_epoch):

            # Train loss
            examples = 0.0
            running_loss = 0.0
            
            for batch_num, data in enumerate(training_data_loader, 0):

                input_img_batch, gt_img_batch, category = Variable(data['input_img'],
                                                                       requires_grad=False).cuda(), Variable(data['output_img'],
                                                                                                             requires_grad=False).cuda(), data[
                    'name']

                start_time = time.time()
                net_img_batch = net(input_img_batch)
                net_img_batch = torch.clamp(net_img_batch, 0.0, 1.0)

                elapsed_time = time.time() - start_time

                loss = criterion(net_img_batch, gt_img_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.data[0]
                examples += batch_size
                total_examples+=batch_size

                if batch_num % log_interval == 0:
                    print('Loss/train: ', loss.data[0])
                # writer.add_scalar('Loss/train', loss.data[0], total_examples)

            print('[%d] train loss: %.15f' %
                         (epoch + 1, running_loss / examples))

            # writer.add_scalar('Loss/train_smooth', running_loss / examples, epoch + 1)

            # Valid loss
            '''
            examples = 0.0
            running_loss = 0.0

            for batch_num, data in enumerate(validation_data_loader, 0):

                net.eval()

                input_img_batch, output_img_batch, category = Variable(
                    data['input_img'],
                    requires_grad=False).cuda(), Variable(data['output_img'],
                                                         requires_grad=False).cuda(), \
                    data[
                    'name']

                net_output_img_batch = net(
                    input_img_batch)
                net_output_img_batch = torch.clamp(
                    net_output_img_batch, 0.0, 1.0)

                optimizer.zero_grad()

                loss = criterion(net_output_img_batch, output_img_batch)

                running_loss += loss.data[0]
                examples += batch_size
                total_examples+=batch_size

                writer.add_scalar('Loss/train', loss.data[0], total_examples)

            logging.info('[%d] valid loss: %.15f' %
                         (epoch + 1, running_loss / examples))
            writer.add_scalar('Loss/valid_smooth', running_loss / examples, epoch + 1)

            net.train()
            '''

            if (epoch + 1) % valid_every == 0:

                # print("Evaluating model on validation and test dataset")
                #
                # valid_loss, valid_psnr, valid_ssim = validation_evaluator.evaluate(
                #     net, epoch)
                # test_loss, test_psnr, test_ssim = testing_evaluator.evaluate(
                #     net, epoch)
                #
                # # update best validation set psnr
                # if valid_psnr > best_valid_psnr:
                #
                #     print(
                #         "Validation PSNR has increased. Saving the more accurate model to file: " + 'deeplpf_validpsnr_{}_validloss_{}_testpsnr_{}_testloss_{}_epoch_{}_model.pt'.format(valid_psnr,
                #                                                                                                                                                                          valid_loss.tolist()[0], test_psnr, test_loss.tolist()[
                #                                                                                                                                                                              0],
                #                                                                                                                                                                          epoch))

                snapshot_prefix = os.path.join(
                    log_dirpath, 'deeplpf')
                snapshot_path = snapshot_prefix + '_epoch_{}_model.pt'.format(epoch)
                torch.save(net.state_dict(), snapshot_path)

                net.train()

        '''
        Run the network over the testing dataset split
        '''
        testing_evaluator.evaluate(net, epoch=0)

        snapshot_prefix = os.path.join(log_dirpath, 'deep_lpf')
        snapshot_path = snapshot_prefix + "_" + str(num_epoch)
        torch.save(net.state_dict(), snapshot_path)
Exemple #3
0
def main():

    print(
        "*** Before running this code ensure you keep the default batch size of 1. The code has not been engineered to support higher batch sizes. See README for more detail. Remove the exit() statement to use code. ***"
    )
    exit()

    writer = SummaryWriter()

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dirpath = "./log_" + timestamp
    os.mkdir(log_dirpath)

    handlers = [
        logging.FileHandler(log_dirpath + "/deep_lpf.log"),
        logging.StreamHandler()
    ]
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(levelname)s %(message)s',
                        handlers=handlers)

    parser = argparse.ArgumentParser(
        description="Train the DeepLPF neural network on image pairs")

    parser.add_argument("--num_epoch",
                        type=int,
                        required=False,
                        help="Number of epoches (default 5000)",
                        default=100000)
    parser.add_argument(
        "--valid_every",
        type=int,
        required=False,
        help="Number of epoches after which to compute validation accuracy",
        default=25)
    parser.add_argument("--checkpoint_filepath",
                        required=False,
                        help="Location of checkpoint file",
                        default=None)
    parser.add_argument(
        "--inference_img_dirpath",
        required=False,
        help=
        "Directory containing images to run through a saved DeepLPF model instance",
        default=None)
    parser.add_argument(
        "--training_img_dirpath",
        required=True,
        help="Directory containing images to train a DeepLPF model instance",
        default="/home/sjm213/adobe5k/adobe5k/")
    parser.add_argument(
        "--inference_img_list_path",
        required=False,
        help="Plain text file containing the names of the images to inference")
    parser.add_argument(
        "--train_img_list_path",
        required=True,
        help="Plain text file containing the names of the training images")
    parser.add_argument(
        "--valid_img_list_path",
        required=True,
        help="Plain text file containing the names of the validation images")
    parser.add_argument(
        "--test_img_list_path",
        required=False,
        help="Plain text file containing the names of the test images")

    args = parser.parse_args()
    num_epoch = args.num_epoch
    valid_every = args.valid_every
    checkpoint_filepath = args.checkpoint_filepath
    inference_img_dirpath = args.inference_img_dirpath
    training_img_dirpath = args.training_img_dirpath
    inference_img_list_path = args.inference_img_list_path
    test_img_list_path = args.test_img_list_path
    valid_img_list_path = args.valid_img_list_path
    train_img_list_path = args.train_img_list_path

    logging.info('######### Parameters #########')
    logging.info('Number of epochs: ' + str(num_epoch))
    logging.info('Logging directory: ' + str(log_dirpath))
    logging.info('Dump validation accuracy every: ' + str(valid_every))
    logging.info('Training image directory: ' + str(training_img_dirpath))
    logging.info('List of images to inference: ' +
                 str(inference_img_list_path))
    logging.info('List of test images: ' + str(test_img_list_path))
    logging.info('List of validation images: ' + str(valid_img_list_path))
    logging.info('List of training images: ' + str(train_img_list_path))

    logging.info('##############################')

    BATCH_SIZE = 1  # *** WARNING: batch size of > 1 not supported in current version of code ***

    if (checkpoint_filepath is not None) and (inference_img_dirpath
                                              is not None):
        '''
        inference_img_dirpath: the actual filepath should have "input" in the name an in the level above where the images 
        for inference are located, there should be a file "images_inference.txt with each image filename as one line i.e."
        
        images_inference.txt    ../
                                a1000.tif
                                a1242.tif
                                etc
        '''
        inference_data_loader = Adobe5kDataLoader(
            data_dirpath=inference_img_dirpath,
            img_ids_filepath=inference_img_list_path)
        inference_data_dict = inference_data_loader.load_data()
        inference_dataset = Dataset(data_dict=inference_data_dict,
                                    transform=transforms.Compose(
                                        [transforms.ToTensor()]),
                                    normaliser=1,
                                    is_inference=True)

        assert (BATCH_SIZE == 1)
        inference_data_loader = torch.utils.data.DataLoader(
            inference_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=6)
        '''
        Performs inference on all the images in inference_img_dirpath
        '''
        logging.info("Performing inference with images in directory: " +
                     inference_img_dirpath)

        net = model.DeepLPFNet()
        net.load_state_dict(torch.load(checkpoint_filepath))
        net.eval()

        criterion = model.DeepLPFLoss()

        inference_evaluator = metric.Evaluator(criterion,
                                               inference_data_loader, "test",
                                               log_dirpath)

        inference_evaluator.evaluate(net, epoch=0)

    else:

        assert (BATCH_SIZE == 1)

        training_data_loader = Adobe5kDataLoader(
            data_dirpath=training_img_dirpath,
            img_ids_filepath=train_img_list_path)
        training_data_dict = training_data_loader.load_data()

        training_dataset = Dataset(data_dict=training_data_dict,
                                   normaliser=1,
                                   is_valid=False)

        validation_data_loader = Adobe5kDataLoader(
            data_dirpath=training_img_dirpath,
            img_ids_filepath=valid_img_list_path)
        validation_data_dict = validation_data_loader.load_data()
        validation_dataset = Dataset(data_dict=validation_data_dict,
                                     normaliser=1,
                                     is_valid=True)

        testing_data_loader = Adobe5kDataLoader(
            data_dirpath=training_img_dirpath,
            img_ids_filepath=test_img_list_path)
        testing_data_dict = testing_data_loader.load_data()
        testing_dataset = Dataset(data_dict=testing_data_dict,
                                  normaliser=1,
                                  is_valid=True)

        training_data_loader = torch.utils.data.DataLoader(
            training_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=6)
        testing_data_loader = torch.utils.data.DataLoader(
            testing_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=6)
        validation_data_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=6)
        net = model.DeepLPFNet()
        net.cuda(0)

        logging.info('######### Network created #########')
        logging.info('Architecture:\n' + str(net))

        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name)

        criterion = model.DeepLPFLoss(ssim_window_size=5)
        '''
        The following objects allow for evaluation of a model on the testing and validation splits of a dataset
        '''
        validation_evaluator = metric.Evaluator(criterion,
                                                validation_data_loader,
                                                "valid", log_dirpath)
        testing_evaluator = metric.Evaluator(criterion, testing_data_loader,
                                             "test", log_dirpath)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      net.parameters()),
                               lr=1e-4,
                               betas=(0.9, 0.999),
                               eps=1e-08)
        best_valid_psnr = 0.0

        optimizer.zero_grad()
        net.train()

        running_loss = 0.0
        examples = 0
        total_examples = 0

        for epoch in range(num_epoch):

            # Train loss
            examples = 0.0
            running_loss = 0.0

            for batch_num, data in enumerate(training_data_loader, 0):

                input_img_batch, gt_img_batch, _ = Variable(
                    data['input_img'], requires_grad=False).cuda(), Variable(
                        data['output_img'],
                        requires_grad=False).cuda(), data['name']

                start_time = time.time()
                net_img_batch = net(input_img_batch)
                net_img_batch = torch.clamp(net_img_batch, 0.0, 1.0)

                elapsed_time = time.time() - start_time

                loss = criterion(net_img_batch, gt_img_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.data[0]
                examples += BATCH_SIZE
                total_examples += BATCH_SIZE

                writer.add_scalar('Loss/train', loss.data[0], total_examples)

            logging.info('[%d] train loss: %.15f' %
                         (epoch + 1, running_loss / examples))
            writer.add_scalar('Loss/train_smooth', running_loss / examples,
                              epoch + 1)

            # Valid loss
            '''
            examples = 0.0
            running_loss = 0.0

            for batch_num, data in enumerate(validation_data_loader, 0):

                net.eval()

                input_img_batch, output_img_batch, category = Variable(
                    data['input_img'],
                    requires_grad=False).cuda(), Variable(data['output_img'],
                                                         requires_grad=False).cuda(), \
                    data[
                    'name']

                net_output_img_batch = net(
                    input_img_batch)
                net_output_img_batch = torch.clamp(
                    net_output_img_batch, 0.0, 1.0)

                optimizer.zero_grad()

                loss = criterion(net_output_img_batch, output_img_batch)

                running_loss += loss.data[0]
                examples += BATCH_SIZE
                total_examples+=BATCH_SIZE

                writer.add_scalar('Loss/train', loss.data[0], total_examples)

            logging.info('[%d] valid loss: %.15f' %
                         (epoch + 1, running_loss / examples))
            writer.add_scalar('Loss/valid_smooth', running_loss / examples, epoch + 1)

            net.train()
            '''

            if (epoch + 1) % valid_every == 0:

                logging.info("Evaluating model on validation and test dataset")

                valid_loss, valid_psnr, valid_ssim = validation_evaluator.evaluate(
                    net, epoch)
                test_loss, test_psnr, test_ssim = testing_evaluator.evaluate(
                    net, epoch)

                # update best validation set psnr
                if valid_psnr > best_valid_psnr:

                    logging.info(
                        "Validation PSNR has increased. Saving the more accurate model to file: "
                        +
                        'deeplpf_validpsnr_{}_validloss_{}_testpsnr_{}_testloss_{}_epoch_{}_model.pt'
                        .format(valid_psnr,
                                valid_loss.tolist()[0], test_psnr,
                                test_loss.tolist()[0], epoch))

                    best_valid_psnr = valid_psnr
                    snapshot_prefix = os.path.join(log_dirpath, 'deeplpf')
                    snapshot_path = snapshot_prefix + '_validpsnr_{}_validloss_{}_testpsnr_{}_testloss_{}_epoch_{}_model.pt'.format(
                        valid_psnr,
                        valid_loss.tolist()[0], test_psnr,
                        test_loss.tolist()[0], epoch)
                    torch.save(net.state_dict(), snapshot_path)

                net.train()
        '''
        Run the network over the testing dataset split
        '''
        testing_evaluator.evaluate(net, epoch=0)

        snapshot_prefix = os.path.join(log_dirpath, 'deep_lpf')
        snapshot_path = snapshot_prefix + "_" + str(num_epoch)
        torch.save(net.state_dict(), snapshot_path)
Exemple #4
0
def main():
    args = arg_parser.Parse()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    logger = Logger(args.log_dir)
    logger.PrintAndLogArgs(args)
    saver = ImageAndLossSaver(args.tb_logs_dir, logger.log_folder,
                              args.checkpoints_dir, args.save_pics_every)
    source_loader, target_train_loader, target_eval_loader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args,
                                   'train'), CreateTrgDataLoader(args, 'val')
    epoch_size = np.maximum(len(target_train_loader.dataset),
                            len(source_loader.dataset))
    steps_per_epoch = int(np.floor(epoch_size / args.batch_size))
    source_loader.dataset.SetEpochSize(epoch_size)
    target_train_loader.dataset.SetEpochSize(epoch_size)

    generator = model.DeepLPFNet()
    generator = nn.DataParallel(generator.cuda())
    generator_criterion = model.GeneratorLoss()
    generator_optimizer = optim.Adam(generator.parameters(),
                                     lr=args.generator_lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08)
    discriminator = model.Discriminator()
    discriminator = nn.DataParallel(discriminator.cuda())
    discriminator_criterion = model.DiscriminatorLoss()
    discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                         lr=args.discriminator_lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-08)
    semseg_net, semseg_optimizer = CreateModel(args)
    semseg_net = nn.DataParallel(semseg_net.cuda())

    logger.info('######### Network created #########')
    logger.info('Architecture of Generator:\n' + str(generator))
    logger.info('Architecture of Discriminator:\n' + str(discriminator))
    logger.info('Architecture of Backbone net:\n' + str(semseg_net))

    for epoch in range(args.num_epochs):
        generator.train()
        discriminator.train()
        semseg_net.train()
        saver.Reset()
        discriminate_src = True
        source_loader_iter, target_train_loader_iter, target_eval_loader_iter = iter(
            source_loader), iter(target_train_loader), iter(target_eval_loader)
        logger.info('#################[Epoch %d]#################' %
                    (epoch + 1))

        for batch_num in range(steps_per_epoch):
            start_time = time.time()
            training_discriminator = (batch_num >= args.generator_boost) and (
                batch_num - args.generator_boost) % (
                    args.discriminator_iters +
                    args.generator_iters) < args.discriminator_iters
            src_img, src_lbl, src_shapes, src_names = source_loader_iter.next(
            )  # new batch source
            trg_eval_img, trg_eval_lbl, trg_shapes, trg_names = target_train_loader_iter.next(
            )  # new batch target

            generator_optimizer.zero_grad()
            discriminator_optimizer.zero_grad()
            semseg_optimizer.zero_grad()

            src_input_batch = Variable(src_img, requires_grad=False).cuda()
            src_label_batch = Variable(src_lbl, requires_grad=False).cuda()
            trg_input_batch = Variable(trg_eval_img,
                                       requires_grad=False).cuda()
            # trg_label_batch = Variable(trg_lbl, requires_grad=False).cuda()
            src_in_trg = generator(src_input_batch, trg_input_batch)  # G(S,T)

            if training_discriminator:  #train discriminator
                if discriminate_src == True:
                    discriminator_src_in_trg = discriminator(
                        src_in_trg)  # D(G(S,T))
                    discriminator_trg = None  # D(T)
                else:
                    discriminator_src_in_trg = None  # D(G(S,T))
                    discriminator_trg = discriminator(trg_input_batch)  # D(T)
                discriminate_src = not discriminate_src
                loss = discriminator_criterion(discriminator_src_in_trg,
                                               discriminator_trg)
            else:  #train generator and semseg net
                discriminator_trg = discriminator(trg_input_batch)  # D(T)
                predicted, loss_seg, loss_ent = semseg_net(
                    src_in_trg, lbl=src_label_batch)  # F(G(S.T))
                src_in_trg_labels = torch.argmax(predicted, dim=1)
                loss = generator_criterion(loss_seg, loss_ent, args.entW,
                                           discriminator_trg)

            saver.WriteLossHistory(training_discriminator, loss.item())
            loss.backward()

            if training_discriminator:  # train discriminator
                discriminator_optimizer.step()
            else:  # train generator and semseg net
                generator_optimizer.step()
                semseg_optimizer.step()

            saver.running_time += time.time() - start_time

            if (not training_discriminator) and saver.SaveImagesIteration:
                saver.SaveTrainImages(epoch, src_img[0, :, :, :],
                                      src_in_trg[0, :, :, :], src_lbl[0, :, :],
                                      src_in_trg_labels[0, :, :])

            if (batch_num + 1) % args.print_every == 0:
                logger.PrintAndLogData(saver, epoch, batch_num,
                                       args.print_every)

            if (batch_num + 1) % args.save_checkpoint == 0:
                saver.SaveModelsCheckpoint(semseg_net, discriminator,
                                           generator, epoch, batch_num)

        #Validation:
        semseg_net.eval()
        rand_samp_inds = np.random.randint(0, len(target_eval_loader.dataset),
                                           5)
        rand_batchs = np.floor(rand_samp_inds / args.batch_size).astype(np.int)
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES)).cuda()
        for val_batch_num, (trg_eval_img, trg_eval_lbl, _,
                            _) in enumerate(target_eval_loader):
            with torch.no_grad():
                trg_input_batch = Variable(trg_eval_img,
                                           requires_grad=False).cuda()
                trg_label_batch = Variable(trg_eval_lbl,
                                           requires_grad=False).cuda()
                pred_softs_batch = semseg_net(trg_input_batch)
                pred_batch = torch.argmax(pred_softs_batch, dim=1)
                cm += compute_cm_batch_torch(pred_batch, trg_label_batch,
                                             IGNORE_LABEL, NUM_CLASSES)
                print('Validation: saw', val_batch_num * args.batch_size,
                      'examples')
                if (val_batch_num + 1) in rand_batchs:
                    rand_offset = np.random.randint(0, args.batch_size)
                    saver.SaveValidationImages(
                        epoch, trg_input_batch[rand_offset, :, :, :],
                        trg_label_batch[rand_offset, :, :],
                        pred_batch[rand_offset, :, :])
        iou, miou = compute_iou_torch(cm)
        saver.SaveEpochAccuracy(iou, miou, epoch)
        logger.info(
            'Average accuracy of Epoch #%d on target domain: mIoU = %2f' %
            (epoch + 1, miou))
        logger.info(
            '-----------------------------------Epoch #%d Finished-----------------------------------'
            % (epoch + 1))
        del cm, trg_input_batch, trg_label_batch, pred_softs_batch, pred_batch

    saver.tb.close()
    logger.info('Finished training.')