示例#1
0
def train():
    loss_func = nn.MSELoss()
    net = MLP()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.005)
    train_set = TrainDatasetFromFolder('./data/train/',
                                       crop_size=200,
                                       upscale_factor=1)  #训练集导入
    train_loader = DataLoader(dataset=train_set,
                              num_workers=6,
                              batch_size=1,
                              shuffle=True)  #训练集制作
    for epoch in range(70):
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=0.005 * (0.98**epoch))
        for target, ir, img in train_loader:
            target = cv2.resize(target.detach().numpy(),
                                (img.shape[2], img.shape[3]))
            target = torch.tensor(target, dtype=torch.float32).reshape(
                (img.shape[1], img.shape[0], 1))
            predict = net(img)
            optimizer.zero_grad()
            loss = loss_func(predict, target)
            loss.backward()
            optimizer.step()
    torch.save(net.state_dict(), './net.pkl')
示例#2
0
    valid = diff[:, :, shave:-shave, shave:-shave]
    mse = valid.pow(2).mean()

    return -10 * math.log10(mse)


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    CROP_SIZE = cfg.crop_size
    UPSCALE_FACTOR = cfg.upscale_factor
    NUM_EPOCHS = cfg.epochs

    train_set = TrainDatasetFromFolder(
        "data/training_hr_images",
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR,
    )
    val_set = ValDatasetFromFolder(
        "data/validation", upscale_factor=UPSCALE_FACTOR
    )
    train_loader = DataLoader(
        dataset=train_set,
        num_workers=cfg.workers,
        batch_size=cfg.batch_size,
        shuffle=True,
    )
    val_loader = DataLoader(
        dataset=val_set, num_workers=cfg.workers, batch_size=1, shuffle=False
    )
示例#3
0
                    choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=200,
                    type=int,
                    help='train epoch number')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    tsrain_set = TrainDatasetFromFolder(
        '/home/alex/SRGAN-master/data/VOC2012/train',
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('/home/alex/SRGAN-master/data/VOC2012/val',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
示例#4
0
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=256, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=8, type=int, choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')


if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', crop_size=CROP_SIZE*2, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=4, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))


    generator_criterion = GeneratorLoss()
    refinement_criterion = torch.nn.L1Loss()

    if torch.cuda.is_available():
        netG.cuda()
示例#5
0
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
unnormalize = Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444])

if __name__ == "__main__":
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    BATCH_SIZE = opt.batch_size
    INTERPOLATION = opt.interpolation

    train_set = TrainDatasetFromFolder(
        "../testing/small/testing",
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR,
        interpolation=INTERPOLATION,
    )
    val_set = ValDatasetFromFolder(
        "../testing/nowe", upscale_factor=UPSCALE_FACTOR, interpolation=INTERPOLATION
    )
    train_loader = DataLoader(
        dataset=train_set, num_workers=4, batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = Generator(16, UPSCALE_FACTOR)
    print("# generator parameters:", sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print("# discriminator parameters:", sum(param.numel() for param in netD.parameters()))
示例#6
0
                    type=int,
                    help='test batchSize')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--start', default=1, type=int, help='start num')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_set = TrainDatasetFromFolder('../../../CelebA-HQ-img/',
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('../../../CelebA-HQ-img/',
                                   crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=opt.batchSize,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=opt.testBatchSize,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR).to(device)
    netD = Discriminator().to(device)
示例#7
0
                    help='noise type')

lr = 0.0002
beta1 = 0.5
best_model = None
best_gen = None
if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    DATASET = opt.data_set
    CASE = opt.case
    train_set = TrainDatasetFromFolder('data/train/' + DATASET + '/train_HR',
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('data/train/' + DATASET + '/valid_HR',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    netG.apply(weights_init)
    print('# generator parameters:',
                    choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=100,
                    type=int,
                    help='train epoch number')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = TrainDatasetFromFolder(
        '../../../../split/ICDAR2015-TextSR-dataset/RELEASE_2015-08-31/DATA/TRAIN',
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder(
        '../../../../split/ICDAR2015-TextSR-dataset/RELEASE_2015-08-31/DATA/VAL',
        upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=1,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
示例#9
0
                    type=int,
                    help='train epoch number')

opt = parser.parse_args()

CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs

init_aux('', 'epochs/')
init_aux('results', str(UPSCALE_FACTOR) + 'x')
init_aux('results', 'statistics/')
# import pdb; pdb.set_trace()

train_set = TrainDatasetFromFolder('data/' + str(UPSCALE_FACTOR) + '/train/',
                                   crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('data/' + str(UPSCALE_FACTOR) + '/val/',
                               upscale_factor=UPSCALE_FACTOR)

train_loader = DataLoader(dataset=train_set,
                          num_workers=4,
                          batch_size=16,
                          shuffle=True)
val_loader = DataLoader(dataset=val_set,
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
示例#10
0
if __name__ == '__main__':
    #hyperparameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--crop_size', type=int, default=88)
    parser.add_argument('--upscale_factor',
                        type=int,
                        default=4,
                        choices=[2, 4, 8])
    parser.add_argument('--lr', type=float, default=0.01)
    config = parser.parse_args()
    print(config)

    # Data Loader
    train_set = TrainDatasetFromFolder('../../Data/VOC2012/train',
                                       crop_size=config.crop_size,
                                       upscale_factor=config.upscale_factor)
    val_set = ValDatasetFromFolder('../../Data/VOC2012/val',
                                   upscale_factor=config.upscale_factor)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=1,
                              batch_size=16,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=1,
                            batch_size=1,
                            shuffle=False)

    G = Generator(config.upscale_factor)
    D = Discriminator()
示例#11
0
                    choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=100,
                    type=int,
                    help='train epoch number')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = TrainDatasetFromFolder(
        '/home/ufuk/ufuk/MLSP_Project/data/train',
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('/home/ufuk/ufuk/MLSP_Project/data/valid',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
示例#12
0
                    type=int,
                    help='train epoch number')
parser.add_argument('--dataset_name', default='DIV2K', type=str)

path = "drive/My Drive/Aerocosmos/1/model_start_4_asarticle/"

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    DATASET_NAME = opt.dataset_name

    train_set = TrainDatasetFromFolder('drive/My Drive/Aerocosmos/data/' +
                                       DATASET_NAME + '/train',
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('drive/My Drive/Aerocosmos/data/' +
                                   DATASET_NAME + '/valid',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
示例#13
0
文件: train.py 项目: gakarak/SRGAN
def main_train(path_trn: str, path_val: str,
               crop_size: int, upscale_factor: int, num_epochs: int,
               num_workers: int, to_device: str = 'cuda:0', batch_size: int = 64):
    to_device = get_device(to_device)
    train_set = TrainDatasetFromFolder(path_trn, crop_size=crop_size, upscale_factor=upscale_factor)
    val_set = ValDatasetFromFolder(path_val, upscale_factor=upscale_factor)
    # train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=crop_size, upscale_factor=upscale_factor)
    # val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=upscale_factor)
    #
    train_loader = DataLoader(dataset=train_set, num_workers=num_workers, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=num_workers, batch_size=1, shuffle=False)

    netG = Generator(upscale_factor)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    for epoch in range(1, num_epochs + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        netG.train()
        netD.train()
        # FIXME: seperate function for epoch training
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            #
            # img_hr = target.numpy().transpose((0, 2, 3, 1))[0]
            # img_lr = data.numpy().transpose((0, 2, 3, 1))[0]
            # img_lr_x4 = cv2.resize(img_lr, img_hr.shape[:2], interpolation=cv2.INTER_CUBIC)
            # #
            # plt.subplot(1, 3, 1)
            # plt.imshow(img_hr)
            # plt.subplot(1, 3, 2)
            # plt.imshow(img_lr)
            # plt.subplot(1, 3, 3)
            # plt.imshow(img_lr_x4)
            # plt.show()
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            # real_img = Variable(target)
            # if torch.cuda.is_available():
            #     real_img = real_img.cuda()
            # z = Variable(data)
            # if torch.cuda.is_available():
            #     z = z.cuda()
            z = data.to(to_device)
            real_img = target.to(to_device)
            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            g_loss = generator_criterion(fake_out, fake_img, real_img)
            running_results['g_loss'] += float(g_loss) * batch_size
            d_loss = 1 - real_out + fake_out
            running_results['d_loss'] += float(d_loss) * batch_size
            running_results['d_score'] += float(real_out) * batch_size
            running_results['g_score'] += float(fake_out) * batch_size

            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

        netG.eval()
        #FIXME: seperate function for epoch validation
        with torch.no_grad():
            out_path = 'training_results/SRF_' + str(upscale_factor) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            val_bar = tqdm(val_loader)
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                # lr = Variable(val_lr, volatile=True)
                # hr = Variable(val_hr, volatile=True)
                # if torch.cuda.is_available():
                #     lr = lr.cuda()
                #     hr = hr.cuda()
                lr = val_lr.to(to_device)
                hr = val_hr.to(to_device)
                sr = netG(lr)

                batch_mse = ((sr - hr) ** 2).mean()
                valing_results['mse'] += float(batch_mse) * batch_size
                batch_ssim = float(pytorch_ssim.ssim(sr, hr)) #.data[0]
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']))

                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                     display_transform()(sr.data.cpu().squeeze(0))])
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1

        # save model parameters
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (upscale_factor, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (upscale_factor, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
示例#14
0
                    type=int,
                    choices=[1, 2, 4],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=4000,
                    type=int,
                    help='train epoch number')

opt = parser.parse_args()

CROP_SIZE = opt.crop_size  #裁剪会带来拼尽问题嘛
UPSCALE_FACTOR = opt.upscale_factor  #上采样
NUM_EPOCHS = opt.num_epochs  #轮数

train_set = TrainDatasetFromFolder('/data/lpw/FusionDataset/train/',
                                   crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)  #训练集导入
val_set = ValDatasetFromFolder('/data/lpw/FusionDataset/val/',
                               upscale_factor=UPSCALE_FACTOR)  #测试集导入
train_loader = DataLoader(dataset=train_set,
                          num_workers=4,
                          batch_size=64,
                          shuffle=True)  #训练集制作
val_loader = DataLoader(dataset=val_set,
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)  #网络模型
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
示例#15
0
                    help='load epoch number')
parser.add_argument('--generatorWeights',
                    type=str,
                    default='',
                    help="path to CSNet weights (to continue training)")

opt = parser.parse_args()

CROP_SIZE = opt.crop_size
BLOCK_SIZE = opt.block_size
NUM_EPOCHS = opt.num_epochs
PRE_EPOCHS = opt.pre_epochs
LOAD_EPOCH = 0

train_set = TrainDatasetFromFolder('/media/gdh-95/data/Train',
                                   crop_size=CROP_SIZE,
                                   blocksize=BLOCK_SIZE)
train_loader = DataLoader(dataset=train_set,
                          num_workers=4,
                          batch_size=opt.batchSize,
                          shuffle=True)

net = CSNet(BLOCK_SIZE, opt.sub_rate)

mse_loss = nn.MSELoss()

if opt.generatorWeights != '':
    net.load_state_dict(torch.load(opt.generatorWeights))
    LOAD_EPOCH = opt.loadEpoch

if torch.cuda.is_available():
示例#16
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--augmentation", action='store_true')
    parser.add_argument("--train-dataset-percentage", type=float, default=100)
    parser.add_argument("--val-dataset-percentage", type=int, default=100)
    parser.add_argument("--label-smoothing", type=float, default=0.9)
    parser.add_argument("--validation-frequency", type=int, default=1)
    args = parser.parse_args()

    ENABLE_AUGMENTATION = args.augmentation
    TRAIN_DATASET_PERCENTAGE = args.train_dataset_percentage
    VAL_DATASET_PERCENTAGE = args.val_dataset_percentage
    LABEL_SMOOTHING_FACTOR = args.label_smoothing
    VALIDATION_FREQUENCY = args.validation_frequency

    if ENABLE_AUGMENTATION:
        augment_batch = AugmentPipe()
        augment_batch.to(device)
    else:
        augment_batch = lambda x: x
        augment_batch.p = 0

    NUM_ADV_EPOCHS = round(NUM_ADV_BASELINE_EPOCHS /
                           (TRAIN_DATASET_PERCENTAGE / 100))
    NUM_PRETRAIN_EPOCHS = round(NUM_BASELINE_PRETRAIN_EPOCHS /
                                (TRAIN_DATASET_PERCENTAGE / 100))
    VALIDATION_FREQUENCY = round(VALIDATION_FREQUENCY /
                                 (TRAIN_DATASET_PERCENTAGE / 100))

    training_start = datetime.datetime.now().isoformat()

    train_set = TrainDatasetFromFolder(train_dataset_dir,
                                       patch_size=PATCH_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    len_train_set = len(train_set)
    train_set = Subset(
        train_set,
        list(
            np.random.choice(
                np.arange(len_train_set),
                int(len_train_set * TRAIN_DATASET_PERCENTAGE / 100), False)))

    val_set = ValDatasetFromFolder(val_dataset_dir,
                                   upscale_factor=UPSCALE_FACTOR)
    len_val_set = len(val_set)
    val_set = Subset(
        val_set,
        list(
            np.random.choice(np.arange(len_val_set),
                             int(len_val_set * VAL_DATASET_PERCENTAGE / 100),
                             False)))

    train_loader = DataLoader(dataset=train_set,
                              num_workers=8,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              prefetch_factor=8)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=2,
                            batch_size=VAL_BATCH_SIZE,
                            shuffle=False,
                            pin_memory=True,
                            prefetch_factor=2)

    epoch_validation_hr_dataset = HrValDatasetFromFolder(
        val_dataset_dir)  # Useful to compute FID metric

    results_folder = Path(
        f"results_{training_start}_CS:{PATCH_SIZE}_US:{UPSCALE_FACTOR}x_TRAIN:{TRAIN_DATASET_PERCENTAGE}%_AUGMENTATION:{ENABLE_AUGMENTATION}"
    )
    results_folder.mkdir(exist_ok=True)
    writer = SummaryWriter(str(results_folder / "tensorboard_log"))
    g_net = Generator(n_residual_blocks=NUM_RESIDUAL_BLOCKS,
                      upsample_factor=UPSCALE_FACTOR)
    d_net = Discriminator(patch_size=PATCH_SIZE)
    lpips_metric = lpips.LPIPS(net='alex')

    g_net.to(device=device)
    d_net.to(device=device)
    lpips_metric.to(device=device)

    g_optimizer = optim.Adam(g_net.parameters(), lr=1e-4)
    d_optimizer = optim.Adam(d_net.parameters(), lr=1e-4)

    bce_loss = BCELoss()
    mse_loss = MSELoss()

    bce_loss.to(device=device)
    mse_loss.to(device=device)
    results = {
        'd_total_loss': [],
        'g_total_loss': [],
        'g_adv_loss': [],
        'g_content_loss': [],
        'd_real_mean': [],
        'd_fake_mean': [],
        'psnr': [],
        'ssim': [],
        'lpips': [],
        'fid': [],
        'rt': [],
        'augment_probability': []
    }

    augment_probability = 0
    num_images = len(train_set) * (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS)
    prediction_list = []
    rt = 0

    for epoch in range(1, NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS + 1):
        train_bar = tqdm(train_loader, ncols=200)
        running_results = {
            'batch_sizes': 0,
            'd_epoch_total_loss': 0,
            'g_epoch_total_loss': 0,
            'g_epoch_adv_loss': 0,
            'g_epoch_content_loss': 0,
            'd_epoch_real_mean': 0,
            'd_epoch_fake_mean': 0,
            'rt': 0,
            'augment_probability': 0
        }
        image_percentage = epoch / (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) * 100
        g_net.train()
        d_net.train()

        for data, target in train_bar:
            augment_batch.p = torch.tensor([augment_probability],
                                           device=device)
            batch_size = data.size(0)
            running_results["batch_sizes"] += batch_size
            target = target.to(device)
            data = data.to(device)
            real_labels = torch.ones(batch_size, device=device)
            fake_labels = torch.zeros(batch_size, device=device)

            if epoch > NUM_PRETRAIN_EPOCHS:
                # Discriminator training
                d_optimizer.zero_grad(set_to_none=True)

                d_real_output = d_net(augment_batch(target))
                d_real_output_loss = bce_loss(
                    d_real_output, real_labels * LABEL_SMOOTHING_FACTOR)

                fake_img = g_net(data)
                d_fake_output = d_net(augment_batch(fake_img))
                d_fake_output_loss = bce_loss(d_fake_output, fake_labels)

                d_total_loss = d_real_output_loss + d_fake_output_loss
                d_total_loss.backward()
                d_optimizer.step()

                d_real_mean = d_real_output.mean()
                d_fake_mean = d_fake_output.mean()

            # Generator training
            g_optimizer.zero_grad(set_to_none=True)

            fake_img = g_net(data)
            if epoch > NUM_PRETRAIN_EPOCHS:
                adversarial_loss = bce_loss(d_net(augment_batch(fake_img)),
                                            real_labels) * ADV_LOSS_BALANCER
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss + adversarial_loss
            else:
                adversarial_loss = mse_loss(torch.zeros(
                    1, device=device), torch.zeros(
                        1,
                        device=device))  # Logging purposes, it is always zero
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss

            g_total_loss.backward()
            g_optimizer.step()

            if epoch > NUM_PRETRAIN_EPOCHS and ENABLE_AUGMENTATION:
                prediction_list.append(
                    (torch.sign(d_real_output - 0.5)).tolist())
                if len(prediction_list) == RT_BATCH_SMOOTHING_FACTOR:
                    rt_list = [
                        prediction for sublist in prediction_list
                        for prediction in sublist
                    ]
                    rt = mean(rt_list)
                    if mean(rt_list) > AUGMENT_PROB_TARGET:
                        augment_probability = min(
                            0.85,
                            augment_probability + AUGMENT_PROBABABILITY_STEP)
                    else:
                        augment_probability = max(
                            0.,
                            augment_probability - AUGMENT_PROBABABILITY_STEP)
                    prediction_list.clear()

            running_results['g_epoch_total_loss'] += g_total_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_adv_loss'] += adversarial_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_content_loss'] += content_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            if epoch > NUM_PRETRAIN_EPOCHS:
                running_results['d_epoch_total_loss'] += d_total_loss.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_real_mean'] += d_real_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_fake_mean'] += d_fake_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['rt'] += rt * batch_size
                running_results[
                    'augment_probability'] += augment_probability * batch_size

            train_bar.set_description(
                desc=f'[{epoch}/{NUM_ADV_EPOCHS + NUM_PRETRAIN_EPOCHS}] '
                f'Loss_D: {running_results["d_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G: {running_results["g_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_adv: {running_results["g_epoch_adv_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_content: {running_results["g_epoch_content_loss"] / running_results["batch_sizes"]:.4f} '
                f'D(x): {running_results["d_epoch_real_mean"] / running_results["batch_sizes"]:.4f} '
                f'D(G(z)): {running_results["d_epoch_fake_mean"] / running_results["batch_sizes"]:.4f} '
                f'rt: {running_results["rt"] / running_results["batch_sizes"]:.4f} '
                f'augment_probability: {running_results["augment_probability"] / running_results["batch_sizes"]:.4f}'
            )

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            torch.cuda.empty_cache()
            gc.collect()
            g_net.eval()
            # ...
            images_path = results_folder / Path(f'training_images_results')
            images_path.mkdir(exist_ok=True)

            with torch.no_grad():
                val_bar = tqdm(val_loader, ncols=160)
                val_results = {
                    'epoch_mse': 0,
                    'epoch_ssim': 0,
                    'epoch_psnr': 0,
                    'epoch_avg_psnr': 0,
                    'epoch_avg_ssim': 0,
                    'epoch_lpips': 0,
                    'epoch_avg_lpips': 0,
                    'epoch_fid': 0,
                    'batch_sizes': 0
                }
                val_images = torch.empty((0, 0))
                epoch_validation_sr_dataset = None
                for lr, val_hr_restore, hr in val_bar:
                    batch_size = lr.size(0)
                    val_results['batch_sizes'] += batch_size
                    hr = hr.to(device=device)
                    lr = lr.to(device=device)

                    sr = g_net(lr)
                    sr = torch.clamp(sr, 0., 1.)
                    if not epoch_validation_sr_dataset:
                        epoch_validation_sr_dataset = SingleTensorDataset(
                            (sr.cpu() * 255).to(torch.uint8))

                    else:
                        epoch_validation_sr_dataset = ConcatDataset(
                            (epoch_validation_sr_dataset,
                             SingleTensorDataset(
                                 (sr.cpu() * 255).to(torch.uint8))))

                    batch_mse = ((sr - hr)**2).data.mean()  # Pixel-wise MSE
                    val_results['epoch_mse'] += batch_mse * batch_size
                    batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                    val_results['epoch_ssim'] += batch_ssim * batch_size
                    val_results['epoch_avg_ssim'] = val_results[
                        'epoch_ssim'] / val_results['batch_sizes']
                    val_results['epoch_psnr'] += 20 * log10(
                        hr.max() / (batch_mse / batch_size)) * batch_size
                    val_results['epoch_avg_psnr'] = val_results[
                        'epoch_psnr'] / val_results['batch_sizes']
                    val_results['epoch_lpips'] += torch.mean(
                        lpips_metric(hr * 2 - 1, sr * 2 - 1)).to(
                            'cpu', non_blocking=True).detach() * batch_size
                    val_results['epoch_avg_lpips'] = val_results[
                        'epoch_lpips'] / val_results['batch_sizes']

                    val_bar.set_description(
                        desc=
                        f"[converting LR images to SR images] PSNR: {val_results['epoch_avg_psnr']:4f} dB "
                        f"SSIM: {val_results['epoch_avg_ssim']:4f} "
                        f"LPIPS: {val_results['epoch_avg_lpips']:.4f} ")
                    if val_images.size(0) * val_images.size(
                            1) < NUM_LOGGED_VALIDATION_IMAGES * 3:
                        if val_images.size(0) == 0:
                            val_images = torch.hstack(
                                (display_transform(CENTER_CROP_SIZE)
                                 (val_hr_restore).unsqueeze(0).transpose(0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     hr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     sr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1)))
                        else:
                            val_images = torch.cat((
                                val_images,
                                torch.hstack(
                                    (display_transform(CENTER_CROP_SIZE)(
                                        val_hr_restore).unsqueeze(0).transpose(
                                            0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         hr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         sr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1)))))
                val_results['epoch_fid'] = calculate_metrics(
                    epoch_validation_sr_dataset,
                    epoch_validation_hr_dataset,
                    cuda=True,
                    fid=True,
                    verbose=True
                )['frechet_inception_distance']  # Set batch_size=1 if you get memory error (inside calculate metric function)

                val_images = val_images.view(
                    (NUM_LOGGED_VALIDATION_IMAGES // 4, -1, 3,
                     CENTER_CROP_SIZE, CENTER_CROP_SIZE))
                val_save_bar = tqdm(val_images,
                                    desc='[saving validation results]',
                                    ncols=160)

                for index, image_batch in enumerate(val_save_bar, start=1):
                    image_grid = utils.make_grid(image_batch,
                                                 nrow=3,
                                                 padding=5)
                    writer.add_image(
                        f'progress{image_percentage:.1f}_index_{index}.png',
                        image_grid)

        # save loss / scores / psnr /ssim
        results['d_total_loss'].append(running_results['d_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_total_loss'].append(running_results['g_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_adv_loss'].append(running_results['g_epoch_adv_loss'] /
                                     running_results['batch_sizes'])
        results['g_content_loss'].append(
            running_results['g_epoch_content_loss'] /
            running_results['batch_sizes'])
        results['d_real_mean'].append(running_results['d_epoch_real_mean'] /
                                      running_results['batch_sizes'])
        results['d_fake_mean'].append(running_results['d_epoch_fake_mean'] /
                                      running_results['batch_sizes'])
        results['rt'].append(running_results['rt'] /
                             running_results['batch_sizes'])
        results['augment_probability'].append(
            running_results['augment_probability'] /
            running_results['batch_sizes'])
        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            results['psnr'].append(val_results['epoch_avg_psnr'])
            results['ssim'].append(val_results['epoch_avg_ssim'])
            results['lpips'].append(val_results['epoch_avg_lpips'])
            results['fid'].append(val_results['epoch_fid'])

        for metric, metric_values in results.items():
            if epoch == 1 or epoch == (
                    NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1 or \
                    metric not in ["psnr", "ssim", "lpips", "fid"]:
                writer.add_scalar(metric, metric_values[-1],
                                  int(image_percentage * num_images * 0.01))

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            # save model parameters
            models_path = results_folder / "saved_models"
            models_path.mkdir(exist_ok=True)
            torch.save(
                {
                    'progress': image_percentage,
                    'g_net': g_net.state_dict(),
                    'd_net': g_net.state_dict(),
                    # 'g_optimizer': g_optimizer.state_dict(), Uncomment this if you want resume training
                    # 'd_optimizer': d_optimizer.state_dict(),
                },
                str(models_path / f'progress_{image_percentage:.1f}.tar'))
示例#17
0
                    choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=100,
                    type=int,
                    help='train epoch number')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = TrainDatasetFromFolder(
        'data/train', crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR)  # 可以直接 `train_set[xx]` 访问
    val_set = ValDatasetFromFolder(
        'data/val', upscale_factor=UPSCALE_FACTOR)  # 可以直接 `val_set[xx]` 访问
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)  # Iterable
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)  # Iterable

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
示例#18
0
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs',
                    default=200,
                    type=int,
                    help='train epoch number')
parser.add_argument(
    '--dataset',
    default='/home/lizhuangzi/Desktop/imageNetData2/ILSVRC2013_DET_test',
    type=str)
parser.add_argument('--valset', default='./Set5', type=str)

opt = parser.parse_args()

# Setting dataset
train_set = TrainDatasetFromFolder(opt.dataset,
                                   crop_size=opt.crop_size,
                                   upscale_factor=opt.upscale_factor)

val_set = ValDatasetFromFolder(opt.valset, upscale_factor=opt.upscale_factor)

train_loader = DataLoader(dataset=train_set,
                          num_workers=8,
                          batch_size=32,
                          shuffle=True)

val_loader = DataLoader(dataset=val_set,
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

# Define Network
示例#19
0
if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    train_set_foldpath = opt.train_set_foldpath
    val_set_foldpath = opt.val_set_foldpath
    pre_Gmodel = opt.pre_Gmodel
    pre_Dmodel = opt.pre_Dmodel

    # train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    # val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
    # train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    # val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
    train_set = TrainDatasetFromFolder(train_set_foldpath, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder(val_set_foldpath, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netG.load_state_dict(torch.load(pre_Gmodel))
        netD.cuda()
    opt.sub_rate) + '_blocksize_' + str(BLOCK_SIZE)
argv = sys.argv[1:]
for arg in argv:
    if arg in "--group_num":
        save_dir = save_dir + "_g%d" % (opt.group_num)
    if arg in "--loss_mode":
        save_dir = save_dir + "_l%s" % (opt.loss_mode)
    if arg in "--fusion_mode":
        save_dir = save_dir + "_%s" % (opt.fusion_mode)
    if arg in "--zc":
        save_dir = save_dir + "_zc%d" % (opt.zc)
    if arg in "--weight":
        save_dir = save_dir + "_weight%d" % (opt.zc)

train_set = TrainDatasetFromFolder('data/train_crop',
                                   crop_size=CROP_SIZE,
                                   blocksize=BLOCK_SIZE)
train_loader = DataLoader(dataset=train_set,
                          num_workers=16,
                          batch_size=opt.batchSize,
                          shuffle=True)

use_variance_estimation = True
net = HierarchicalCSNet(BLOCK_SIZE,
                        opt.sub_rate,
                        group_num=GROUP_NUM,
                        mode=FUSION_MODE,
                        variance_estimation=use_variance_estimation,
                        z_channel=opt.zc)

示例#21
0
parser.add_argument('--val_data_dir',
                    default='data',
                    type=str,
                    help='validation data path')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    TRAIN_DATA_DIR = opt.train_data_dir
    VAL_DATA_DIR = opt.val_data_dir

    train_set = TrainDatasetFromFolder(TRAIN_DATA_DIR,
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder(VAL_DATA_DIR, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
示例#22
0
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
parser.add_argument('--img_dir', default='data/train', help='path to train dataset')
parser.add_argument('--val_dir', default='data/train', help='path to train dataset')


if __name__ == '__main__':
    opt = parser.parse_args()
    
    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    
    train_set = TrainDatasetFromFolder(opt.img_dir, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder(opt.val_dir, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=False)
    
    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
    
    generator_criterion = GeneratorLoss()
    
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()
示例#23
0
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=30, type=int, help='train epoch number')

# train and valid set
# input_data = pd.read_csv('/nfs/masi/hansencb/CDMRI_2020/challenge_info.csv')
# print(input_data['ISOTROPIC'])

if __name__ == '__main__':
    #torch.backends.cudnn.enabled = False
    opt = parser.parse_args()
    
    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = TrainDatasetFromFolder('/home/local/VANDERBILT/kanakap/challenge_info.csv', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('/home/local/VANDERBILT/kanakap/validation_info.csv', upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=1, shuffle=False)
    val_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=False)
    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
    
    generator_criterion = GeneratorLoss()
    
    if torch.cuda.is_available():
        print(netG)
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()
示例#24
0
                    type=int,
                    help='train epoch number')

if __name__ == '__main__':
    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    # train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    # val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
    # train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    # val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
    train_set = TrainDatasetFromFolder(
        '../input/sr-test/VOC2012/VOC2012/train',
        crop_size=CROP_SIZE,
        upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('../input/sr-test/VOC2012/VOC2012/val',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
示例#25
0
parser.add_argument('--discriminator_pretrain_path', default=None, help='path to pretrained discriminator network')
parser.add_argument('--generator_optim_pretrain_path', default=None, help='path to pretrained generator network')
parser.add_argument('--discriminator_optim_pretrain_path', default=None, help='path to pretrained discriminator network')

opt = parser.parse_args()

SMALLEST_CROP_SIZE = opt.smallest_crop_size
LARGEST_CROP_SIZE = opt.largest_crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
START_EPOCH = opt.start_epoch

print(NUM_EPOCHS)
print(START_EPOCH)

train_set = TrainDatasetFromFolder('data/VOC2012/train', smallest_crop_size=SMALLEST_CROP_SIZE, largest_crop_size=LARGEST_CROP_SIZE, upscale_factor=UPSCALE_FACTOR, batch_size=opt.batch_size)
val_set = ValDatasetFromFolder('data/VOC2012/val', crop_size=LARGEST_CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=16, batch_size=opt.batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=16, batch_size=1, shuffle=False)

#div2kmean = ([0.4488, 0.4371, 0.4040])
#div2kstd = ([0.2845, 0.2701, 0.2920])
netG = Generator(UPSCALE_FACTOR)
if opt.generator_pretrain_path is not None:
    print('loading pretrained model at ' + opt.generator_pretrain_path)
    netG.load_state_dict(torch.load(opt.generator_pretrain_path))
else:
    print('initializing generator weights')
    netG.initialize_weights()
#netG = torch.nn.Sequential(
#            Normalize(div2kmean, div2kstd),
示例#26
0
                    help='generator update number')
parser.add_argument('--num_epochs',
                    default=100,
                    type=int,
                    help='train epoch number')

opt = parser.parse_args()

CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
G_TRIGGER_THRESHOLD = opt.g_trigger_threshold
G_UPDATE_NUMBER = opt.g_update_number

train_set = TrainDatasetFromFolder('data/VOC2012/train',
                                   crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('data/VOC2012/val',
                               upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set,
                          num_workers=4,
                          batch_size=64,
                          shuffle=True)
val_loader = DataLoader(dataset=val_set,
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
示例#27
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'

    train_set = TrainDatasetFromFolder('data/DIV2K_train_HR',
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('data/DIV2K_valid_HR',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel()
              for param in netG.parameters()))  # Generator의 총 parameter 수
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel()
              for param in netD.parameters()))  # Discriminator의 총 parameter 수

    generator_criterion = GeneratorLoss()  # loss function
    netG = nn.DataParallel(netG).cuda()
    netD = nn.DataParallel(netD).cuda()
    #netG = netG.cuda()
    #netD = netD.cuda()
    generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())  # optimizer : adam
    optimizerD = optim.Adam(netD.parameters())  # optimizer : adam

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }

    for epoch in range(1, NUM_EPOCHS + 1):
        # model train
        d_loss, g_loss, d_score, g_score = train(netG, netD,
                                                 generator_criterion,
                                                 optimizerG, optimizerD,
                                                 train_loader, epoch)

        # validation data acc
        psnr, ssim = test(netG, netD, val_loader, epoch)

        # save loss\scores\psnr\ssim
        results['d_loss'].append(d_loss)
        results['g_loss'].append(g_loss)
        results['d_score'].append(d_score)
        results['g_score'].append(g_score)
        results['psnr'].append(psnr)
        results['ssim'].append(ssim)

        # save results
        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(data={
                'Loss_D': results['d_loss'],
                'Loss_G': results['g_loss'],
                'Score_D': results['d_score'],
                'Score_G': results['g_score'],
                'PSNR': results['psnr'],
                'SSIM': results['ssim']
            },
                                      index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) +
                              '_train_results.csv',
                              index_label='Epoch')

        print()