Пример #1
0
def train_gan():
    [
        depth_net, color_net, d_net, depth_optimizer, color_optimizer,
        d_optimizer
    ] = load_networks(True)
    generator_criterion = GeneratorLoss()
    if param.useGPU:
        generator_criterion.cuda()
    train_system(depth_net, color_net, d_net, depth_optimizer, color_optimizer,
                 d_optimizer, generator_criterion)
Пример #2
0
  def __init__(self, config):
    with open(config, 'r') as f:
      config = json.load(f)
    self.config = config

    self.iterations = config['train']['iterations']
    self.critic_iters = config['train']['critic_iters']
    self.batch_size = config['train']['batch_size']
    self.lr = config['train']['learning_rate']
    img_size = config['model']['image_size']

    trainset = CelebaDataset(config['data']['image_path'], size=img_size)
    self.trainloader = DataLoader(trainset, batch_size=self.batch_size, pin_memory=True)

    self.z_size = config['train']['z_size']
    self.generator = Generator(self.z_size)
    self.descriminator = Descriminator()

    lam = config['train']['lambda']
    self.g_loss = GeneratorLoss()
    self.d_loss = DescriminatorLoss(lam=lam)

    betas = (config['train']['beta1'], config['train']['beta2'])
    self.g_optim = optim.Adam(self.generator.parameters(), lr=self.lr, betas=betas)
    self.d_optim = optim.Adam(self.descriminator.parameters(), lr=self.lr, betas=betas)

    self.dlosses = []
    self.glosses = []

    if torch.cuda.is_available():
      self.generator.cuda()
      self.descriminator.cuda()
      self.device = torch.device("cuda")
      print("Using GPU")
    else:
      self.device = torch.device("cpu")
      print("No GPU detected")

    self.write_interval = config['model']['write_interval']
    self.train_info_path = self.config['model']['trainer_save_path']
    self.generator_path = self.config['model']['generator_save_path'].split('.pt')[0]
    self.descriminator_path = self.config['model']['descriminator_save_path'].split('.pt')[0]
    self.img_path = self.config['model']['image_save_path'].split('.png')[0]
Пример #3
0
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(upscale)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))
    generator_loss = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_loss.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
Пример #4
0
    logging.info("Creating the dataset...")

    # fetch dataloaders
    dataloaders = data_loader.fetch_dataloader(['test'], args.data_dir, params)
    test_dl = dataloaders['test']

    logging.info("- done.")

    # Define the model
    #     model = net.Net(params).cuda() if params.cuda else net.Net(params)
    netG = net.Generator(4).cuda(cuda_id)
    netD = net.Discriminator().cuda(cuda_id)

    loss_fn = None
    if net == GAN or net == CGAN:
        loss_fn = GeneratorLoss().cuda(cuda_id)
    elif net == GAN_adv:
        loss_fn = GeneratorLoss_adv().cuda(cuda_id)
    elif net == GAN_ssim:
        loss_fn = GeneratorLoss_ssim().cuda(cuda_id)
    elif net == GAN_notv:
        loss_fn = GeneratorLoss_notv().cuda(cuda_id)

    metrics = net.metrics

    logging.info("Starting evaluation")

    # Reload weights from the saved file
    restore_path_g = os.path.join(args.model_dir, 'best_g' + '.pth.tar')
    restore_path_d = os.path.join(args.model_dir, 'best_d' + '.pth.tar')
    utils.load_checkpoint(restore_path_g, netG)
Пример #5
0
def main(args):
    if (not os.path.exists('data/dataset.pt')):
        # Make sure the bit depth is 24, 8 = Gray scale
        df = pd.read_pickle('data/dataset_files.gzip')
        df = df[(df['width'] > 100) & (df['height'] > 100)]
        train_df, val_df = train_test_split(df,
                                            test_size=0.2,
                                            random_state=42,
                                            shuffle=True)
        _, val_similar = dataframe_find_similar_images(
            val_df, batch_size=args.batch_size)

        # Create the train dataset
        train_filenames = train_df['filename'].tolist()
        train_set = TrainDatasetFromList(train_filenames,
                                         crop_size=args.crop_size,
                                         upscale_factor=args.upscale_factor)

        val_sets = list()
        for val_df in val_similar:
            val_filenames = val_df['filename'].tolist()
            val_set = ValDatasetFromList(val_filenames,
                                         upscale_factor=args.upscale_factor)
            val_sets.append(val_set)

        train_sampler = torch.utils.data.RandomSampler(train_set)
        val_sampler = torch.utils.data.SequentialSampler(val_set)
        data_to_save = {
            'train_dataset': train_set,
            "val_datasets": val_sets,
            'train_sampler': train_sampler,
            'val_sampler': val_sampler
        }
        torch.save(data_to_save, 'data/dataset.pt')
    else:
        datasets = torch.load('data/dataset.pt')
        train_set = datasets['train_dataset']
        val_sets = datasets['val_datasets']
        train_sampler = datasets['train_sampler']
        val_sampler = datasets['val_sampler']

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              sampler=train_sampler)
    val_loaders = list()
    for val_set in val_sets:
        val_loaders.append(
            DataLoader(dataset=val_set,
                       batch_size=args.batch_size,
                       num_workers=args.num_workers,
                       shuffle=False))

    netG = Generator(args.upscale_factor)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }
    start_epoch = 1
    if args.resume:
        import glob
        netG_files = glob.glob(
            os.path.join(args.output_dir,
                         'netG_epoch_%d_*.pth' % (args.upscale_factor)))
        netD_files = glob.glob(
            os.path.join(args.output_dir,
                         'netD_epoch_%d_*.pth' % (args.upscale_factor)))
        if (len(netG_files) > 0):
            netG_file = max(netG_files, key=os.path.getctime)
            netD_file = max(netD_files, key=os.path.getctime)
            netG.load_state_dict(torch.load(netG_file))
            netD.load_state_dict(torch.load(netD_file))
            start_epoch = len(netG_files)

    for epoch in range(start_epoch, args.epochs + 1):
        train_bar = tqdm(train_loader)
        running_results = {
            'batch_sizes': 0,
            'd_loss': 0,
            'g_loss': 0,
            'd_score': 0,
            'g_score': 0
        }

        dscaler = torch.cuda.amp.GradScaler(
        )  # Creates once at the beginning of training #* Discriminator
        gscaler = torch.cuda.amp.GradScaler()  #* Generator
        netG.train()
        netD.train()
        for data, target in train_bar:
            with torch.cuda.amp.autocast():  # Mix precision
                batch_size = data.size(0)
                running_results['batch_sizes'] += batch_size

                ############################
                # (1) Update D network: maximize D(x)-1-D(G(z))
                ###########################
                netD.zero_grad()
                real_img = Variable(target, requires_grad=False)
                if torch.cuda.is_available():
                    real_img = real_img.cuda()
                z = Variable(data)
                if torch.cuda.is_available():
                    z = z.cuda()
                fake_img = netG(z)

                real_out = netD(real_img).mean(
                )  # Discriminator Takes in the real image and predicts whether it's real
                fake_out = netD(fake_img).mean(
                )  # Discriminator takes in the fake image and predicts if it's fake
                d_loss = 1 - real_out + fake_out  # Minimizing the loss would mean real_out=1 and fake out = 0. so it knows the real image it knows the fake image

                # d_loss.backward(retain_graph=True)
                # optimizerD.step()
                ############################
                # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
                ###########################
                netG.zero_grad()
                g_loss = generator_criterion(fake_out.detach(), fake_img,
                                             real_img.detach())

            dscaler.scale(d_loss).backward(retain_graph=True)
            gscaler.scale(g_loss).backward()

            dscaler.step(optimizerD)
            dscaler.update()
            gscaler.step(optimizerG)
            gscaler.update()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            # loss for current batch before optimization
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            train_bar.set_description(
                desc=
                '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' %
                (epoch, args.epochs,
                 running_results['d_loss'] / running_results['batch_sizes'],
                 running_results['g_loss'] / running_results['batch_sizes'],
                 running_results['d_score'] / running_results['batch_sizes'],
                 running_results['g_score'] / running_results['batch_sizes']))

        # save model parameters
        torch.save(
            netG.state_dict(),
            os.path.join(args.output_dir, 'netG_epoch_%d_%d.pth' %
                         (args.upscale_factor, epoch)))
        torch.save(
            netD.state_dict(),
            os.path.join(args.output_dir, 'netD_epoch_%d_%d.pth' %
                         (args.upscale_factor, epoch)))

        if epoch % args.validation_epoch == 0 and epoch != 0:
            netG.eval()
            with torch.no_grad():
                val_results = {
                    'mse': 0,
                    'ssims': 0,
                    'psnr': 0,
                    'ssim': 0,
                    'batch_sizes': 0
                }
                val_images = []
                for i in trange(len(val_loaders), desc='Running validation'):
                    val_loader = val_loaders[i]
                    for val_lr, val_hr_restore, val_hr in val_loader:
                        batch_size = val_lr.size(0)
                        val_results['batch_sizes'] += batch_size
                        lr = val_lr
                        hr = val_hr
                        if torch.cuda.is_available():
                            lr = lr.cuda()
                            hr = hr.cuda()
                        sr = netG(lr)

                        batch_mse = ((sr - hr)**2).data.mean()
                        val_results['mse'] += batch_mse * batch_size
                        batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                        val_results['ssims'] += batch_ssim * batch_size
                        val_results['psnr'] = 10 * log10(
                            (hr.max()**2) /
                            (val_results['mse'] / val_results['batch_sizes']))
                        val_results['ssim'] = val_results[
                            'ssims'] / val_results['batch_sizes']
                        # val_bar.set_description(
                        #     desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        #         val_results['psnr'], val_results['ssim']))

                        # convert the validation images
                        val_hr_restore_squeeze = val_hr_restore.squeeze(0)
                        hr_squeeze = hr.data.cpu().squeeze(0)
                        sr_squeeze = sr.data.cpu().squeeze(0)
                        for b in range(batch_size):
                            val_hr = val_hr_restore_squeeze[b]
                            hr_temp = hr_squeeze[b]
                            sr_temp = sr_squeeze[b]
                            val_images.extend([
                                display_transform()(val_hr),
                                display_transform()(hr_temp),
                                display_transform()(sr_temp)
                            ])

                val_images = torch.stack(val_images)
                val_images = torch.chunk(val_images, val_images.size(0) // 15)
                val_save_bar = tqdm(val_images,
                                    desc='[saving training results]')
                index = 1
                for image in val_save_bar:
                    image = utils.make_grid(image, nrow=3, padding=5)
                    utils.save_image(
                        image,
                        os.path.join(
                            args.output_dir,
                            'epoch_%d_upscale_%d_index_%d.png' %
                            (epoch, args.upscale_factor, index)))
                    index += 1

            # save loss\scores\psnr\ssim
            results['d_loss'].append(running_results['d_loss'] /
                                     running_results['batch_sizes'])
            results['g_loss'].append(running_results['g_loss'] /
                                     running_results['batch_sizes'])
            results['d_score'].append(running_results['d_score'] /
                                      running_results['batch_sizes'])
            results['g_score'].append(running_results['g_score'] /
                                      running_results['batch_sizes'])
            results['psnr'].append(val_results['psnr'])
            results['ssim'].append(val_results['ssim'])

        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(data={
                'Loss_D': results['d_loss'],
                'Loss_G': results['g_loss'],
                'Score_D': results['d_score'],
                'Score_G': results['g_score'],
                'PSNR': results['psnr'],
                'SSIM': results['ssim']
            },
                                      index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(args.upscale_factor) +
                              '_train_results.csv',
                              index_label='Epoch')
Пример #6
0
from model import Generator, Discriminator
import architecture as arch
import pdb 
import torch.nn.functional as F

gpu_id = 0
port_num = 8091
display = visualizer(port=port_num)
report_feq = 10
NUM_EPOCHS = 40

netG = arch.RRDB_Net(4, 3, 64, 12, gc=32, upscale=1, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
netD = Discriminator()

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.to(gpu_id)
    netD.to(gpu_id)
    generator_criterion.to(gpu_id)

optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)

train_set = MyDataLoader(hr_dir='../data/train_sample/HR/', hr_sample_dir='../data/train_sample/HR_Sample/4/', lap_dir='../data/train_sample/LAP_HR_Norm/')
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=2, shuffle=True)

step = 0
for epoch in range(1, NUM_EPOCHS):
    netG.train()
Пример #7
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'

    train_set = TrainDatasetFromFolder('data/DIV2K_train_HR',
                                       crop_size=CROP_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('data/DIV2K_valid_HR',
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=64,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel()
              for param in netG.parameters()))  # Generator의 총 parameter 수
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel()
              for param in netD.parameters()))  # Discriminator의 총 parameter 수

    generator_criterion = GeneratorLoss()  # loss function
    netG = nn.DataParallel(netG).cuda()
    netD = nn.DataParallel(netD).cuda()
    #netG = netG.cuda()
    #netD = netD.cuda()
    generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())  # optimizer : adam
    optimizerD = optim.Adam(netD.parameters())  # optimizer : adam

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }

    for epoch in range(1, NUM_EPOCHS + 1):
        # model train
        d_loss, g_loss, d_score, g_score = train(netG, netD,
                                                 generator_criterion,
                                                 optimizerG, optimizerD,
                                                 train_loader, epoch)

        # validation data acc
        psnr, ssim = test(netG, netD, val_loader, epoch)

        # save loss\scores\psnr\ssim
        results['d_loss'].append(d_loss)
        results['g_loss'].append(g_loss)
        results['d_score'].append(d_score)
        results['g_score'].append(g_score)
        results['psnr'].append(psnr)
        results['ssim'].append(ssim)

        # save results
        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(data={
                'Loss_D': results['d_loss'],
                'Loss_G': results['g_loss'],
                'Score_D': results['d_score'],
                'Score_G': results['g_score'],
                'PSNR': results['psnr'],
                'SSIM': results['ssim']
            },
                                      index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) +
                              '_train_results.csv',
                              index_label='Epoch')

        print()
Пример #8
0
def main():

    # SRGAN parameters

    batch_size = 10
    epochs = 100
    lr = 0.01
    threads = 4
    upscale_factor = args.upscale_factor

    img_path_low = args.inputDir
    img_path_ref = args.targetDir

    train_set = DatasetSuperRes(img_path_low, img_path_ref)

    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=threads,
                                      batch_size=batch_size,
                                      shuffle=True)

    netG = Generator(upscale_factor).to(device)
    print('# Generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator().to(device)
    print('# Discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))
    generator_criterion = GeneratorLoss().to(device)

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    out_path = 'results/'
    out_model_path = 'models/'

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    if not os.path.exists(out_model_path):
        os.makedirs(out_model_path)

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }

    # Training

    begin_counter = time.time()

    for epoch in range(1, epochs + 1):

        running_results = {
            'batch_sizes': 0,
            'd_loss': 0,
            'g_loss': 0,
            'd_score': 0,
            'g_score': 0
        }
        netG.train()
        netD.train()

        for data, target in training_data_loader:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            real_img = Variable(target).to(device)
            z = Variable(data).to(device)

            fake_img = netG(z)

            netD.zero_grad()

            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()

            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)

            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()

            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            optimizerG.step()

            # Loss for current batch before optimization

            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

        print(
            '[{}/{}] Loss_D: {:.4f} Loss_G: {:.4f} D(x): {:.4f} D(G(z)): {:.4f}'
            .format(
                epoch, epochs,
                running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

        netG.eval()

        batch_mse = ((fake_img - real_img)**2).data.mean()
        batch_ssim = ssim(fake_img, real_img).item()
        batch_psnr = 10 * log10(1 / batch_mse)

        # Save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] /
                                 running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] /
                                 running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] /
                                  running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] /
                                  running_results['batch_sizes'])
        results['psnr'].append(batch_psnr)
        results['ssim'].append(batch_ssim)

        # Checkpoint
        if epoch % (epochs // 10) == 0:
            # Save model
            torch.save(
                netG, out_model_path + 'netG_x%d_epoch_%d.pth' %
                (upscale_factor, epoch))
            #torch.save(netD, 'netD_x%d_epoch_%d.pt' % (upscale_factor, epoch))

            data_frame = pd.DataFrame(data={
                'Loss_D': results['d_loss'],
                'Loss_G': results['g_loss'],
                'Score_D': results['d_score'],
                'Score_G': results['g_score'],
                'PSNR': results['psnr'],
                'SSIM': results['ssim']
            },
                                      index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'SRGAN_x' + str(upscale_factor) +
                              '_train_results.csv',
                              index_label='Epoch')

    end_counter = time.time()
    training_time = end_counter - begin_counter
    print("Seconds spent during training = ", training_time)
    report = open(
        out_path + "SRGAN_model_x" + str(args.upscale_factor) + ".txt", "w")
    report.write("Training time: {:.2f}".format(training_time))
    report.close()
Пример #9
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()
Пример #10
0
def main_train(path_trn: str, path_val: str,
               crop_size: int, upscale_factor: int, num_epochs: int,
               num_workers: int, to_device: str = 'cuda:0', batch_size: int = 64):
    to_device = get_device(to_device)
    train_set = TrainDatasetFromFolder(path_trn, crop_size=crop_size, upscale_factor=upscale_factor)
    val_set = ValDatasetFromFolder(path_val, upscale_factor=upscale_factor)
    # train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=crop_size, upscale_factor=upscale_factor)
    # val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=upscale_factor)
    #
    train_loader = DataLoader(dataset=train_set, num_workers=num_workers, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=num_workers, batch_size=1, shuffle=False)

    netG = Generator(upscale_factor)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    for epoch in range(1, num_epochs + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        netG.train()
        netD.train()
        # FIXME: seperate function for epoch training
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            #
            # img_hr = target.numpy().transpose((0, 2, 3, 1))[0]
            # img_lr = data.numpy().transpose((0, 2, 3, 1))[0]
            # img_lr_x4 = cv2.resize(img_lr, img_hr.shape[:2], interpolation=cv2.INTER_CUBIC)
            # #
            # plt.subplot(1, 3, 1)
            # plt.imshow(img_hr)
            # plt.subplot(1, 3, 2)
            # plt.imshow(img_lr)
            # plt.subplot(1, 3, 3)
            # plt.imshow(img_lr_x4)
            # plt.show()
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            # real_img = Variable(target)
            # if torch.cuda.is_available():
            #     real_img = real_img.cuda()
            # z = Variable(data)
            # if torch.cuda.is_available():
            #     z = z.cuda()
            z = data.to(to_device)
            real_img = target.to(to_device)
            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            g_loss = generator_criterion(fake_out, fake_img, real_img)
            running_results['g_loss'] += float(g_loss) * batch_size
            d_loss = 1 - real_out + fake_out
            running_results['d_loss'] += float(d_loss) * batch_size
            running_results['d_score'] += float(real_out) * batch_size
            running_results['g_score'] += float(fake_out) * batch_size

            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

        netG.eval()
        #FIXME: seperate function for epoch validation
        with torch.no_grad():
            out_path = 'training_results/SRF_' + str(upscale_factor) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            val_bar = tqdm(val_loader)
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                # lr = Variable(val_lr, volatile=True)
                # hr = Variable(val_hr, volatile=True)
                # if torch.cuda.is_available():
                #     lr = lr.cuda()
                #     hr = hr.cuda()
                lr = val_lr.to(to_device)
                hr = val_hr.to(to_device)
                sr = netG(lr)

                batch_mse = ((sr - hr) ** 2).mean()
                valing_results['mse'] += float(batch_mse) * batch_size
                batch_ssim = float(pytorch_ssim.ssim(sr, hr)) #.data[0]
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']))

                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                     display_transform()(sr.data.cpu().squeeze(0))])
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1

        # save model parameters
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (upscale_factor, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (upscale_factor, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
Пример #11
0
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)  #网络模型
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
small_netD = nn.Sequential(*list((netD.net))[:23]).eval()
#small_shadow_netD = nn.Sequential(*list((netD.net))[:2]).eval()
#for param in small_netD.parameters():
#	param.requires_grad = False
print('# discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss(batchSize=64)  #生成器损失
adversarial_criterion = nn.BCELoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {
    'd_loss': [],
    'g_loss': [],
    'd_score': [],
    'g_score': [],
Пример #12
0
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE_TRAIN,
                          shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

# Networks and Loss
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
if USE_DISCRIMINATOR:
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss(weight_perception=WEIGHT_PERCEPTION,
                                    weight_adversarial=WEIGHT_ADVERSARIAL,
                                    weight_image=WEIGHT_IMAGE,
                                    network=NETWORK)

if USE_CUDA:
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        netG = torch.nn.DataParallel(netG)
        # netG = torch.nn.ModelDataParallel(
        #     netG, device_ids=list(range(NUM_GPU)))
        if USE_DISCRIMINATOR:
            netD = torch.nn.DataParallel(netD)
            # netD = torch.nn.ModelDataParallel(
            #     netD, device_ids=list(range(NUM_GPU)))
    netG.cuda()
    if USE_DISCRIMINATOR:
        netD.cuda()
Пример #13
0
 
 CROP_SIZE = opt.crop_size
 UPSCALE_FACTOR = opt.upscale_factor
 NUM_EPOCHS = opt.num_epochs
 
 train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
 val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
 train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
 val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
 
 netG = Generator(UPSCALE_FACTOR)
 print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
 netD = Discriminator()
 print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
 
 generator_criterion = GeneratorLoss(opt.loss_net)
 
 if torch.cuda.is_available():
     netG.cuda()
     netD.cuda()
     generator_criterion.cuda()
 
 optimizerG = optim.Adam(netG.parameters())
 optimizerD = optim.Adam(netD.parameters())
 
 results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
 
 for epoch in range(1, NUM_EPOCHS + 1):
     train_bar = tqdm(train_loader)
     running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
     
Пример #14
0
print('===> Loading validation dataset')
val_set = DatasetFromFolderEval(
    image_dir=Path(opt.data_root) / 'val',
    upscale_factor=opt.upscale_factor,
)
val_loader = DataLoader(val_set, shuffle=False)

print('===> Building model')
netG = Generator(opt.upscale_factor).to(device)
print('# Generator parameters:', sum(p.numel() for p in netG.parameters()))
netD = Discriminator(opt.patch_size).to(device)
print('# Discriminator parameters:', sum(p.numel() for p in netD.parameters()))

print('===> Defining criterions')
mse_loss = nn.MSELoss().to(device)
criterionG = GeneratorLoss(opt.loss_type, opt.adv_coefficient).to(device)
criterionD = DiscriminatorLoss().to(device)

print('===> Defining optimizers')
optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizerD = optim.Adam(netD.parameters(), lr=1e-4)

writer = SummaryWriter()
log_dir = Path(writer.log_dir)
sample_dir = log_dir / 'sample'
sample_dir.mkdir(exist_ok=True)
weight_dir = log_dir / 'weights'
weight_dir.mkdir(exist_ok=True)

global_step = 0
for epoch in range(1, opt.num_epochs + 1):
Пример #15
0
device = torch.device("cuda" if opt.cuda else "cpu")
seed = 1000
np.random.seed(seed)
torch.manual_seed(seed)

train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

netG = Generator(UPSCALE_FACTOR).to(device)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator().to(device)
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss().to(device)

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

for epoch in range(1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
Пример #16
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
Пример #17
0
def main(step, dataset, data_dir, data_dir_bias, model_name):

    global args, model, netContent, lr

    args = parser.parse_args()
    lr = args.lr
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    netG = GeneratorMM(args.upscale_factor)
    n_parameters = sum([p.data.nelement() for p in netG.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    netD = DiscriminatorMM()
    n_parameters = sum([p.data.nelement() for p in netD.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    generator_criterion = GeneratorLoss()
    netG.set_multiple_gpus()
    netD.set_multiple_gpus()
    if step > 0:
        model_dir = data_dir + '/model/modelG_' + str(step) + '.pkl'
        netG.load_state_dict(torch.load(model_dir))

        model_dir = data_dir + '/model/modelD_' + str(step) + '.pkl'
        netD.load_state_dict(torch.load(model_dir))
    if args.cuda:
        netG = netG.cuda()
        netD = netD.cuda()
        generator_criterion = generator_criterion.cuda()
    cudnn.benchmark = True

    optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.9, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.9, 0.999))

    # Load the dataset
    kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}

    train_loader = torch.utils.data.DataLoader(
        ShepardMetzler(root_dir=data_dir_bias + '/torch_super/' + model_name +
                       '/train/' + '/bias_0/'),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    test_loader = torch.utils.data.DataLoader(
        ShepardMetzler(root_dir=data_dir_bias + '/torch_super/' + model_name +
                       '/test/' + '/bias_0/'),
        batch_size=args.test_batch_size,
        shuffle=False,
        **kwargs)

    lRecord = []
    generator_loss_train = []
    discriminator_loss_train = []
    a_loss_train = []
    p_loss_train = []
    i_loss_train = []
    t_loss_train = []

    generator_loss_test = []
    discriminator_loss_test = []
    a_loss_test = []
    p_loss_test = []
    i_loss_test = []
    t_loss_test = []
    start = 0
    for epoch in range(step + 1, args.epochs + step + 1):
        generator_loss, a_loss, p_loss, i_loss, t_loss, discriminator_loss = train(
            train_loader, optimizerG, optimizerD, netG, netD,
            generator_criterion, epoch, lRecord)
        generator_loss_train.append(generator_loss)
        a_loss_train.append(a_loss)
        p_loss_train.append(p_loss)
        i_loss_train.append(i_loss)
        t_loss_train.append(t_loss)
        discriminator_loss_train.append(discriminator_loss)

        lr = adjust_learning_rate(optimizerG, epoch - 1)
        for param_group in optimizerG.param_groups:
            param_group["lr"] = lr

        lr = adjust_learning_rate(optimizerD, epoch - 1)
        for param_group in optimizerD.param_groups:
            param_group["lr"] = lr

        if epoch % args.log_interval_test == 0:
            test_dir = data_dir + '/test/' + 'model' + str(
                epoch) + '_scene' + str(start + 1) + '/'
            if os.path.exists(test_dir) == False:
                os.mkdir(test_dir)

            generator_loss, a_loss, p_loss, i_loss, t_loss, discriminator_loss = test(
                netG, netD, start, test_loader, epoch, generator_criterion,
                lRecord, test_dir)
            start = (start + 1) % len(test_loader)
            generator_loss_test.append(generator_loss)
            a_loss_test.append(a_loss)
            p_loss_test.append(p_loss)
            i_loss_test.append(i_loss)
            t_loss_test.append(t_loss)
            discriminator_loss_test.append(discriminator_loss)

        if epoch % args.log_interval_record == 0:
            SaveRecord(data_dir, epoch, netG, netD, generator_loss_train,
                       a_loss_train, p_loss_train, i_loss_train, t_loss_train,
                       discriminator_loss_train, generator_loss_test,
                       a_loss_test, p_loss_test, i_loss_test, t_loss_test,
                       discriminator_loss_test, lRecord)
Пример #18
0
def main_train(path_trn: str,
               path_val: str,
               crop_size: int,
               upscale_factor: int,
               num_epochs: int,
               num_workers: int,
               to_device: str = 'cuda:0',
               in_memory_trn: bool = False,
               in_memory_val: bool = False,
               batch_size: int = 64,
               step_val: int = 5):
    out_dir = path_trn + '_results_c{}_s{}'.format(crop_size, upscale_factor)
    out_dir_states = out_dir + '_states'
    out_dir_statistics = out_dir + '_staticstics'
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(out_dir_states, exist_ok=True)
    os.makedirs(out_dir_statistics, exist_ok=True)
    path_results_csv = os.path.join(
        out_dir, 'statistics_x{}_train_results.csv'.format(upscale_factor))
    #
    to_device = get_device(to_device)
    train_set = DatasetExtTrn(path_idx=path_trn,
                              crop_lr=crop_size,
                              scale=upscale_factor,
                              in_memory=in_memory_trn).build()
    val_set = DatasetExtVal(path_idx=path_val,
                            crop_lr=crop_size,
                            scale=upscale_factor,
                            in_memory=in_memory_val).build()
    #
    train_loader = DataLoader(dataset=train_set,
                              num_workers=num_workers,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=num_workers,
                            batch_size=1,
                            shuffle=False)
    #
    netG = Generator(upscale_factor).to(to_device)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator().to(to_device)
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))
    generator_criterion = GeneratorLoss().to(to_device)
    #
    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())
    # results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
    for epoch in range(1, num_epochs + 1):
        results_train = train_step(train_loader, netD, netG, optimizerD,
                                   optimizerG, generator_criterion, epoch,
                                   num_epochs)
        # FIXME: seperate function for epoch training
        if (epoch % step_val) == 0:
            results_validation = validation_step(val_loader, netG, out_dir,
                                                 epoch, num_epochs)
            results_save = {**results_train, **results_validation}
            results_save['epoch'] = epoch
            export_results(results_save, path_results_csv)
            # export model
            # save model parameters
            path_state_G = os.path.join(
                out_dir_states,
                'netG_epoch_x{}_{:05d}.pth'.format(upscale_factor, epoch))
            path_state_D = os.path.join(
                out_dir_states,
                'netD_epoch_x{}_{:05d}.pth'.format(upscale_factor, epoch))
            t1 = time.time()
            torch.save(netG.state_dict(), path_state_G)
            torch.save(netD.state_dict(), path_state_D)
            dt = time.time() - t1
            print(
                '\t\t:: dump:generator-model to [{}], dt ~ {:0.2f} (s)'.format(
                    path_state_G, dt))
Пример #19
0
def train_lambda_class():
    data_dir_lr = '../data/split_dataset/'
    data_dir_hr = '../data/split_dataset_SRGAN_' + str(UPSCALE_FACTOR) + os.sep
    train_set = ImageFolderWithPaths_train(data_dir_hr + "train" + os.sep,
                                           data_dir_lr + "train", HR_SIZE,
                                           UPSCALE_FACTOR)
    val_set = ImageFolderWithPaths_val(data_dir_hr + "test" + os.sep,
                                       data_dir_lr + "test", HR_SIZE,
                                       UPSCALE_FACTOR)

    train_loader = DataLoader(dataset=train_set,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)
    class_names = train_set.classes
    print(train_set.classes)
    print(val_set.classes)

    netD_weights = "epochs/weights_halfPipe/weights_"+str(UPSCALE_FACTOR)\
                   + "_dataAug/netD_dataAug_epoch_"+str(UPSCALE_FACTOR)+"_050.pth"
    netG_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) \
                   + "_dataAug/netG_dataAug_epoch_" + str(UPSCALE_FACTOR) + "_050.pth"
    # netD_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) + "_dataAug/best_netD.pth"
    # netG_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) + "_dataAug/best_netG.pth"

    netG = Generator(UPSCALE_FACTOR)
    netG.load_state_dict(torch.load(netG_weights))
    netD = Discriminator()
    netD.load_state_dict(torch.load(netD_weights))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    ####CLASSIFIER
    classifier = models.resnet50(pretrained=True)
    num_ftrs = classifier.fc.in_features
    classifier.fc = nn.Linear(num_ftrs, len(class_names))
    classifier.name = 'resnet50'
    classifier.cuda()
    # if UPSCALE_FACTOR == 1:
    #     weights_class_path = "/home/mrey/ESA/pruebas/multiclass_classification/weights/best_model.pth"
    # else:
    #     weights_class_path = "/home/mrey/ESA/pruebas/multiclass_classification/weights/best_model_SRGAN_"\
    #                          + str(UPSCALE_FACTOR)+".pth"
    weights_class_path = "../data/multiclass_classification/weights/"+\
                         classifier.name+"_best_model_kfold.pth"
    classifier.load_state_dict(torch.load(weights_class_path))
    criterion_classifier = nn.CrossEntropyLoss()
    classifier.eval()
    print(classifier.name)
    print("upscale factor %d" % UPSCALE_FACTOR)
    print("HR size %d" % HR_SIZE)
    #############

    best_netG = copy.deepcopy(netG.state_dict())
    best_netD = copy.deepcopy(netD.state_dict())
    old_PSNR = 0
    old_SSIM = 0
    basic_netG = copy.deepcopy(netG.state_dict())
    basic_netD = copy.deepcopy(netD.state_dict())

    miss_classifications = []

    lambda_values = [2, 1, 0.1, 0.01]

    for lambda_class in lambda_values:
        results = {
            'd_loss': [],
            'g_loss': [],
            'd_score': [],
            'g_score': [],
            'psnr': [],
            'ssim': []
        }
        netG.load_state_dict(basic_netG)
        netD.load_state_dict(basic_netD)
        print("CHOSE LAMBDA {}".format(lambda_class))
        for epoch in range(1, NUM_EPOCHS + 1):
            train_bar = tqdm(train_loader)
            running_results = {
                'batch_sizes': 0,
                'd_loss': 0,
                'g_loss': 0,
                'd_score': 0,
                'g_score': 0
            }

            netG.train()
            netD.train()
            for data, target, label in train_bar:

                # print('')
                # print([data_i.shape for data_i in data.data])
                # print([data_i.shape for data_i in target.data])
                g_update_first = True
                batch_size = data.size(0)
                running_results['batch_sizes'] += batch_size

                ############################
                # (1) Update D network: maximize D(x)-1-D(G(z))
                ###########################
                real_img = Variable(target)
                if torch.cuda.is_available():
                    real_img = real_img.cuda()
                    label = label.to(device)
                z = Variable(data)
                if torch.cuda.is_available():
                    z = z.cuda()
                fake_img = netG(z)

                classifier_outputs = classifier(z)
                _, preds = torch.max(classifier_outputs, 1)
                if preds != label:
                    miss_classifications.append([preds, label])
                    # print(str(preds)+"--"+str(label))
                loss_classifier = criterion_classifier(classifier_outputs,
                                                       label)

                netD.zero_grad()
                real_out = netD(real_img).mean()
                fake_out = netD(fake_img).mean()
                d_loss = 1 - real_out + fake_out
                d_loss.backward(retain_graph=True)
                optimizerD.step()

                ############################
                # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
                ###########################
                netG.zero_grad()
                g_loss = generator_criterion(fake_out, fake_img, real_img,
                                             loss_classifier,
                                             float(lambda_class))
                g_loss.backward()
                optimizerG.step()
                fake_img = netG(z)
                fake_out = netD(fake_img).mean()

                g_loss = generator_criterion(fake_out, fake_img, real_img,
                                             loss_classifier,
                                             float(lambda_class))
                running_results['g_loss'] += g_loss.data * batch_size
                d_loss = 1 - real_out + fake_out
                running_results['d_loss'] += d_loss.data * batch_size
                running_results['d_score'] += real_out.data * batch_size
                running_results['g_score'] += fake_out.data * batch_size

                train_bar.set_description(
                    desc=
                    '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f'
                    % (epoch, NUM_EPOCHS, running_results['d_loss'] /
                       running_results['batch_sizes'],
                       running_results['g_loss'] /
                       running_results['batch_sizes'],
                       running_results['d_score'] /
                       running_results['batch_sizes'],
                       running_results['g_score'] /
                       running_results['batch_sizes']))

            netG.eval()
            out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            val_bar = tqdm(val_loader)
            valing_results = {
                'mse': 0,
                'ssims': 0,
                'psnr': 0,
                'ssim': 0,
                'batch_sizes': 0
            }
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                lr = Variable(val_lr, volatile=True)
                hr = Variable(val_hr, volatile=True)
                if torch.cuda.is_available():
                    lr = lr.cuda()
                    hr = hr.cuda()
                sr = netG(lr)

                batch_mse = ((sr - hr)**2).data.mean()
                valing_results['mse'] += batch_mse * batch_size
                batch_ssim = pytorch_ssim.ssim(sr, hr).data
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10(
                    1 /
                    (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results[
                    'ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc=
                    '[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f'
                    % (valing_results['psnr'], valing_results['ssim']))

                val_images.extend([
                    display_transform()(val_hr_restore.squeeze(0)),
                    display_transform()(hr.data.cpu().squeeze(0)),
                    display_transform()(sr.data.cpu().squeeze(0))
                ])
            # val_images = torch.stack(val_images)
            # val_images = torch.chunk(val_images, val_images.size(0) // 15)
            # val_save_bar = tqdm(val_images, desc='[saving training results]')
            # index = 1
            # for image in val_save_bar:
            #     image = utils.make_grid(image, nrow=3, padding=5)
            #     if DATA_AUG:
            #         utils.save_image(image, out_path + 'dataAug_epoch_%d_index_%d.png' % (epoch, index), padding=5)
            #     else:
            #         utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
            #     index += 1

            # SAVE model parameters
            if epoch % 5 == 0 and epoch != 0:
                out_folder = 'epochs/weights_' + str(UPSCALE_FACTOR)+'_'+str(classifier.name)+'_lambda' + \
                             str(lambda_class)+'_wholePipe/'
                if not os.path.exists(out_folder):
                    os.makedirs(out_folder)
                torch.save(
                    netG.state_dict(), out_folder + 'netG_epoch_%d_%03d.pth' %
                    (UPSCALE_FACTOR, epoch))
                torch.save(
                    netD.state_dict(), out_folder + 'netD_epoch_%d_%03d.pth' %
                    (UPSCALE_FACTOR, epoch))
            # save loss\scores\psnr\ssim
            results['d_loss'].append(running_results['d_loss'] /
                                     running_results['batch_sizes'])
            results['g_loss'].append(running_results['g_loss'] /
                                     running_results['batch_sizes'])
            results['d_score'].append(running_results['d_score'] /
                                      running_results['batch_sizes'])
            results['g_score'].append(running_results['g_score'] /
                                      running_results['batch_sizes'])
            results['psnr'].append(valing_results['psnr'])
            results['ssim'].append(valing_results['ssim'])

            if epoch % 5 == 0 and epoch != 0:
                out_path = 'statistics_wholePipe/'
                if not os.path.exists(out_path):
                    os.makedirs(out_path)
                data_frame = pd.DataFrame(data={
                    'Loss_D': results['d_loss'],
                    'Loss_G': results['g_loss'],
                    'Score_D': results['d_score'],
                    'Score_G': results['g_score'],
                    'PSNR': results['psnr'],
                    'SSIM': results['ssim']
                },
                                          index=range(1, epoch + 1))
                data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) +
                                  '_' + str(classifier.name) + '_lambda' +
                                  str(lambda_class) + '_train_results.csv',
                                  index_label='Epoch')
Пример #20
0
                          num_workers=4,
                          batch_size=64,
                          shuffle=True)
val_loader = DataLoader(dataset=val_set,
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {
    'd_loss': [],
    'g_loss': [],
    'd_score': [],
    'g_score': [],
    'psnr': [],
Пример #21
0
def train_half_pipe():
    if DATA_AUG:
        train_set = TrainDatasetFromFolder('../data/split_SRdataset/train',
                                           hr_size=HR_SIZE,
                                           upscale_factor=UPSCALE_FACTOR)
    else:
        train_set = TrainDatasetFromFolder(
            '/home/mrey/ESA/Dataset/Step2-SuperresolutionWhale/converted_jpg',
            hr_size=HR_SIZE,
            upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('../data/split_SRdataset/test',
                                   hr_size=HR_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=4,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=4,
                            batch_size=1,
                            shuffle=False)

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    best_netG = copy.deepcopy(netG.state_dict())
    best_netD = copy.deepcopy(netD.state_dict())
    old_PSNR = 0.0
    old_SSIM = 0.0

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }

    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {
            'batch_sizes': 0,
            'd_loss': 0,
            'g_loss': 0,
            'd_score': 0,
            'g_score': 0
        }

        netG.train()
        netD.train()
        for data, target in train_bar:

            # print('')
            # print([data_i.shape for data_i in data.data])
            # print([data_i.shape for data_i in target.data])
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img, 0, 0.0)
            g_loss.backward()
            optimizerG.step()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            g_loss = generator_criterion(fake_out, fake_img, real_img, 0, 0.0)
            running_results['g_loss'] += g_loss.data * batch_size
            d_loss = 1 - real_out + fake_out
            running_results['d_loss'] += d_loss.data * batch_size
            running_results['d_score'] += real_out.data * batch_size
            running_results['g_score'] += fake_out.data * batch_size

            train_bar.set_description(
                desc=
                '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' %
                (epoch, NUM_EPOCHS,
                 running_results['d_loss'] / running_results['batch_sizes'],
                 running_results['g_loss'] / running_results['batch_sizes'],
                 running_results['d_score'] / running_results['batch_sizes'],
                 running_results['g_score'] / running_results['batch_sizes']))

        netG.eval()
        out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        val_bar = tqdm(val_loader)
        valing_results = {
            'mse': 0,
            'ssims': 0,
            'psnr': 0,
            'ssim': 0,
            'batch_sizes': 0
        }
        val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = Variable(val_lr, volatile=True)
            hr = Variable(val_hr, volatile=True)
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)

            batch_mse = ((sr - hr)**2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).data
            valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10(
                1 / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results[
                'batch_sizes']
            val_bar.set_description(
                desc=
                '[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f'
                % (valing_results['psnr'], valing_results['ssim']))

            val_images.extend([
                display_transform()(val_hr_restore.squeeze(0)),
                display_transform()(hr.data.cpu().squeeze(0)),
                display_transform()(sr.data.cpu().squeeze(0))
            ])
        # val_images = torch.stack(val_images)
        # val_images = torch.chunk(val_images, val_images.size(0) // 15)
        # val_save_bar = tqdm(val_images, desc='[saving training results]')
        # index = 1
        # for image in val_save_bar:
        #     image = utils.make_grid(image, nrow=3, padding=5)
        #     if DATA_AUG:
        #         utils.save_image(image, out_path + 'dataAug_epoch_%d_index_%d.png' % (epoch, index), padding=5)
        #     else:
        #         utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
        #     index += 1

        # SAVE model parameters
        if valing_results['psnr'] > old_PSNR and valing_results[
                'ssim'] > old_SSIM:
            old_PSNR = valing_results['psnr']
            old_SSIM = valing_results['ssim']
            best_netG = copy.deepcopy(netG.state_dict())
            best_netD = copy.deepcopy(netD.state_dict())
            out_folder = 'epochs/weights_' + str(
                UPSCALE_FACTOR) + '_dataAug_halfPipe/'
            if not os.path.exists(out_folder):
                os.makedirs(out_folder)
            torch.save(best_netG, out_folder + 'best_netG.pth')
            torch.save(best_netD, out_folder + 'best_netD.pth')

        if epoch % 10 == 0 and epoch != 0:
            if DATA_AUG:
                out_folder = 'epochs/weights_' + str(
                    UPSCALE_FACTOR) + '_dataAug_halfPipe/'
                if not os.path.exists(out_folder):
                    os.makedirs(out_folder)
                torch.save(
                    netG.state_dict(), out_folder +
                    'netG_dataAug_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch))
                torch.save(
                    netD.state_dict(), out_folder +
                    'netD_dataAug_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch))
            else:
                out_folder = 'epochs/weights_' + str(
                    UPSCALE_FACTOR) + '_halfPipe/'
                if not os.path.exists(out_folder):
                    os.makedirs(out_folder)
                torch.save(
                    netG.state_dict(), out_folder + 'netG_epoch_%d_%03d.pth' %
                    (UPSCALE_FACTOR, epoch))
                torch.save(
                    netD.state_dict(), out_folder + 'netD_epoch_%d_%03d.pth' %
                    (UPSCALE_FACTOR, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] /
                                 running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] /
                                 running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] /
                                  running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] /
                                  running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics_halfPipe/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            data_frame = pd.DataFrame(data={
                'Loss_D': results['d_loss'],
                'Loss_G': results['g_loss'],
                'Score_D': results['d_score'],
                'Score_G': results['g_score'],
                'PSNR': results['psnr'],
                'SSIM': results['ssim']
            },
                                      index=range(1, epoch + 1))
            if DATA_AUG:
                data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) +
                                  '_dataAug_train_results.csv',
                                  index_label='Epoch')
            else:
                data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) +
                                  '_train_results.csv',
                                  index_label='Epoch')