예제 #1
0
    def build_model(self):
        self.netG = Generator(n_residual_blocks=self.num_residuals,
                              upsample_factor=self.upscale_factor,
                              base_filter=64,
                              num_channel=1).to(self.device)
        self.netD = Discriminator(base_filter=64,
                                  num_channel=1).to(self.device)
        self.feature_extractor = vgg16(pretrained=True)
        self.netG.weight_init(mean=0.0, std=0.2)
        self.netD.weight_init(mean=0.0, std=0.2)
        self.criterionG = nn.MSELoss()
        self.criterionD = nn.BCELoss()
        torch.manual_seed(self.seed)

        if self.GPU_IN_USE:
            torch.cuda.manual_seed(self.seed)
            self.feature_extractor.cuda()
            cudnn.benchmark = True
            self.criterionG.cuda()
            self.criterionD.cuda()

        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=self.lr,
                                     betas=(0.9, 0.999))
        self.optimizerD = optim.SGD(self.netD.parameters(),
                                    lr=self.lr / 100,
                                    momentum=0.9,
                                    nesterov=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizerG, milestones=[50, 75, 100], gamma=0.5)  # lr decay
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizerD, milestones=[50, 75, 100], gamma=0.5)  # lr decay
예제 #2
0
HEIGHT = opt.height
batch_size = opt.batch_size
dataset_size = opt.dataset_size
lr = opt.lr
# train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
# val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
# train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

train_loader, val_loader = Dataset_OnlyHR.get_data_loaders(batch_size, dataset_size=dataset_size, validation_split=0.2)
num_train_batches = len(train_loader)
num_val_batches = len(val_loader)

netG = FRVSR(batch_size, lr_width=WIDTH, lr_height=HEIGHT)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters(), lr=lr)
optimizerD = optim.Adam(netD.parameters(), lr=lr)

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
예제 #3
0
class SRGANTrainer(object):
    def __init__(self, config, training_loader, testing_loader):
        super(SRGANTrainer, self).__init__()
        self.GPU_IN_USE = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
        self.netG = None
        self.netD = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.epoch_pretrain = 10
        self.criterionG = None
        self.criterionD = None
        self.optimizerG = None
        self.optimizerD = None
        self.feature_extractor = None
        self.scheduler = None
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.num_residuals = 16
        self.training_loader = training_loader
        self.testing_loader = testing_loader

        self.load = config.load
        self.model_path = 'models/SRGAN/' + str(self.upscale_factor)

    def build_model(self):
        self.netG = Generator(n_residual_blocks=self.num_residuals, upsample_factor=self.upscale_factor, base_filter=64, num_channel=1).to(self.device)
        self.netD = Discriminator(base_filter=64, num_channel=1).to(self.device)
        self.feature_extractor = vgg16(pretrained=True)
        self.netG.weight_init(mean=0.0, std=0.2)
        self.netD.weight_init(mean=0.0, std=0.2)
        self.criterionG = nn.MSELoss()
        self.criterionD = nn.BCELoss()
        torch.manual_seed(self.seed)

        if self.GPU_IN_USE:
            torch.cuda.manual_seed(self.seed)
            self.feature_extractor.cuda()
            cudnn.benchmark = True
            self.criterionG.cuda()
            self.criterionD.cuda()

        self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999))
        self.optimizerD = optim.SGD(self.netD.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerG, milestones=[50, 75, 100], gamma=0.5)  # lr decay
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerD, milestones=[50, 75, 100], gamma=0.5)  # lr decay

    @staticmethod
    def to_data(x):
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def save(self, epoch):
        g_model_out_path = self.model_path + "/g_model_{}.pth".format(epoch)
        d_model_out_path = self.model_path + "/d_model_{}.pth".format(epoch)
        torch.save(self.netG, g_model_out_path)
        torch.save(self.netD, d_model_out_path)
        print("Checkpoint saved to {}".format(g_model_out_path))
        print("Checkpoint saved to {}".format(d_model_out_path))

    def pretrain(self):
        self.netG.train()
        for batch_num, (data, target) in enumerate(self.training_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.netG.zero_grad()
            loss = self.criterionG(self.netG(data), target)
            loss.backward()
            self.optimizerG.step()

    def train(self):
        # models setup
        self.netG.train()
        self.netD.train()
        g_train_loss = 0
        d_train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            # setup noise
            real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
            fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
            data, target = data.to(self.device), target.to(self.device)

            # Train Discriminator
            self.optimizerD.zero_grad()
            d_real = self.netD(target)
            d_real_loss = self.criterionD(d_real, real_label)

            d_fake = self.netD(self.netG(data))
            d_fake_loss = self.criterionD(d_fake, fake_label)
            d_total = d_real_loss + d_fake_loss
            d_train_loss += d_total.item()
            d_total.backward()
            self.optimizerD.step()

            # Train generator
            self.optimizerG.zero_grad()
            g_real = self.netG(data)
            g_fake = self.netD(g_real)
            gan_loss = self.criterionD(g_fake, real_label)
            mse_loss = self.criterionG(g_real, target)

            g_total = mse_loss + 1e-3 * gan_loss
            g_train_loss += g_total.item()
            g_total.backward()
            self.optimizerG.step()

            progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))

        print("    Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader)))

    def test(self):
        self.netG.eval()
        avg_psnr = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.testing_loader):
                data, target = data.to(self.device), target.to(self.device)
                prediction = self.netG(data)
                mse = self.criterionG(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr
                progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        print("    Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
        return avg_psnr / len(self.testing_loader)

    def run(self):
        self.build_model()
        all_epoch_psnrs = []
        for epoch in range(1, self.epoch_pretrain + 1):
            self.pretrain()
            print("{}/{} pretrained".format(epoch, self.epoch_pretrain))

        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            epoch_psnr = self.test()
            all_epoch_psnrs.append(epoch_psnr)
            self.scheduler.step()
            # if epoch == self.nEpochs:
            self.save_model(epoch)
        
        best_epoch = argmax(all_epoch_psnrs) + 1
        print("Best epoch: model_{} with PSNR {}".format(best_epoch, all_epoch_psnrs[best_epoch - 1]))
        copyfile(self.model_path + "/model_{}.pth".format(best_epoch), self.model_path + "/best_model.pth")

        with open(self.model_path + '/metrics.txt', 'w+') as metricsfile:
            print("Saving metrics")
            for i, psnr in enumerate(all_epoch_psnrs):
                metricsfile.write("{},{}\n".format(i+1, psnr))
            metricsfile.write("Best epoch: model_{} with PSNR {}\n".format(best_epoch, all_epoch_psnrs[best_epoch - 1]))
예제 #4
0
# Load dataset
trainLoader, valLoader = DatasetLoader.get_data_loaders(
    batchSize, dataset_size=dataset_size, validation_split=0.1)
numTrainBatches = len(trainLoader)
numValBatches = len(valLoader)

# Initialize Logger
logger.initLogger(args.debug)

# Use Generator as FRVSR
netG = FRVSR(batchSize, lr_width=WIDTH, lr_height=HEIGHT)
print('# of Generator parameters:',
      sum(param.numel() for param in netG.parameters()))

# Use Discriminator from SRGAN
netD = Discriminator()
print('# of Discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))

generatorCriterion = GeneratorLoss()

if torch.cuda.is_available():

    def printCUDAStats():
        logger.info("# of CUDA devices detected: %s",
                    torch.cuda.device_count())
        logger.info("Using CUDA device #: %s", torch.cuda.current_device())
        logger.info("CUDA device name: %s",
                    torch.cuda.get_device_name(torch.cuda.current_device()))

    printCUDAStats()
def main():
    """ Lets begin the training process! """

    args = parser.parse_args()

    # Initialize Logger
    logger.initLogger(args.debug)

    # Load dataset
    logger.info('==> Loading datasets')
    # print(args.file_list)
    # sys.exit()

    train_set = get_training_set(args.data_dir, args.nFrames,
                                 args.upscale_factor, args.data_augmentation,
                                 args.file_list, args.other_dataset,
                                 args.patch_size, args.future_frame)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=args.threads,
                                      batch_size=args.batchSize,
                                      shuffle=True)

    # Use generator as RBPN
    netG = RBPN(num_channels=3,
                base_filter=256,
                feat=64,
                num_stages=3,
                n_resblock=5,
                nFrames=args.nFrames,
                scale_factor=args.upscale_factor)
    logger.info('# of Generator parameters: %s',
                sum(param.numel() for param in netG.parameters()))

    # Use DataParallel?
    if args.useDataParallel:
        gpus_list = range(args.gpus)
        netG = torch.nn.DataParallel(netG, device_ids=gpus_list)

    # Use discriminator from SRGAN
    netD = Discriminator()
    logger.info('# of Discriminator parameters: %s',
                sum(param.numel() for param in netD.parameters()))

    # Generator loss
    generatorCriterion = nn.L1Loss() if not args.APITLoss else GeneratorLoss()

    # Specify device
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu_mode else "cpu")

    if args.gpu_mode and torch.cuda.is_available():
        utils.printCUDAStats()

        netG.cuda()
        netD.cuda()

        netG.to(device)
        netD.to(device)

        generatorCriterion.cuda()

    # Use Adam optimizer
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(0.9, 0.999),
                            eps=1e-8)

    if args.APITLoss:
        logger.info(
            "Generator Loss: Adversarial Loss + Perception Loss + Image Loss + TV Loss"
        )
    else:
        logger.info("Generator Loss: L1 Loss")

    # print iSeeBetter architecture
    utils.printNetworkArch(netG, netD)

    if args.pretrained:
        modelPath = os.path.join(args.save_folder + args.pretrained_sr)
        utils.loadPreTrainedModel(gpuMode=args.gpu_mode,
                                  model=netG,
                                  modelPath=modelPath)

    # sys.exit()
    for epoch in range(args.start_epoch, args.nEpochs + 1):
        runningResults = trainModel(epoch, training_data_loader, netG, netD,
                                    optimizerD, optimizerG, generatorCriterion,
                                    device, args)

        if (epoch + 1) % (args.snapshots) == 0:
            saveModelParams(epoch, runningResults, netG, netD)
class SRGANTrainer(Trainer):
    def __init__(self, config, training_loader, testing_loader):
        super(SRGANTrainer, self).__init__()
        self.config = config
        self.GPU_IN_USE = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
        self.netG = None
        self.netD = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.epoch_pretrain = 10
        self.criterionG = None
        self.criterionD = None
        self.optimizerG = None
        self.optimizerD = None
        self.feature_extractor = None
        self.scheduler = None
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.num_residuals = 16
        self.training_loader = training_loader
        self.testing_loader = testing_loader

    def build_model(self):
        self.netG = Generator(n_residual_blocks=self.num_residuals,
                              upsample_factor=self.upscale_factor,
                              base_filter=64,
                              num_channel=1).to(self.device)
        self.netD = Discriminator(base_filter=64,
                                  num_channel=1).to(self.device)
        self.feature_extractor = vgg16(pretrained=True)
        self.netG.weight_init(mean=0.0, std=0.2)
        self.netD.weight_init(mean=0.0, std=0.2)
        self.criterionG = nn.MSELoss()
        self.criterionD = nn.BCELoss()
        torch.manual_seed(self.seed)

        if self.GPU_IN_USE:
            torch.cuda.manual_seed(self.seed)
            self.feature_extractor.cuda()
            cudnn.benchmark = True
            self.criterionG.cuda()
            self.criterionD.cuda()

        # self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999))
        self.optimizer = optim.Adam(self.netG.parameters(),
                                    lr=self.lr,
                                    betas=(0.9, 0.999))
        self.optimizerD = optim.SGD(self.netD.parameters(),
                                    lr=self.lr / 100,
                                    momentum=0.9,
                                    nesterov=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[50, 75, 100], gamma=0.5)  # lr decay
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizerD, milestones=[50, 75, 100], gamma=0.5)  # lr decay

        self.set_optimizer(_type='gan')

    @staticmethod
    def to_data(x):
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def save(self):
        g_model_out_path = "SRGAN_Generator_model_path.pth"
        d_model_out_path = "SRGAN_Discriminator_model_path.pth"
        torch.save(self.netG, g_model_out_path)
        torch.save(self.netD, d_model_out_path)
        print("Checkpoint saved to {}".format(g_model_out_path))
        print("Checkpoint saved to {}".format(d_model_out_path))

    def pretrain(self):
        self.netG.train()
        for batch_num, (data, target) in enumerate(self.training_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.netG.zero_grad()
            loss = self.criterionG(self.netG(data), target)
            loss.backward()
            # self.optimizerG.step()
            self.optimizer.step()

    def train(self):
        # models setup
        self.netG.train()
        self.netD.train()
        g_train_loss = 0
        d_train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            # setup noise
            real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
            fake_label = torch.zeros(data.size(0),
                                     data.size(1)).to(self.device)
            data, target = data.to(self.device), target.to(self.device)

            # Train Discriminator
            self.optimizerD.zero_grad()
            d_real = self.netD(target)
            d_real_loss = self.criterionD(d_real, real_label)

            d_fake = self.netD(self.netG(data))
            d_fake_loss = self.criterionD(d_fake, fake_label)
            d_total = d_real_loss + d_fake_loss
            d_train_loss += d_total.item()
            d_total.backward()
            self.optimizerD.step()

            # Train generator
            # self.optimizerG.zero_grad()
            self.optimizer.zero_grad()
            g_real = self.netG(data)
            g_fake = self.netD(g_real)
            gan_loss = self.criterionD(g_fake, real_label)
            mse_loss = self.criterionG(g_real, target)

            g_total = mse_loss + 1e-3 * gan_loss
            g_train_loss += g_total.item()
            g_total.backward()
            # self.optimizerG.step()
            self.optimizer.step()

            total_time = progress_bar(
                batch_num, len(self.training_loader),
                'G_Loss: %.4f | D_Loss: %.4f' %
                (g_train_loss / (batch_num + 1), d_train_loss /
                 (batch_num + 1)))

        avg_loss = g_train_loss / len(self.training_loader)
        return [avg_loss, total_time]

    def test(self):
        self.netG.eval()
        avg_psnr = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.testing_loader):
                data, target = data.to(self.device), target.to(self.device)
                prediction = self.netG(data)
                mse = self.criterionG(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr
                total_time = progress_bar(
                    batch_num, len(self.testing_loader),
                    'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        avg_psnr = avg_psnr / len(self.testing_loader)
        return [avg_psnr, total_time]

    def run(self):
        self.build_model()
        for epoch in range(1, self.epoch_pretrain + 1):
            self.pretrain()
            print("{}/{} pretrained".format(epoch, self.epoch_pretrain))

        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            avg_loss = self.train()
            avg_psnr = self.test()
            self.scheduler.step(epoch)
            if epoch == self.nEpochs:
                self.save()
예제 #7
0
# Load dataset
trainLoader, valLoader = DatasetLoader.get_data_loaders(
    batchSize, dataset_size=dataset_size, validation_split=0.1)
numTrainBatches = len(trainLoader)
numValBatches = len(valLoader)

# Initialize Logger
logger.initLogger(args.debug)

# Use Generator as FRVSR
netG = FRVSR(batchSize, lr_width=WIDTH, lr_height=HEIGHT)
print('# of Generator parameters:',
      sum(param.numel() for param in netG.parameters()))

# Use Discriminator from SRGAN
netD = Discriminator()
print('# of Discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))

generatorCriterion = GeneratorLoss()

if torch.cuda.is_available():

    def printCUDAStats():
        logger.info("# of CUDA devices detected: %s",
                    torch.cuda.device_count())
        logger.info("Using CUDA device #: %s", torch.cuda.current_device())
        logger.info("CUDA device name: %s",
                    torch.cuda.get_device_name(torch.cuda.current_device()))

    printCUDAStats()
예제 #8
0
class SRGANTrainer(object):
    def __init__(self, config, training_loader, testing_loader):
        self.netG = None
        self.netD = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.epoch_pretrain = 10
        self.criterionG = None
        self.criterionD = None
        self.optimizerG = None
        self.optimizerD = None
        self.feature_extractor = None
        self.scheduler = None
        self.GPU_IN_USE = torch.cuda.is_available()
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.num_residuals = 16
        self.training_loader = training_loader
        self.testing_loader = testing_loader

    def build_model(self):
        self.netG = Generator(n_residual_blocks=self.num_residuals,
                              upsample_factor=self.upscale_factor,
                              base_filter=64,
                              num_channel=1)
        self.netD = Discriminator(base_filter=64, num_channel=1)
        self.feature_extractor = vgg16(pretrained=True)
        self.netG.weight_init(mean=0.0, std=0.2)
        self.netD.weight_init(mean=0.0, std=0.2)
        self.criterionG = nn.MSELoss()
        self.criterionD = nn.BCELoss()
        torch.manual_seed(self.seed)

        if self.GPU_IN_USE:
            torch.cuda.manual_seed(self.seed)
            self.netG.cuda()
            self.netD.cuda()
            self.feature_extractor.cuda()
            cudnn.benchmark = True
            self.criterionG.cuda()
            self.criterionD.cuda()

        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=self.lr,
                                     betas=(0.9, 0.999))
        self.optimizerD = optim.SGD(self.netD.parameters(),
                                    lr=self.lr / 100,
                                    momentum=0.9,
                                    nesterov=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizerG, milestones=[50, 75, 100], gamma=0.5)  # lr decay
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizerD, milestones=[50, 75, 100], gamma=0.5)  # lr decay

    @staticmethod
    def to_variable(x):
        """Convert tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    @staticmethod
    def to_data(x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def save(self):
        g_model_out_path = "SRGAN_Generator_model_path.pth"
        d_model_out_path = "SRGAN_Discriminator_model_path.pth"
        torch.save(self.netG, g_model_out_path)
        torch.save(self.netD, d_model_out_path)
        print("Checkpoint saved to {}".format(g_model_out_path))
        print("Checkpoint saved to {}".format(d_model_out_path))

    def pretrain(self):
        self.netG.train()
        for batch_num, (data, target) in enumerate(self.training_loader):
            if self.GPU_IN_USE:
                data, target = Variable(data).cuda(), Variable(target).cuda()
            self.netG.zero_grad()
            loss = self.criterionG(self.netG(data), target)
            loss.backward()
            self.optimizerG.step()

    def train(self):
        """
        data: [torch.cuda.FloatTensor], 4 batches: [64, 64, 64, 8]
        """
        # models setup
        self.netG.train()
        self.netD.train()
        g_train_loss = 0
        d_train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            # setup noise
            real_label = self.to_variable(
                torch.ones(data.size(0), data.size(1)))
            fake_label = self.to_variable(
                torch.zeros(data.size(0), data.size(1)))
            if self.GPU_IN_USE:
                data, target = Variable(data).cuda(), Variable(target).cuda()

            # Train Discriminator
            self.optimizerD.zero_grad()
            d_real = self.netD(target)
            d_real_loss = self.criterionD(d_real, real_label)

            d_fake = self.netD(self.netG(data))
            d_fake_loss = self.criterionD(d_fake, fake_label)
            d_total = d_real_loss + d_fake_loss
            d_train_loss += d_total.data[0]
            d_total.backward()
            self.optimizerD.step()

            # Train generator
            self.optimizerG.zero_grad()
            g_real = self.netG(data)
            g_fake = self.netD(g_real)
            gan_loss = self.criterionD(g_fake, real_label)
            mse_loss = self.criterionG(g_real, target)

            g_total = mse_loss + 1e-3 * gan_loss
            g_train_loss += g_total.data[0]
            g_total.backward()
            self.optimizerG.step()

            progress_bar(
                batch_num, len(self.training_loader),
                'G_Loss: %.4f | D_Loss: %.4f' %
                (g_train_loss / (batch_num + 1), d_train_loss /
                 (batch_num + 1)))

        print("    Average G_Loss: {:.4f}".format(g_train_loss /
                                                  len(self.training_loader)))

    def test(self):
        """
        data: [torch.cuda.FloatTensor], 10 batches: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
        """
        self.netG.eval()
        avg_psnr = 0
        for batch_num, (data, target) in enumerate(self.testing_loader):
            if self.GPU_IN_USE:
                data, target = Variable(data).cuda(), Variable(target).cuda()

            prediction = self.netG(data)
            mse = self.criterionG(prediction, target)
            psnr = 10 * log10(1 / mse.data[0])
            avg_psnr += psnr
            progress_bar(batch_num, len(self.testing_loader),
                         'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        print("    Average PSNR: {:.4f} dB".format(avg_psnr /
                                                   len(self.testing_loader)))

    def validate(self):
        self.build_model()
        for epoch in range(1, self.epoch_pretrain + 1):
            self.pretrain()
            print("{}/{} pretrained".format(epoch, self.epoch_pretrain))

        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            self.test()
            self.scheduler.step(epoch)
            if epoch == self.nEpochs:
                self.save()