Exemplo n.º 1
0
def train(net, trainloader, device, optimizer, criterion, memory_back=False):
    """One epoch training of a network.
    
    :param net: The given network.
    :param trainloader: Pytorch DataLoader (train set)
    :param device: Either 'cuda' or 'cpu'
    :param optimizer: The used optimizer.
    :param criterion: The loss function.
    :param memory_back: Whether or not for each batch the memory of the optimizer (if it has one) should be 
                        temporarily added back to the net's parameters to compute the several metrics 
                        with these new parameters. It doesn't change the final net's parameters.
    
    :return:    (train_loss, train_acc, train_loss_with_memory_back, train_acc_with_memory_back, 
                L1/L2 norm ratio of the gradients, L1/L2 norm ratio of g)
    """
    net.train()
    train_loss = 0
    mback_train_loss = 0
    correct = 0
    mback_correct = 0
    total = 0
    norm_ratio_val = 0
    corrected_norm_ratio_val = 0

    # -------------------------------------
    # Loop for training
    # -------------------------------------
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        if memory_back:
            with TemporarilyAddMemory(optimizer):  # import convex_opt
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                mback_train_loss += loss.item()
                _, predicted = outputs.max(1)
                mback_correct += predicted.eq(targets).sum().item()

        # train using optimizer
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        norm_ratio_val += optimizer.gradient_norms_ratio()
        corrected_norm_ratio_val += optimizer.corrected_gradient_norms_ratio()

        train_loss += loss.item()  # loss value
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()  # acc value

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    loss = train_loss/(batch_idx + 1)
    acc = 100. * correct/total

    return loss, acc, mback_train_loss/(batch_idx + 1), 100. * mback_correct/total, norm_ratio_val/(batch_idx + 1), \
           corrected_norm_ratio_val/(batch_idx + 1)
Exemplo n.º 2
0
def get_inception_activations(imgs_path, batch_size):
    print("Calculating Activation")
    num_imgs = len(imgs_path)
    inception_network = PartialInceptionNetwork()
    inception_network = to_cuda(inception_network)
    inception_network.eval()
    n_batches = int(np.ceil(num_imgs / batch_size))
    inception_activations = np.zeros((num_imgs, 2048), dtype=np.float32)
    index = 0
    while len(imgs_path) != 0:
        imgs_path, new_batch_size, tensors = preprocess_imgs(
            imgs_path, batch_size)
        activations = inception_network(tensors)
        activations = activations.detach().cpu().numpy()
        assert activations.shape == (
            tensors.shape[0],
            2048), "Expexted output shape to be: {}, but was: {}".format(
                (tensors.shape[0], 2048), activations.shape)
        start_index = index * batch_size
        end_index = start_index + new_batch_size
        index += 1
        inception_activations[start_index:end_index, :] = activations
        current = num_imgs - len(imgs_path)
        progress_bar(current, num_imgs, barLength=20)
    print("--Done--")

    return inception_activations
Exemplo n.º 3
0
def generate_img(model, config, saved_path, show, generation_size, alpha=0.4, batch_size=32):
    if not os.path.exists(saved_path):
        os.makedirs(saved_path)
    cnt = 0
    for i in range(math.ceil(generation_size/batch_size)):
        if i == math.ceil(generation_size/batch_size) - 1:
            batch_size = generation_size - i*batch_size
        test_samples_z = torch.randn(batch_size, config['z_dim'], dtype=torch.float32).to(device)
        with torch.no_grad():
            generated_images = model.generate(test_samples_z, final_resolution_idx=model.res_idx, alpha=alpha)
            generated_images = generated_images * 0.5 + 0.5
            # generated_images = inverse_normalize(tensor=generated_images, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

        for tensor in generated_images:
            cvimg = single_tensor_to_img(tensor, 128, 128)
            image_name = str(cnt) + ".jpg"
            cnt += 1
            file_name = os.path.join(saved_path, image_name)
            if show == True:
                cv2.imshow("restored_image", cvimg)
                cv2.waitKey(0)
            cvimg = cvimg * 255
            cv2.imwrite(file_name, cvimg)
        progress_bar(i, math.ceil(generation_size/batch_size))
    cv2.destroyAllWindows()
    print("Finished")
Exemplo n.º 4
0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    idx_list=[]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    #joblib.dump(idx_list,'masked_contour_correct_idx.z')
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoints'):
            os.mkdir('checkpoints')
        torch.save(state, pth_path)
        best_acc = acc
Exemplo n.º 5
0
def test(net, testloader, device, optimizer, criterion, memory_back=False):
    """One test evaluation of a network.
    
    :param net: The given network.
    :param testloader: Pytorch DataLoader (train set)
    :param device: Either 'cuda' or 'cpu'
    :param optimizer: The used optimizer.
    :param criterion: The loss function.
    :param memory_back: Whether or not for each batch the memory of the optimizer (if it has one) should be 
                        temporarily added back to the net's parameters to compute the several metrics 
                        with these new parameters. It doesn't change the final net's parameters.
    
    :return: (train_loss, train_acc, train_loss_with_memory_back, train_acc_with_memory_back)
    """
    net.eval()
    test_loss = 0
    mback_test_loss = 0
    correct = 0
    mback_correct = 0
    total = 0

    # -------------------------------------
    # Loop for testing
    # -------------------------------------
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)

            if memory_back:
                with TemporarilyAddMemory(optimizer):
                    outputs = net(inputs)
                    loss = criterion(outputs, targets)
                    mback_test_loss += loss.item()
                    _, predicted = outputs.max(1)
                    mback_correct += predicted.eq(targets).sum().item()

            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    loss = test_loss / (batch_idx + 1)
    acc = 100. * correct / total

    return loss, acc, mback_test_loss / (batch_idx + 1), 100. * mback_correct / total
Exemplo n.º 6
0
def pareto_test_main(opt, logger, model, test_loader, export_images=False):
    test_set_name = test_loader.dataset.opt['name']
    logger.info('\nTesting [{:s}]...'.format(test_set_name))

    idx = 0
    avg_psnr = 0
    avg_niqe = 0

    for batch_num, val_data in enumerate(test_loader):
        img_name = osp.splitext(osp.basename(val_data['LQ_path'][0]))[0]
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        model.feed_data(val_data)
        model.test()

        visuals = model.get_current_visuals()
        sr_img = util.tensor2img(visuals['SR'])  # uint8
        # ground truth image
        # gt_img = util.tensor2img(visuals['GT'])  # uint8

        # Save SR images for reference
        if export_images:
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir,
                                         img_name + suffix + '.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

        # calculate PSNR
        item_psnr = util.tensor_psnr(model.real_H, model.fake_H)
        if math.isfinite(item_psnr):
            avg_psnr += item_psnr
            idx += 1

        # calculate NIQE
        item_niqe = util.tensor_niqe(model.fake_H)
        # item_niqe = 0
        if math.isfinite(item_niqe):
            avg_niqe += item_niqe

        progress_bar(batch_num, len(test_loader), msg=None)

    avg_psnr = avg_psnr / idx
    avg_niqe = avg_niqe / idx
    logger.info("PSNR {} NIQE {}".format(avg_psnr, avg_niqe))

    return avg_psnr, avg_niqe
Exemplo n.º 7
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    encoder.train()
    decoder.train()
    classifier.train()
    train_loss1 = 0
    train_loss2 = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()
        feature = encoder(inputs)
        picture = decoder(feature)
        picture_ = picture.data.detach().cpu()
        input_pic = torch.Tensor(
            normalize(picture_, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))).to(device)
        input_pic.requires_grad = True
        y = classifier(input_pic)
        loss1 = criterion1(picture, inputs)
        loss2 = criterion2(y, targets)
        loss1.backward(retain_graph=True)
        if attenuation != "nc":
            loss2.backward()
            picture.backward(input_pic.grad / attenuation)
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()

        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        _, predicted = y.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx, len(trainloader),
            'Loss1: %.3f | Loss2: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss1 / (batch_idx + 1), train_loss2 /
             (batch_idx + 1), 100. * correct / total, correct, total))
Exemplo n.º 8
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
Exemplo n.º 9
0
predecessors_list: dict = {}

vv = set()
config.logger.info(config.__str__())
count_up = 0
count_down = config.depth_size
while True:
    """
        This while True loop makes sure that we get all of the transitions from the transition generator.
        When no more transitions are left, the buildin call "next" will throw a StopIteration exception.
        This is catched, and we append a few labels to make atl_generation easier. (on the last state, and on the first)
    """
    try:
        count_up += 1
        count_down -= 1
        progress_bar(count_up, config.depth_size)
        transition = next(transitions)
        result['labeling'].append(transition[0])
        result['transitions'].append(transition[1])
        result['moves'].append(transition[2])
        config.logger.info(f'Countdown: {count_down}')
        for key, value in transition[3].items():
            if predecessors_list.get(key) is None:
                predecessors_list[key] = value
            else:
                [predecessors_list[key].add(x) for x in value]
    except StopIteration:
        result['labeling'][-1].append(10)
        result['labeling'][0].append(11)
        break
config.logger.info('done')
Exemplo n.º 10
0
def test(epoch):
    global best_loss
    global save_df
    encoder.eval()
    decoder.eval()
    classifier.eval()
    train_loss1 = 0
    train_loss2 = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        # print_real_pic(inputs)
        inputs, targets = inputs.to(device), targets.to(device)
        feature = encoder(inputs)
        picture = decoder(feature)
        # print_comp_pic(picture.cpu())
        # print('Done.')
        # sleep(1000)
        picture_ = picture.data.detach().cpu()
        input_pic = torch.Tensor(
            normalize(picture_, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))).to(device)
        y = classifier(input_pic)
        loss1 = criterion1(picture, inputs)
        loss2 = criterion2(y, targets)

        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        _, predicted = y.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx, len(testloader),
            'Loss1: %.6f | Loss2: %.4f | Acc: %.3f%% (%d/%d)' %
            (train_loss1 / (batch_idx + 1), train_loss2 /
             (batch_idx + 1), 100. * correct / total, correct, total))

        if batch_idx == 585:
            save_df = save_df.append(
                pd.DataFrame(
                    {
                        'epoch': epoch,
                        'MSE_loss': train_loss1 / (batch_idx + 1),
                        'classifier_loss': train_loss2 / (batch_idx + 1),
                        'accuracy': correct / total
                    },
                    index=[0]))
            if not os.path.isdir('dataframe'):
                os.mkdir('dataframe')
            save_df.to_csv('./dataframe/result_' + codir + '.csv', index=False)

    # Save checkpoint.
    acc = 100. * correct / total
    if train_loss1 < best_loss:
        print('Saving..')
        state = {
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'classifier': classifier.state_dict(),
            'loss': train_loss1,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_' + codir + '.t7')
        best_loss = train_loss1
Exemplo n.º 11
0
def psnr_main(opt,
              train_loader,
              val_loader,
              train_sampler,
              logger,
              resume_state=None,
              tb_logger=None,
              rank=-1):
    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    total_epochs = int(opt['train']['nepochs'])
    best_psnr = 0
    patience = 0
    all_results = []
    for epoch in range(start_epoch, total_epochs):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for batch_num, train_data in enumerate(train_loader):
            current_step += 1

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            progress_bar(batch_num, len(train_loader), msg=None)

        # log
        if epoch % opt['logger']['print_freq'] == 0:
            logs = model.get_current_log()
            message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                epoch, current_step, model.get_current_learning_rate())
            for k, v in logs.items():
                message += '{:s}: {:.4e} '.format(k, v)
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    if rank <= 0:
                        tb_logger.add_scalar(k, v, current_step)
            if rank <= 0:
                logger.info(message)

        # batched validation
        if epoch % opt['train']['val_freq'] == 0 and rank <= 0:
            avg_psnr = 0.0
            idx = 0
            for batch_num, val_data in enumerate(val_loader):
                img_name = os.path.splitext(
                    os.path.basename(val_data['LQ_path'][0]))[0]
                img_dir = os.path.join(opt['path']['val_images'], img_name)
                util.mkdir(img_dir)

                model.feed_data(val_data)
                model.test()

                # calculate PSNR
                item_psnr = util.tensor_psnr(model.real_H, model.fake_H)
                if math.isfinite(item_psnr):
                    avg_psnr += item_psnr
                    idx += 1

                progress_bar(batch_num, len(val_loader), msg=None)

            avg_psnr = avg_psnr / idx
            all_results.append(avg_psnr)

            if avg_psnr < best_psnr:
                patience += 1
                if patience == opt['train']['epoch_patience']:
                    model.update_learning_rate(opt['train']['lr_decay'])
                    print(
                        "no improvement, final patience, updating learning rate to {}"
                        .format(model.get_current_learning_rate()))
                    patience = 0
                else:
                    print("no improvement, patience {} out of {}".format(
                        patience, opt['train']['epoch_patience']))
                if model.get_current_learning_rate() < opt['train']['min_lr']:
                    break

            else:
                best_psnr = avg_psnr
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save('latest')

            # log
            logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
            logger_val = logging.getLogger('val')  # validation logger
            logger_val.info(
                '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} (best: {:.4e})'.
                format(epoch, current_step, avg_psnr, best_psnr))
            # tensorboard logger
            if opt['use_tb_logger'] and 'debug' not in opt['name']:
                tb_logger.add_scalar('psnr', avg_psnr, current_step)

    if rank <= 0:
        logger.info('End of training.')
        json.dump(all_results,
                  open(
                      os.path.join(opt['path']['log'],
                                   'validation_results.json'), 'w'),
                  indent=2)
Exemplo n.º 12
0
            lpips_t0, lpips_t1 = self.vgg16(img, resize_images=True, return_lpips=True).chunk(2)
            dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
        return dist

if __name__ == '__main__':
    batch_size = 16
    test_size = 100000
    total_iterations = math.ceil(test_size/batch_size)
    vgg = torch.jit.load("./score_calculator/vgg16.pt").eval().to(device)
    model = StyleALAE(model_config=config, device=device)
    model.load_train_state('./archived/FFHQ/StyleALAE-z-256_w-256_prog-(4,256)-(8,256)-(16,128)-(32,128)-(64,64)-(64,32)/checkpoints/ckpt_gs-120000_res-5=64x64_alpha-0.40.pt')
    model.G.eval()
    model.F.eval()
    for space in ['w', 'z']:
        for path_len in ['full','end']:
            sampler = PPLNet(G=model.G, F=model.F, config=config, batch_size=batch_size, res_idx=model.res_idx, alpha=0.4, epsilon=0.0001, space=space, sampling=path_len, vgg16=vgg)
            sampler.eval().requires_grad_(False).to(device)
            dists = []
            print("Calculating on ",space," latent space , path length is ", path_len )
            for itr in range(total_iterations):
                dist = sampler.forward()
                dists.append(dist)
                progress_bar(itr, total_iterations)
            print("Finished")
            dists = torch.cat(dists)[:test_size].cpu().numpy()
            lo = np.percentile(dists, 1, interpolation='lower')
            hi = np.percentile(dists, 99, interpolation='higher')
            ppl = np.extract(np.logical_and(dists >= lo, dists <= hi), dists).mean()
            print("The PPL is ", ppl)

Exemplo n.º 13
0
def train_main(opt,
               train_loader,
               val_loader,
               train_sampler,
               logger,
               resume_state=None,
               tb_logger=None,
               rank=-1):
    # create model
    model = create_model(opt)

    try:
        total_nfe = model.netG.module.conv_trunk.nfe
        nfe = True
        try:
            nfe_count = json.load(
                open(os.path.join(opt['path']['log'], "nfe_count.json")))
            print("resuming NFE count from {}".format(
                os.path.join(opt['path']['log'], "nfe_count.json")))
        except FileNotFoundError:
            print("no previous NFE count file found, starting from scratch")
            nfe_count = []
    except AttributeError:
        nfe = False
        total_nfe = None
        nfe_count = None

    best_niqe = 1e10
    best_psnr = 0
    patience = 0

    # resume training
    if resume_state:
        logger.info(
            'Resuming training from epoch: {}, iter: {}., psnr: {}, niqe: {}, patience: {}'
            .format(resume_state['epoch'], resume_state['iter'],
                    resume_state['psnr'], resume_state['niqe'],
                    resume_state['patience']))

        start_epoch = resume_state['epoch'] + 1
        current_step = resume_state['iter']
        best_psnr = resume_state.get('psnr', 0)
        best_niqe = resume_state.get('niqe', 1e10)
        patience = resume_state.get('patience', 0)
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    try:
        if opt['train']['G_pretraining'] >= 1:
            pretraining_epochs = opt['train']['G_pretraining']
        else:
            pretraining_epochs = 0
    except (KeyError, TypeError):
        pretraining_epochs = 0
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    total_epochs = int(opt['train']['nepochs'])
    lr_decay = opt['train']['lr_decay']
    min_lr = opt['train']['min_lr']
    pretraining = False
    all_results = []
    start_time = time()
    for epoch in range(start_epoch, total_epochs):

        if pretraining_epochs > 0:
            if epoch == 0:
                pretraining = True
                logger.info('Starting pretraining.')
            if epoch == pretraining_epochs:
                pretraining = False
                logger.info(
                    'Pretraining done, adding feature and discriminator loss.')

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        if nfe:
            epoch_nfe = []

        for batch_num, train_data in enumerate(train_loader):
            # try:
            current_step += 1

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step, pretraining=pretraining)

            if nfe:
                last_nfe = model.netG.module.conv_trunk.nfe - total_nfe
                total_nfe = model.netG.module.conv_trunk.nfe
                epoch_nfe.append(last_nfe)

            progress_bar(batch_num, len(train_loader), msg=None)

        # except RuntimeError:
        #     continue
        if nfe:
            nfe_count.append(epoch_nfe)
        # log
        if epoch % opt['logger']['print_freq'] == 0:
            logs = model.get_current_log()
            message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                epoch, current_step, model.get_current_learning_rate())
            for k, v in logs.items():
                message += '{:s}: {:.4e} '.format(k, v)
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    if rank <= 0:
                        tb_logger.add_scalar(k, v, current_step)
            if rank <= 0:
                logger.info(message)

        # batched validation

        if nfe:
            epoch_nfe = []

        if epoch % opt['train'][
                'val_freq'] == 0 and rank <= 0 and epoch >= pretraining_epochs - 1:
            avg_psnr = 0.0
            avg_niqe = 0.0
            idx = 0
            for batch_num, val_data in enumerate(val_loader):
                # try:
                img_name = os.path.splitext(
                    os.path.basename(val_data['LQ_path'][0]))[0]
                img_dir = os.path.join(opt['path']['val_images'], img_name)
                util.mkdir(img_dir)

                model.feed_data(val_data)
                model.test()
                if nfe:
                    last_nfe = model.netG.module.conv_trunk.nfe - total_nfe
                    total_nfe = model.netG.module.conv_trunk.nfe
                    epoch_nfe.append(last_nfe)

                visuals = model.get_current_visuals()
                sr_img = util.tensor2img(visuals['SR'])  # uint8
                # ground truth image
                # gt_img = util.tensor2img(visuals['GT'])  # uint8

                # Save SR images for reference
                save_img_path = os.path.join(
                    img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                util.save_img(sr_img, save_img_path)

                # calculate PSNR
                item_psnr = util.tensor_psnr(model.real_H, model.fake_H)
                if math.isfinite(item_psnr):
                    avg_psnr += item_psnr
                    idx += 1

                # calculate NIQE
                if opt['niqe']:
                    item_niqe = util.tensor_niqe(model.fake_H)
                    # item_niqe = 0
                    if math.isfinite(item_niqe):
                        avg_niqe += item_niqe

                progress_bar(batch_num, len(val_loader), msg=None)
            if nfe:
                nfe_count.append(epoch_nfe)
                json.dump(nfe_count,
                          open(
                              os.path.join(opt['path']['log'],
                                           'nfe_count.json'), 'w'),
                          indent=2)

            avg_psnr = avg_psnr / idx
            avg_niqe = avg_niqe / idx
            all_results.append((time() - start_time, avg_psnr, avg_niqe))

            # save models and training states
            if rank <= 0 and (avg_psnr > best_psnr
                              or avg_niqe < best_niqe - 10e-6):
                logger.info('Saving models and training states.')
                model.save(epoch)
                model.save_training_state(epoch, current_step, best_psnr,
                                          best_niqe, patience)

            else:
                patience += 1
                if patience == opt['train']['epoch_patience']:
                    model.update_learning_rate(lr_decay)
                    print(
                        "no improvement, final patience, updating learning rate to {}"
                        .format(model.get_current_learning_rate()))
                    patience = 0
                else:
                    print("no improvement, patience {} out of {}".format(
                        patience, opt['train']['epoch_patience']))
                if model.get_current_learning_rate() < min_lr:
                    break

            if avg_niqe < best_niqe:
                best_niqe = avg_niqe

            if avg_psnr > best_psnr:
                best_psnr = avg_psnr

            # log
            logger.info('# Validation # PSNR: {:.4e} # NIQE: {:.4e}'.format(
                avg_psnr, avg_niqe))
            logger_val = logging.getLogger('val')  # validation logger
            logger_val.info(
                '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} niqe: {:.4e} (best: {:.4e}/{:.4e})'
                .format(epoch, current_step, avg_psnr, avg_niqe, best_psnr,
                        best_niqe))
            # tensorboard logger
            if opt['use_tb_logger'] and 'debug' not in opt['name']:
                tb_logger.add_scalar('psnr', avg_psnr, current_step)

        print('\n')

    if rank <= 0:
        # save results
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        json.dump(all_results,
                  open(
                      os.path.join(opt['path']['log'],
                                   'validation_results.json'), 'w'),
                  indent=2)
        if nfe:
            nfe_count.append(epoch_nfe)
            json.dump(nfe_count,
                      open(os.path.join(opt['path']['log'], 'nfe_count.json'),
                           'w'),
                      indent=2)

        # clear validation logger
        logger_val.handlers.clear()

        # print out graph of val psnr with time
        fig, ax = plt.subplots()
        y = list(zip(*all_results))
        runtime, dev_psnr, dev_niqe = y[0], y[1], y[2]
        ax.plot(runtime, dev_psnr, color='blue', label='Validation PSNR')

        ax.set(xlabel='Time (s)', ylabel='Dev. PSNR.')
        ax.legend(loc='upper right')
        ax.grid()

        plt.savefig(os.path.join(opt['path']['log'], "psnr_evolution.png"))