Пример #1
0
                print(predict_value)
                label = label.data.cpu().numpy()
                # print(output)
                # print(label)
                acc = np.mean((predict_value == label).astype(int))
                speed = config.log_interval / (time.time() - start)

                time_str = time.asctime(time.localtime(time.time()))
                print(
                    '{} || train epoch {} || iter {} || {} iters/s || loss {} || acc {}'
                    .format(time_str, current_epoch, current_epoch_step, speed,
                            loss.item(), acc))

                if config.display:
                    visualizer.display_current_results(current_global_step,
                                                       loss.item(),
                                                       name='train_loss')
                    visualizer.display_current_results(current_global_step,
                                                       acc,
                                                       name='train_acc')

                start = time.time()

        if current_epoch % config.save_interval == 0 or current_epoch == config.max_epoch:
            save_model(model, config.checkpoints_path, config.backbone,
                       current_epoch)

        model.eval()
        acc = lfw_test(model, test_img_paths, test_identity_list,
                       config.lfw_test_list, config.test_batch_size)
        if config.display:
Пример #2
0
def train(gpu, opt):
    device = torch.device(f"cuda:{opt.gpu_ids[0]}" if use_cuda else "cpu")
    opt.print_freq = lcm(opt.print_freq, opt.batch_size)

    folder_name = "DPFIP"

    iter_path = os.path.join(opt.checkpoints_dir, folder_name, opt.name,
                             'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except FileNotFoundError as e:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0
    dataset = dset.ImageFolder(root=opt.dataroot,
                               transform=transforms.Compose([
                                   transforms.CenterCrop(178),
                                   transforms.Resize(128),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    # Create the dataloader
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=True,
                                              num_workers=opt.num_workers)

    dataset_size = len(trainloader) * opt.batch_size
    visualizer = Visualizer(opt)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    model = VaeGanModule(opt, device)

    if use_cuda:
        model = model.cuda()
    optimizer_vae, optimizer_D = model.optimizer_vae, model.optimizer_D

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size

        for i, data in enumerate(iter(trainloader)):
            compute_g_loss = True
            images = data[0].to(device)
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size
            save_fake = total_steps % opt.display_freq == display_delta
            losses, fake_images = model(images)
            output_images = fake_images.detach()
            # sum per device losses
            losses = [
                torch.mean(x)
                if not isinstance(x, int) and x is not None else x
                for x in losses
            ]
            if use_cuda:
                loss_dict = dict(zip(model.loss_names, losses))
            else:
                loss_dict = dict(zip(model.loss_names, losses))
            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            ############### Backward Pass ####################
            # update generator weights
            # Update G
            loss_G = loss_dict['G_GAN'] + \
                         loss_dict.get("G_Image_Rec", 0) + \
                         loss_dict.get("G_KL_image", 0)
            optimizer_vae.zero_grad()
            loss_G.backward()
            optimizer_vae.step()
            optimizer_vae.zero_grad()

            # update discriminator weights
            loss_D.backward()
            optimizer_D.step()
            optimizer_D.zero_grad()

            ############## Display results and errors ##########
            ### print out errors

            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item()
                    if not isinstance(v, int) and v is not None else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            visuals = OrderedDict([('real_image', util.tensor2im(images[0])),
                                   ('fake_image',
                                    util.tensor2im(output_images[0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)
            ### save latest model
            if (total_steps % opt.save_latest_freq == save_delta):

                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                if use_cuda:
                    model.save('latest')
                else:
                    model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')
            if epoch_iter >= dataset_size:
                break
            # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if (epoch % opt.save_epoch_freq == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()
Пример #3
0
            iters = i * len(trainloader) + ii

            if iters % opt.print_freq == 0:
                output = output.data.cpu().numpy()
                output = np.argmax(output, axis=1)
                label = label.data.cpu().numpy()
                # print(output)
                # print(label)
                acc = np.mean((output == label).astype(int))
                speed = opt.print_freq / (time.time() - start)
                time_str = time.asctime(time.localtime(time.time()))
                print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.
                      format(time_str, i, ii, speed, loss.item(), acc))
                if opt.display:
                    visualizer.display_current_results(iters,
                                                       loss.item(),
                                                       name='train_loss')
                    visualizer.display_current_results(iters,
                                                       acc,
                                                       name='train_acc')

                start = time.time()

        if i % opt.save_interval == 0 or i == opt.max_epoch:
            save_model(model, opt.checkpoints_path, opt.backbone, i)

        model.eval()
        acc = lfw_test(model, img_paths, identity_list, opt.lfw_test_list,
                       opt.test_batch_size)
        if opt.display:
            visualizer.display_current_results(iters, acc, name='test_acc')
Пример #4
0
            for i, data in enumerate(dataset):

                iter_start_time = time.time()
                if total_iters % opt.print_freq == 0:
                    t_data = iter_start_time - iter_data_time

                total_iters += opt.batch_size
                epoch_iter += opt.batch_size

                pruned_model.set_input(data)
                pruned_model.optimize_parameters()

                if total_iters % opt.display_freq == 0:
                    save_result = total_iters % opt.update_html_freq == 0
                    # pruned_model.compute_visuals()
                    visualizer.display_current_results(
                        pruned_model.get_current_visuals(), epoch, save_result)

                if total_iters % opt.print_freq == 0:
                    losses = pruned_model.get_current_losses()
                    t_comp = (time.time() - iter_start_time) / opt.batch_size
                    loss_message = visualizer.print_current_losses(
                        epoch, epoch_iter, losses, t_comp, t_data)
                    logger.info(loss_message)
                    if opt.display_id > 0:
                        visualizer.plot_current_losses(
                            epoch,
                            float(epoch_iter) / dataset_size, losses)

                    iter_data_time = time.time()

            if epoch % opt.save_epoch_freq == 0:
Пример #5
0
visualizer = Visualizer(opt)
total_steps = 0
epoch_count = 0

for epoch in range(opt.epoch):
    epoch_start_time = time.time()
    iter_count = 0

    for i, data in enumerate(data_loader):
        batch_start_time = time.time()
        total_steps += opt.batch_size
        iter_count += opt.batch_size
        # data : list
        model.set_input(data[0])
        model.optimize_parameters()
        batch_end_time = time.time()

        if iter_count % opt.print_freq == 0:
            errors = model.get_losses()
            visualizer.print_current_errors(epoch, iter_count, errors, (batch_end_time - batch_start_time))

        if total_steps % opt.plot_freq == 0:
            save_result = total_steps % opt.plot_freq == 0
            visualizer.display_current_results(model.get_visuals(), int(total_steps/opt.plot_freq), save_result)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, total_steps, errors)

    model.remove(epoch_count)
    epoch_count += 1
    model.save(epoch_count)
Пример #6
0
def run():
    opt = Config()

    if opt.display:
        visualizer = Visualizer()

    # device = torch.device("cuda")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = FaceDataset(opt.train_root,
                                opt.train_list,
                                phase='train',
                                input_shape=opt.input_shape)
    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=opt.train_batch_size,
                                              shuffle=True,
                                              num_workers=opt.num_workers)
    print('{} train iters per epoch:'.format(len(trainloader)))

    # Focal Loss, 解决类别不均衡问题,减少易分类样本的权重,使得模型在训练时更专注于难分类的样本
    # https://blog.csdn.net/u014380165/article/details/77019084
    #

    #定义损失函数
    if opt.loss == 'focal_loss':
        criterion = FocalLoss(gamma=2)  #
    else:
        criterion = torch.nn.CrossEntropyLoss()

    #定义模型
    if opt.backbone == 'resnet18':
        model = resnet_face18(use_se=opt.use_se)
    elif opt.backbone == 'resnet34':
        model = resnet34()
    elif opt.backbone == 'resnet50':
        model = resnet50()

    #全连接层?
    if opt.metric == 'add_margin':
        metric_fc = AddMarginProduct(512, opt.num_classes, s=30, m=0.35)
    elif opt.metric == 'arc_margin':
        metric_fc = ArcMarginProduct(512,
                                     opt.num_classes,
                                     s=30,
                                     m=0.5,
                                     easy_margin=opt.easy_margin)
    elif opt.metric == 'sphere':
        metric_fc = SphereProduct(512, opt.num_classes, m=4)
    else:
        metric_fc = nn.Linear(512, opt.num_classes)

    # view_model(model, opt.input_shape)
    print(model)
    model.to(device)
    model = DataParallel(model)
    metric_fc.to(device)
    metric_fc = DataParallel(metric_fc)

    #定义优化算法
    if opt.optimizer == 'sgd':
        optimizer = torch.optim.SGD([{
            'params': model.parameters()
        }, {
            'params': metric_fc.parameters()
        }],
                                    lr=opt.lr,
                                    weight_decay=opt.weight_decay)
    else:
        optimizer = torch.optim.Adam([{
            'params': model.parameters()
        }, {
            'params': metric_fc.parameters()
        }],
                                     lr=opt.lr,
                                     weight_decay=opt.weight_decay)

    # https://www.programcreek.com/python/example/98143/torch.optim.lr_scheduler.StepLR
    # ? 每过{lr_step}个epoch训练,学习率就乘gamma
    scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)

    start = time.time()
    for i in range(opt.max_epoch):
        scheduler.step()

        model.train()  # train模式,eval模式
        for ii, data in enumerate(trainloader):
            data_input, label = data
            data_input = data_input.to(device)
            label = label.to(device).long()

            feature = model(data_input)
            output = metric_fc(feature,
                               label)  # 全连接层? 将原本用于输出分类的层,改成输出512维向量?似乎不是?
            loss = criterion(output, label)  # criterion:做出判断的依据

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iters = i * len(trainloader) + ii

            if iters % opt.print_freq == 0:
                output = output.data.cpu().numpy()
                output = np.argmax(output,
                                   axis=1)  #最大值所在的索引? index <-> one-hot相互转换
                label = label.data.cpu().numpy()
                # print(output)
                # print(label)
                acc = np.mean((output == label).astype(int))
                speed = opt.print_freq / (time.time() - start)
                time_str = time.asctime(time.localtime(time.time()))
                print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.
                      format(time_str, i, ii, speed, loss.item(), acc))
                if opt.display:
                    visualizer.display_current_results(iters,
                                                       loss.item(),
                                                       name='train_loss')
                    visualizer.display_current_results(iters,
                                                       acc,
                                                       name='train_acc')

                start = time.time()

        if i % opt.save_interval == 0 or i == opt.max_epoch:
            save_model(model, opt.checkpoints_path, opt.backbone, i)

        # train结束,模型设置为eval模式
        model.eval()

        #测试?
        identity_list = get_lfw_list(opt.lfw_test_list)
        img_paths = [
            os.path.join(opt.lfw_root, each) for each in identity_list
        ]
        acc = lfw_test(model, img_paths, identity_list, opt.lfw_test_list,
                       opt.test_batch_size)

        if opt.display:
            visualizer.display_current_results(iters, acc, name='test_acc')
Пример #7
0
for epoch in range(opt.epoch):
    epoch_start_time = time.time()
    iter_count = 0

    for i, data in enumerate(data_loader):
        batch_start_time = time.time()
        total_steps += opt.batch_size
        iter_count += opt.batch_size
        # data : list
        model.set_input(data[0])
        model.optimize_parameters()
        batch_end_time = time.time()

        if iter_count % opt.print_freq == 0:
            errors = model.get_losses()
            visualizer.print_current_errors(
                epoch, iter_count, errors, (batch_end_time - batch_start_time))

        if total_steps % opt.plot_freq == 0:
            save_result = total_steps % opt.plot_freq == 0
            visualizer.display_current_results(
                model.get_visuals(), int(total_steps / opt.plot_freq),
                save_result)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, total_steps, errors)

    model.remove(epoch_count)
    epoch_count += 1
    model.save(epoch_count)
Пример #8
0
        if total_iters % opt['print_freq'] == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_iters += opt['batch_size']
        epoch_iter += opt['batch_size']
        model.set_input(
            data)  # unpack data from dataset and apply preprocessing
        model.optimize_parameters(
        )  # calculate loss functions, get gradients, update network weights

        if total_iters % opt[
                'display_freq'] == 0:  # display images on visdom and save images to a HTML file
            save_result = total_iters % opt['update_html_freq'] == 0
            model.compute_visuals()
            visualizer.display_current_results(model.get_current_visuals(),
                                               total_iters, len(dataset),
                                               time.time() - train_start_time,
                                               save_result)

        if total_iters % opt[
                'print_freq'] == 0:  # print training losses and save logging information to the disk
            losses = model.get_current_losses()
            t_comp = (time.time() - iter_start_time) / opt['batch_size']
            visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp,
                                            t_data)
            if opt['display_id'] > 0:
                visualizer.plot_current_losses(
                    epoch,
                    float(epoch_iter) / len(dataset), losses)
                visualizer.display_train_time(time.time(), epoch_end_time)

        if total_iters % opt[
Пример #9
0
                train_opt.niter + train_opt.niter_decay))

            train_psnr = train_model.cal_psnr()
            train_psnr_list.append(train_psnr)

            if epoch % train_opt.print_freq == 0:
                losses = train_model.get_current_losses()
                t = (time.time() - iter_start_time) / train_opt.batchsize
                visualizer.print_current_losses(epoch, epoch_iter, losses, t)
                if train_opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, train_opt, losses)
                    visualizer.display_current_results(
                        train_model.get_current_visuals(),
                        train_model.get_image_name(),
                        epoch,
                        True,
                        win_id=[1])

                    visualizer.plot_spectral_lines(
                        train_model.get_current_visuals(),
                        train_model.get_image_name(),
                        visual_corresponding_name=train_model.
                        get_visual_corresponding_name(),
                        win_id=[2, 3])
                    visualizer.plot_psnr_sam(
                        train_model.get_current_visuals(),
                        train_model.get_image_name(), epoch,
                        float(epoch_iter) / dataset_size,
                        train_model.get_visual_corresponding_name())
Пример #10
0
    # training process
    while (keep_training):
        epoch_start_time = time.time()
        epoch += 1
        print('\n Training epoch: %d' % epoch)

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_iteration += 1
            model.set_input(data)
            model.optimize_parameters()

            # display images on visdom and save images
            if total_iteration % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch)
                visualizer.plot_current_distribution(model.get_current_dis())

            # print training loss and save logging information to the disk
            if total_iteration % opt.print_freq == 0:
                losses = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, total_iteration, losses,
                                                t)
                # print(f"epoch: {epoch}, total_iteration: {total_iteration}, losses: {losses}, t: {t}")
                if opt.display_id > 0:
                    visualizer.plot_current_errors(total_iteration, losses)

            # save the latest model every <save_latest_freq> iterations to the disk
            if total_iteration % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
Пример #11
0
class Treainer(object):
    def __init__(self,
                 opt=None,
                 train_dt=None,
                 train_dt_warm=None,
                 dis_list=[],
                 val_dt_warm=None):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.opt = opt

        self.visualizer = Visualizer(opt)

        num_gpus = torch.cuda.device_count()
        #dis_list[1]
        print(dis_list)
        #torch.cuda.device_count()
        self.rank = dis_list[0]
        print(self.rank)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #init_distributed(rank, num_gpus, group_name, **dist_config)
            dist_config = dis_list[3]
            init_distributed(dis_list[0], dis_list[1], dis_list[2],
                             **dist_config)
        #=====END:   ADDED FOR DISTRIBUTED======

        if opt.ge_net == "srfeat":
            self.netG = model.G()
        elif opt.ge_net == "carn":
            self.netG = model.G1()
        elif opt.ge_net == "carnm":
            self.netG = model.G2()
        else:
            raise Exception("unknow ")

        self.netD_vgg = model.D(input_c=512, input_width=18)

        self.netD = model.D()

        if opt.vgg_type == "style":
            self.vgg = load_vgg16(opt.vgg_model_path + '/models')
        elif opt.vgg_type == "classify":
            self.vgg = model.vgg19_withoutbn_customefinetune()

        self.vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

#         for p in self.vgg.parameters():
#             p.requires_grad = False

        init_weights(self.netD, init_type=opt.init)
        init_weights(self.netD_vgg, init_type=opt.init)
        init_weights(self.netG, init_type=opt.init)

        self.vgg = self.vgg.to(self.device)
        self.netD = self.netD.to(self.device)
        self.netD_vgg = self.netD_vgg.to(self.device)
        self.netG = self.netG.to(self.device)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #self.vgg = apply_gradient_allreduce(self.vgg)
            self.netD_vgg = apply_gradient_allreduce(self.netD_vgg)
            self.netD = apply_gradient_allreduce(self.netD)
            self.netG = apply_gradient_allreduce(self.netG)

        #=====END:   ADDED FOR DISTRIBUTED======

        print(opt)

        self.optim_G= torch. optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
         lr=opt.warm_opt.lr, betas=opt.warm_opt.betas, weight_decay=0.0)

        #        self.optim_G= torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
        #         lr=opt.gen.lr, betas=opt.gen.betas, weight_decay=0.0)

        if opt.dis.optim == "sgd":
            self.optim_D= torch.optim.SGD( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,
             )
        elif opt.dis.optim == "adam":
            self.optim_D= torch.optim.Adam( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,betas=opt.dis.betas, weight_decay=0.0
             )
        else:
            raise Exception("unknown")

        print("create schedule ")

        lr_sc_G = get_scheduler(self.optim_G, opt.gen)
        lr_sc_D = get_scheduler(self.optim_D, opt.dis)

        self.schedulers = []

        self.schedulers.append(lr_sc_G)
        self.schedulers.append(lr_sc_D)

        # =====START: ADDED FOR DISTRIBUTED======
        train_dt = torch.utils.data.ConcatDataset([train_dt, train_dt_warm])

        train_sampler = DistributedSampler(train_dt) if num_gpus > 1 else None
        val_sampler_warm = DistributedSampler(
            val_dt_warm) if num_gpus > 1 else None
        # =====END:   ADDED FOR DISTRIBUTED======

        kw = {
            "pin_memory": True,
            "num_workers": 8
        } if torch.cuda.is_available() else {}
        dl_c =t_data.DataLoader(train_dt ,batch_size=opt.batch_size,\
             sampler=train_sampler , drop_last=True, **kw )

        dl_val_warm = t_data.DataLoader(
            val_dt_warm,
            batch_size=opt.batch_size
            if not hasattr(opt, "batch_size_warm") else opt.batch_size_warm,
            sampler=val_sampler_warm,
            drop_last=True,
            **kw)

        self.dt_train = dl_c
        self.dt_val_warm = dl_val_warm

        if opt.warm_opt.loss_fn == "mse":
            self.critic_pixel = torch.nn.MSELoss()
        elif opt.warm_opt.loss_fn == "l1":
            self.critic_pixel = torch.nn.L1Loss()
        elif opt.warm_opt.loss_fn == "smooth_l1":
            self.critic_pixel = torch.nn.SmoothL1Loss()
        else:
            raise Exception("unknown")

        self.critic_pixel = self.critic_pixel.to(self.device)

        self.gan_loss = GANLoss(gan_mode=opt.gan_loss_fn).to(self.device)
        print("init ....")

        self.save_dir = os.path.dirname(self.visualizer.log_name)

    def _validate_(self):
        with torch.no_grad():
            print("val ," * 8, "warm start...", len(self.dt_val_warm))
            iter_start_time = time.time()
            ssim = []
            batch_loss = []
            psnr = []

            cub_ssim = []
            cub_batch_loss = []
            cub_psnr = []

            save_image_list_1 = []

            for ii, data in tqdm.tqdm(enumerate(self.dt_val_warm)):
                if len(data) > 3:
                    input_lr, input_hr, cubic_hr, _, _ = data
                else:
                    input_lr, input_hr, cubic_hr = data

                self.input_lr = input_lr.to(self.device)
                self.input_hr = input_hr.to(self.device)
                self.input_cubic_hr = cubic_hr.to(self.device)

                self.forward()

                save_image_list_1.append(torch.cat( [self.input_cubic_hr ,\
                 self.output_hr ,\
                 self.input_hr ],dim=3)  )

                loss = self.critic_pixel(self.output_hr, self.input_hr)
                batch_loss.append(loss.item())
                ssim.append(
                    image_quality.msssim(self.output_hr, self.input_hr).item())
                psnr.append(
                    image_quality.psnr(self.output_hr, self.input_hr).item())

                cub_loss = self.critic_pixel(self.input_cubic_hr,
                                             self.input_hr)
                cub_batch_loss.append(cub_loss.item())
                cub_ssim.append(
                    image_quality.msssim(self.input_cubic_hr,
                                         self.input_hr).item())
                cub_psnr.append(
                    image_quality.psnr(self.input_cubic_hr,
                                       self.input_hr).item())

            np.random.shuffle(save_image_list_1)
            save_image_list = save_image_list_1[:8]
            save_image_list = util.tensor2im(torch.cat(save_image_list, dim=2))
            save_image_list = OrderedDict([("cub_out_gt", save_image_list)])
            self.visualizer.display_current_results(save_image_list,
                                                    self.epoch,
                                                    save_result=True,
                                                    offset=20,
                                                    title="val_imag")

            val_info = (np.mean(batch_loss), np.mean(ssim), np.mean(psnr),
                        np.mean(cub_batch_loss), np.mean(cub_ssim),
                        np.mean(cub_psnr))
            errors = dict(
                zip(("loss", "ssim", "psnr", "cub_loss", "cub_ssim",
                     "cub_psnr"), val_info))
            t = (time.time() - iter_start_time)
            self.visualizer.print_current_errors(self.epoch,
                                                 self.epoch,
                                                 errors,
                                                 t,
                                                 log_name="loss_log_val.txt")
            self.visualizer.plot_current_errors(self.epoch,
                                                self.epoch,
                                                opt=None,
                                                errors=errors,
                                                display_id_offset=3,
                                                loss_name="val")

            return val_info

    def run(self):
        current_epoch = self.load_networks()
        self._run_train()

    def _run_train(self):
        print("train.i..." * 8)
        total_steps = 0
        opt = self.opt

        self.model_names = ["G", "D", "D_vgg"]

        self.loss_w_g = torch.tensor(0)
        dataset_size = len(self.dt_train) * opt.batch_size
        best_loss = 10e5

        for epoch in range(0, self.opt.epoches_warm + self.opt.epoches):
            self.epoch = epoch
            #             epoch_start_time = time.time()
            epoch_iter = 0

            val_loss = self._validate_()
            val_loss = val_loss[0]
            if best_loss > val_loss:
                best_loss = val_loss
                self.save_networks("best")
            self.save_networks(epoch)

            for data in self.dt_train:
                if len(data) > 3:
                    input_lr, input_hr, cubic_hr, _, _ = data
                else:
                    input_lr, input_hr, cubic_hr = data

                iter_start_time = time.time()

                self.input_lr = input_lr.to(self.device)
                self.input_hr = input_hr.to(self.device)
                self.input_cubic_hr = cubic_hr

                self.forward()

                self.optim_G.zero_grad()
                self.g_loss()
                self.optim_G.step()

                self.optim_D.zero_grad()
                self.d_loss()
                self.optim_D.step()

                self.visualizer.reset()
                total_steps += opt.batch_size
                epoch_iter += opt.batch_size

                if total_steps % opt.display_freq == 0:
                    save_result = total_steps % opt.update_html_freq == 0
                    self.visualizer.display_current_results(
                        self.get_current_visuals(), epoch, save_result)

                if total_steps % opt.print_freq == 0:
                    errors = self.get_current_errors()
                    t = (time.time() - iter_start_time) / opt.batch_size
                    self.visualizer.print_current_errors(
                        epoch, epoch_iter, errors, t)
                    if opt.display_id > 0:
                        self.visualizer.plot_current_errors(
                            epoch,
                            float(epoch_iter) / dataset_size, opt, errors)

                if self.rank != 0:
                    continue
            lr_g, lr_d = self.update_learning_rate(is_warm=False)
            self.visualizer.plot_current_lrs(epoch,0,opt=None,\
                errors=OrderedDict([ ('lr_warm_g',0),("lr_g",lr_g),("lr_d",lr_d) ]) ,  loss_name="lr_warm"  ,display_id_offset=1)

    def forward(self, ):
        self.output_hr = self.netG(self.input_lr)
        #         self.input_hr
        pass

    def g_loss(self, ):
        #print (self.opt.gen,type(self.opt.gen),self.opt.gen.keys())
        vgg_r = self.opt.gen.lambda_vgg_input
        #g feature f
        x_f_fake = self.vgg(vgg_r * self.output_hr)

        #g .. f
        d_fake = self.netD(self.output_hr)
        self.loss_G_g = self.opt.gen.lambda_vgg_loss * self.gan_loss(
            d_fake, True)

        fd_fake = self.netD_vgg(x_f_fake)
        self.loss_G_fg = self.opt.gen.lambda_vgg_loss * self.gan_loss(
            fd_fake, True)

        ## perception
        x_f_real = self.vgg(vgg_r * self.input_hr)
        self.loss_G_p = self.critic_pixel(x_f_fake, x_f_real)

        self.loss_w_g = self.opt.warm_opt.lambda_warm_loss * self.critic_pixel(
            self.output_hr, self.input_hr)

        self.loss_g =  self.loss_G_g + self.loss_G_fg  + self.loss_G_p +\
             self.loss_w_g

        self.loss_g.backward()

        if hasattr(self.opt.warm_opt, "clip"):
            nn.utils.clip_grad_norm(self.netG.parameters(),
                                    self.opt.warm_opt.clip)

    def d_loss(self, ):
        d_fake = self.netD(self.output_hr.detach())
        d_real = self.netD(self.input_hr)

        vgg_r = self.opt.gen.lambda_vgg_input
        x_f_fake = self.vgg(vgg_r * self.output_hr.detach())
        x_f_real = self.vgg(vgg_r * self.input_hr)

        vgg_d_fake = self.netD_vgg(x_f_fake)
        vgg_d_real = self.netD_vgg(x_f_real)

        self.loss_D_f = self.gan_loss(d_fake, False)
        self.loss_D_r = self.gan_loss(d_real, True)

        self.loss_Df_f = self.gan_loss(vgg_d_fake, False)
        self.loss_Df_r = self.gan_loss(vgg_d_real, True)

        #self.loss_d_f_fake = 0
        #self.loss_d_f_real = 0
        if self.opt.gan_loss_fn == "wgangp":
            # train with gradient penalty
            gradient_penalty_vgg,_ = cal_gradient_penalty(netD=self.netD_vgg, real_data=x_f_real.data,\
                fake_data=x_f_fake.data,device=self.device)
            gradient_penalty_vgg.backward()

            gradient_penalty,_ = cal_gradient_penalty(netD=self.netD, real_data=self.input_hr.data, \
                fake_data = self.output_hr.data,  device=self.device)
            gradient_penalty.backward()



        loss_d =self.loss_D_f+ self.loss_D_r +self.loss_Df_f+\
            self.loss_Df_r
        #print ("loss_d",loss_d.item() )

        loss_d.backward()

    def get_current_errors(self):
        return OrderedDict([
            ('G_p', self.loss_G_p.item() if hasattr(self, "loss_G_p") else 0),
            ('G_fg',
             self.loss_G_fg.item() if hasattr(self, "loss_G_fg") else 0),
            ('G_g', self.loss_G_g.item() if hasattr(self, "loss_G_g") else 0),
            ('D_f_real',
             self.loss_Df_r.item() if hasattr(self, "loss_Df_r") else 0),
            ('D_f_fake',
             self.loss_Df_f.item() if hasattr(self, "loss_Df_f") else 0),
            ('D_real',
             self.loss_D_r.item() if hasattr(self, "loss_D_r") else 0),
            ('D_fake',
             self.loss_D_f.item() if hasattr(self, "loss_D_f") else 0),
            ('warm_p',
             self.loss_w_g.item() if hasattr(self, "loss_w_g") else 0),
        ])

    def get_current_visuals(self):
        input = util.tensor2im(self.input_cubic_hr)
        target = util.tensor2im(self.input_hr)
        fake = util.tensor2im(self.output_hr.detach())
        return OrderedDict([('input', input), ('fake', fake),
                            ('target', target)])

    def update_learning_rate(self, is_warm=True):
        if True:
            for scheduler in self.schedulers:
                scheduler.step()

            lr_g = self.optim_G.param_groups[0]['lr']
            lr_d = self.optim_D.param_groups[0]['lr']
            return (lr_g, lr_d)

    def save_networks(self, epoch):
        """Save all the networks to the disk.
        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if "parallel" in str(type(net)) and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

                net.to(self.device)

    def load_networks(self, epoch=None):
        """Load all the networks from the disk.
        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        pth_list = [os.path.basename(x) for x in os.listdir(self.save_dir)]
        pth_list = [
            x.split("_")[0] for x in pth_list
            if "_net" in x and ".pth" in x and "best" not in x
        ]
        pth_list = sorted(pth_list)[:-1]
        pth_list = list(map(int, pth_list))
        pth_list = sorted(pth_list)
        current_epoch = 0
        try:
            current_epoch = int(pth_list[-1])
        except:
            pass

        if current_epoch <= 0:
            return current_epoch

        epoch = current_epoch
        #for name in self.model_names:
        for name in ["G", "D", "D_vgg"]:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                if not os.path.isfile(load_path):
                    print("***", "fail find%s" % (load_path))
                    continue
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path,
                                        map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                #for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                #    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

        return current_epoch
Пример #12
0
    epoch_iter = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_iters += 1
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize

        model.tick()
        data_loader.tick()
        model.set_input(data)
        model.optimize_parameters()
        errors = model.get_current_errors()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch, total_steps)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            lrs = model.get_current_lr()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t,
                                            opt.name)
            visualizer.plot_current_errors(epoch, total_steps, errors, 'loss')
            visualizer.plot_current_errors(epoch, total_steps, lrs, 'lr')

        if opt.eval and total_steps % opt.eval_freq == 0:
            eval_start_time = time.time()
            mse = model.eval_network()
            visualizer.plot_current_errors(epoch, total_steps, mse, 'mse')
            t = (time.time() - eval_start_time)
Пример #13
0
    def train(self):
        params = self.params
        # loading data
        # loading training images

        #set_start_method('spawn')

        data_loader = StereoDataloader(params)  # create dataloader
        train_data = DataLoader(data_loader,
                                batch_size=params.batchsize,
                                shuffle=True)  #, num_workers=1)
        dataset_size = len(data_loader)
        print('#training images: %d' % dataset_size)

        start_epoch, epoch_iter = 1, 0
        total_steps = (start_epoch - 1) * dataset_size + epoch_iter
        iter_ = 1

        # create/load model
        model = Model(params)
        visualizer = Visualizer(params)

        total_train_size = dataset_size * params.niter

        print('>>> dataset_size: %d, total train size: %d' %
              (dataset_size, total_train_size))
        print('\nTraining started...')
        train_start_time = time.time()
        # for epoch in range(start_epoch, params.niter + params.niter_decay + 1):
        for epoch in range(start_epoch, params.niter + 1):
            # epoch start time
            epoch_start_time = time.time()
            # params.current = epoch

            for i, data in enumerate(train_data, start=epoch_iter):
                iter_start_time = time.time()
                total_steps += params.batchsize
                epoch_iter += params.batchsize

                # whether to collect output images
                params.save_fake = total_steps % params.display_freq == 0

                if params.save_fake:
                    time_spent = time.time() - train_start_time

                    left_dataize = total_train_size - total_steps
                    fps = round(total_steps * 1.0 / time_spent, 1)
                    time_left = round(left_dataize / fps / 3600, 2)

                    print('\nEpoch: %d, Iteration: %d, Time Left: %2.2f hrs, examples/s: %3.2f' \
                        %(epoch, total_steps, time_left, fps))

                # when to start dis traning
                if epoch > params.headstart:
                    params.headstart_switch = -1

                # forward
                loss_G = model(Variable(data['left_img']),
                               Variable(data['right_img']))

                # backward G
                model.optimizer_G.zero_grad()
                loss_G.backward()
                model.optimizer_G.step()

                # # display input & output and save ouput images
                if params.save_fake:
                    result_img = model.get_result_img(
                        Variable(data['left_img']),
                        Variable(data['right_img']))
                    visualizer.display_current_results(result_img, epoch,
                                                       total_steps)

            # epoch end time
            iter_end_time = time.time()
            print('End of epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, params.niter, time.time() - epoch_start_time))

            # save mdodel
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save(epoch)
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            rec_.set_input(data, vgg_)
            rec_.optimize_parameters(vgg_)

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                images = {
                    'input': rec_.input_im,
                    'reconstruct': rec_.rec_im,
                    'real': rec_.real_im
                }
                visualizer.display_current_results(images, epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                if opt.adain_loss:
                    loss = {
                        'rec_loss': float(rec_.rec_loss),
                        'ft_loss': float(rec_.ft_loss),
                        'adain_loss': float(rec_.adain_loss)
                    }
                else:
                    loss = {
                        'ft_loss': float(rec_.ft_loss) * opt.ft_loss,
                        'D_loss': float(rec_.loss_D) * opt.gan_loss,
                        'G_loss': float(rec_.loss_G * opt.gan_loss)
                    }
Пример #15
0
def train(opt):
    torch.manual_seed(opt.seed)

    # load the train dataset
    dset = PairedDataset(opt, os.path.join(opt.real_im_path, 'train'),
                         os.path.join(opt.fake_im_path, 'train'))
    # halves batch size since each batch returns both real and fake ims
    dl = DataLoader(dset, batch_size=opt.batch_size // 2,
                    num_workers=opt.nThreads, pin_memory=False,
                    shuffle=True)

    # setup class labeling
    assert(opt.fake_class_id in [0, 1])
    fake_label = opt.fake_class_id
    real_label = 1 - fake_label
    logging.info("real label = %d" % real_label)
    logging.info("fake label = %d" % fake_label)
    dataset_size = 2 * len(dset)
    logging.info('# total images = %d' % dataset_size)
    logging.info('# total batches = %d' % len(dl))

    # setup model and visualizer
    model = create_model(opt)
    epoch, best_val_metric, best_val_ep = model.setup(opt)

    visualizer_losses = model.loss_names + [n + '_val' for n in model.loss_names]
    visualizer = Visualizer(opt, visualizer_losses, model.visual_names)
    total_batches = epoch * len(dl)
    t_data = 0

    now = time.strftime("%c")
    logging.info('================ Training Loss (%s) ================\n' % now)

    while True:
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        for i, ims in enumerate(dl):
            ims_real = ims['original'].to(opt.gpu_ids[0])
            ims_fake = ims['manipulated'].to(opt.gpu_ids[0])
            labels_real = real_label * torch.ones(ims_real.shape[0], dtype=torch.long).to(opt.gpu_ids[0])
            labels_fake = fake_label * torch.ones(ims_fake.shape[0], dtype=torch.long).to(opt.gpu_ids[0])

            batch_im = torch.cat((ims_real, ims_fake), axis=0)
            batch_label = torch.cat((labels_real, labels_fake), axis=0)
            batch_data = dict(ims=batch_im, labels=batch_label)

            iter_start_time = time.time()
            if total_batches % opt.print_freq == 0:
                # time to load data
                t_data = iter_start_time - iter_data_time

            total_batches += 1
            epoch_iter += 1
            model.reset()
            model.set_input(batch_data)
            model.optimize_parameters()

            if epoch_iter % opt.print_freq == 0:
                losses = model.get_current_losses()
                t = time.time() - iter_start_time
                visualizer.print_current_losses(
                    epoch, float(epoch_iter)/len(dl), total_batches,
                    losses, t, t_data)
                visualizer.plot_current_losses(total_batches, losses)

            if epoch_iter % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(),
                                                   total_batches)

            if epoch_iter % opt.save_latest_freq == 0:
                logging.info('saving the latest model (epoch %d, total_batches %d)' %
                      (epoch, total_batches))
                model.save_networks('latest', epoch, best_val_metric,
                                    best_val_ep)

            model.reset()
            iter_data_time = time.time()

        # do validation loop at end of each epoch
        model.eval()
        val_start_time = time.time()
        val_losses = validate(model, opt)
        visualizer.plot_current_losses(epoch, val_losses)
        logging.info("Printing validation losses:")
        visualizer.print_current_losses(
            epoch, 0.0, total_batches, val_losses,
            time.time()-val_start_time, 0.0)
        model.train()
        model.reset()
        assert(model.net_D.training)

        # update best model and determine stopping conditions
        if val_losses[model.val_metric + '_val'] > best_val_metric:
            logging.info("Updating best val mode at ep %d" % epoch)
            logging.info("The previous values: ep %d, val %0.2f" %
                         (best_val_ep, best_val_metric))
            best_val_ep = epoch
            best_val_metric = val_losses[model.val_metric + '_val']
            logging.info("The updated values: ep %d, val %0.2f" %
                         (best_val_ep, best_val_metric))
            model.save_networks('bestval', epoch, best_val_metric, best_val_ep)
            with open(os.path.join(model.save_dir, 'bestval_ep.txt'), 'a') as f:
                f.write('ep: %d %s: %f\n' % (epoch, model.val_metric + '_val',
                                           best_val_metric))
        elif epoch > (best_val_ep + 5*opt.patience):
            logging.info("Current epoch %d, last updated val at ep %d" %
                         (epoch, best_val_ep))
            logging.info("Stopping training...")
            break
        elif best_val_metric == 1:
            logging.info("Reached perfect val accuracy metric")
            logging.info("Stopping training...")
            break
        elif opt.max_epochs and epoch > opt.max_epochs:
            logging.info("Reached max epoch count")
            logging.info("Stopping training...")
            break

        logging.info("Best val ep: %d" % best_val_ep)
        logging.info("Best val metric: %0.2f" % best_val_metric)

        # save final plots at end of each epoch
        visualizer.save_final_plots()

        if epoch % opt.save_epoch_freq == 0 and epoch > 0:
            logging.info('saving the model at the end of epoch %d, total batches %d' % (epoch, total_batches))
            model.save_networks('latest', epoch, best_val_metric,
                                best_val_ep)
            model.save_networks(epoch, epoch, best_val_metric, best_val_ep)

        logging.info('End of epoch %d \t Time Taken: %d sec' %
              (epoch, time.time() - epoch_start_time))
        model.update_learning_rate(metric=val_losses[model.val_metric + '_val'])
        epoch += 1

    # save model at the end of training
    visualizer.save_final_plots()
    model.save_networks('latest', epoch, best_val_metric,
                        best_val_ep)
    model.save_networks(epoch, epoch, best_val_metric, best_val_ep)
    logging.info("Finished Training")
Пример #16
0
                # time_str = time.asctime(time.localtime(time.time()))
                # print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.format(time_str, i, ii, speed, loss.item(), acc))

                progress_bar.set_description(
                    'Epoch: {}/{} Iter: {}/{} Loss: {:.5f} Acc: {:.5f} Speed: {:.4f}iters/s'.format(epoch, opt.max_epoch,
                                                                                                    iter,
                                                                                                    num_iter_per_epoch,
                                                                                                    loss.item(), acc,
                                                                                                    speed))
                writer.add_scalars('Loss', {'train': loss}, iters)
                writer.add_scalars('Accuracy', {'train': acc}, iters)
                # writer.add_graph(model, data_input)


                if opt.display:
                    visualizer.display_current_results(iters, loss.item(), name='train_loss')
                    visualizer.display_current_results(iters, acc, name='train_acc')

                start = time.time()

        # write
        f.write(f'\nEpochs : {epoch} / {opt.max_epoch}\n')
        f.write('[train_loss] : {:.5f} [train_acc] : {:.5f}\n'.format(loss.item(), acc))

        if epoch % opt.val_interval == 0:
            model.eval()

            total = 0
            total_acc = 0
            total_loss = 0
Пример #17
0
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                visualizer.display_current_results(model.get_current_visuals(),
                                                   total_iters)

            if total_iters % opt.print_freq == 0:  # print training losses
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                visualizer.plot_current_losses(total_iters, losses)

            iter_data_time = time.time()

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.num_epoch, time.time() - epoch_start_time))

    # evaluate and save result tqdm
    with tqdm(total=len(dataset)) as progress_bar:
Пример #18
0
    epoch_start_time = time.time()
    if epoch != start_epoch:
        epoch_iter = epoch_iter % dataset_size
    model.model.train()
    model.freeze_bn()
    for i, data in enumerate(dataset, start=epoch_iter):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        # add some commits
        model.forward(data)
        model.backward(total_steps,
                       opt.nepochs * dataset.__len__() * opt.batchSize + 1)
        if total_steps % opt.display_freq == 0:
            visuals = model.get_visuals(total_steps)
            visualizer.display_current_results(visuals, epoch, total_steps)
        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
        # print time.time()-iter_start_time

    # end of epoch
    model.model.eval()
    if dataset_val != None:
        label_trues, label_preds = [], []
        for i, data in enumerate(dataset_val):
            with torch.no_grad():
                seggt, segpred = model.forward(data, False)
            seggt = seggt.data.cpu().numpy()
Пример #19
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.
    visualizer = Visualizer()

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True
    using_spanet_dadaset = True  # added by he
    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            if not using_spanet_dadaset:  # changed by he
                train_set = create_dataset(dataset_opt)
                train_size = int(
                    math.ceil(len(train_set) / dataset_opt['batch_size']))
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                total_iters = int(opt['train']['niter'])
                total_epochs = int(math.ceil(total_iters / train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
                train_loader = create_dataloader(train_set, dataset_opt)
            else:
                train_set = LRHRDataset_FromTXT(
                    '/home/spl/anaconda2/envs/pth10-py36-cu10/CODE-Net/datasets/Real_Rain_Streaks_Dataset_CVPR19_spanet/Training/real_world_refined.txt',
                    dataset_opt)
                train_size = int(
                    math.ceil(len(train_set) / dataset_opt['batch_size']))
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                total_iters = int(opt['train']['niter'])
                total_epochs = int(math.ceil(total_iters / train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
                train_loader = create_dataloader(train_set, dataset_opt)

        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):

        epoch_iter = 0
        for _, train_data in enumerate(train_loader):

            current_step += 1
            epoch_iter += 1
            if current_step > total_iters:
                break

            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                ###################################

                visualizer.plot_current_losses(epoch,
                                               float(epoch_iter) / train_size,
                                               logs)  #

                visuals = model.get_current_visuals()
                vvis = OrderedDict()
                vvis['LR'] = visuals['LR']
                vvis['SR'] = visuals['SR']
                vvis['HR'] = visuals['HR']

                visualizer.display_current_results(vvis, epoch, False)

                ####################################
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    #print(img_dir)
                    util.mkdir(img_dir)

                    model.feed_data(
                        val_data, need_HR=True
                    )  # need_HR=true when using LRHR_dataset.py in training phase
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(
                        visuals['HR']
                    )  # uint8       # ['HR'] when using LRHR_dataset.py in training phase
                    #gt_img = util.tensor2img(visuals['LR'])  # uint8        # ['LR'] when using LR_dataset.py in training phase

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.

                    ################################################################################
                    avg_psnr += util.calculate_psnr(sr_img * 255, gt_img * 255)
                    ##################################################################################

                avg_psnr = avg_psnr / idx
                psnr_dic = OrderedDict()
                psnr_dic['psnr'] = avg_psnr
                visualizer.plot_current_psnr(epoch,
                                             float(epoch_iter) / train_size,
                                             psnr_dic)

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')