Ejemplo n.º 1
0
def main():
    dataset = 'Vimeo90K_LQ'  # REDS | Vimeo90K | DIV2K800_sub
    opt = {}
    opt['dist'] = False
    opt['gpu_ids'] = [0]
    if dataset == 'LQGT':
        opt['name'] = 'test_KWAI'
        opt['dataroot_GT'] = '/home/web_server/zhouhuanxiang/disk/data/HD_UGC_raw'
        opt['dataroot_LQ'] = '/home/web_server/zhouhuanxiang/disk/data/HD_UGC_crf40_raw'
        opt['mode'] = 'LQGT'
        opt['color'] = 'RGB'
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 1
        opt['batch_size'] = 4
        opt['GT_size'] = 256
        opt['LQ_size'] = 256
        opt['scale'] = 1
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'img'  # img | lmdb | mc
    elif dataset == 'KWAI':
        opt['name'] = 'test_KWAI'
        opt['dataroot_GT'] = '/home/web_server/zhouhuanxiang/disk/vdata/HD_UGC.lmdb'
        opt['dataroot_LQ'] = '/home/web_server/zhouhuanxiang/disk/vdata/HD_UGC_crf40.lmdb'
        opt['mode'] = 'KWAI'
        opt['N_frames'] = 5
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 1
        opt['batch_size'] = 4
        opt['GT_size'] = 256
        opt['LQ_size'] = 256
        opt['scale'] = 1
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'lmdb'  # img | lmdb | mc
    elif dataset == 'REDS':
        opt['name'] = 'test_REDS'
        opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
        opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
        opt['mode'] = 'REDS'
        opt['N_frames'] = 5
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 8
        opt['batch_size'] = 16
        opt['GT_size'] = 256
        opt['LQ_size'] = 64
        opt['scale'] = 4
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'lmdb'  # img | lmdb | mc
    elif dataset == 'Vimeo90K':
        opt['name'] = 'test_Vimeo90K'
        opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
        opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
        opt['mode'] = 'Vimeo90K'
        opt['N_frames'] = 7
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 8
        opt['batch_size'] = 16
        opt['GT_size'] = 256
        opt['LQ_size'] = 64
        opt['scale'] = 4
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'lmdb'  # img | lmdb | mc
    elif dataset == 'Vimeo90K_test':
        opt['name'] = 'vimeo90k-test'
        opt['dataroot_GT'] = '/home/web_server/zhouhuanxiang/disk/vimeo/vimeo_septuplet/sequences'
        opt['dataroot_LQ'] = '/home/web_server/zhouhuanxiang/disk/vimeo/vimeo_septuplet/sequences_blocky32'
        opt['mode'] = 'Vimeo90K_test'
        opt['all_gt'] = True
        opt['N_frames'] = 7
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 8
        opt['batch_size'] = 16
        opt['GT_size'] = 256
        opt['LQ_size'] = 256
        opt['scale'] = 1
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'img'  # img | lmdb | mc
    elif dataset == 'Vimeo90K_LQ':
        opt['name'] = 'Vimeo90K-LQ'
        opt['dataroot_LHQ'] = '/home/web_server/zhouhuanxiang/disk/vimeo/vimeo_septuplet/sequences_blocky32'
        opt['dataroot_LQ'] = '/home/web_server/zhouhuanxiang/disk/vimeo/vimeo_septuplet/sequences_blocky37'
        opt['dataroot_LLQ'] = '/home/web_server/zhouhuanxiang/disk/vimeo/vimeo_septuplet/sequences_blocky42'
        opt['mode'] = 'Vimeo90K_LQ'
        opt['patch_size'] = 32
        opt['patch_repeat'] = 5
        opt['N_frames'] = 7
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 8
        opt['batch_size'] = 16
        opt['GT_size'] = 256
        opt['LQ_size'] = 256
        opt['scale'] = 1
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['interval_list'] = [1]
        opt['random_reverse'] = False
        opt['border_mode'] = False
        opt['cache_keys'] = None
        opt['data_type'] = 'img'  # img | lmdb | mc
    elif dataset == 'DIV2K800_sub':
        opt['name'] = 'DIV2K800'
        opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
        opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
        opt['mode'] = 'LQGT'
        opt['phase'] = 'train'
        opt['use_shuffle'] = True
        opt['n_workers'] = 8
        opt['batch_size'] = 16
        opt['GT_size'] = 128
        opt['scale'] = 4
        opt['use_flip'] = True
        opt['use_rot'] = True
        opt['color'] = 'RGB'
        opt['data_type'] = 'lmdb'  # img | lmdb
    else:
        raise ValueError('Please implement by yourself.')

    util.mkdir('/home/web_server/zhouhuanxiang/tmp')
    train_set = create_dataset(opt)
    train_loader = create_dataloader(train_set, opt, opt, None)
    nrow = int(math.sqrt(opt['batch_size']))
    padding = 2 if opt['phase'] == 'train' else 0

    print('start...')
    for i, data in enumerate(train_loader):
        if i > 5:
            break
        print(i)


        LQs = data['LQs']
        # LLQs = data['LLQs']
        # LHQs = data['LHQs']
        patch_labels = data['patch_labels']
        patch_offsets = data['patch_offsets']
        print(patch_labels.shape)
        print(patch_offsets.shape)
        print(LQs.shape)

        for j in range(LQs.size(1)):
            torchvision.utils.save_image(LQs[:, j, :, :, :],
                                         '/home/web_server/zhouhuanxiang/tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
                                         padding=padding, normalize=False)
Ejemplo n.º 2
0
from data import create_dataset
import time

if __name__ == '__main__':
    # for i in range(2, 6):
    # for i in [1]:
    #     start = time.time()
    #     create_dataset(num_predict_words=2**i)
    #     print(time.time() - start)
    create_dataset(num_predict_words=7)
Ejemplo n.º 3
0
def batch_generate(opt):

    opt.dataroot = opt.DATA_PATH
    #     opt.epoch = 200
    # hard-code some parameters for test
    opt.num_threads = 0  # test code only supports num_threads = 1
    #opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    opt.load_size = opt.crop_size
    opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(
        opt)  # create a model given opt.model and other options
    # load model from a path - for platform
    load_path = opt.MODEL_FILE
    net = getattr(model, 'netG_A')
    if isinstance(net, torch.nn.DataParallel):
        net = net.module
    print('loading the model from %s' % load_path)
    state_dict = torch.load(load_path, map_location=str(model.device))
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    # patch InstanceNorm checkpoints prior to 0.4
    for key in list(state_dict.keys()
                    ):  # need to copy keys here because we mutate in loop
        model._BaseModel__patch_instance_norm_state_dict(
            state_dict, net, key.split('.'))
    net.load_state_dict(state_dict)
    model.print_networks(opt.verbose)

    # create a website
    sha = hashlib.sha256()
    sha.update(str(time.time()).encode('utf-8'))
    web_dir = opt.OUTPUT_PATH + "/" + sha.hexdigest(
    )  # specific output dir - for platform
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.epoch))
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    starttime = time.time()

    if opt.eval:
        model.eval()
    starttime = time.time()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input_predict(data)  # unpack data from data loader
        model.test()  # run inference

        for i in range(len(model.fake_B)):
            visual = model.fake_B[i]
            img_path = model.image_paths[i]

            im = tensor2im([visual])[0]
            img_name = img_path.split('/')[-1]
            save_image(im, "{}/{}".format(web_dir, img_name))

    lasttime = time.time()
    print("Work Done!!!")
    print('Generated', len(dataset), 'maps. Total Time Cost: ',
          lasttime - starttime, 'seconds')

    return web_dir
Ejemplo n.º 4
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='options/train/train_KPSAGAN.yml',
                        help='Path to option YMAL file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])
            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)
            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Ejemplo n.º 5
0
def main():
    # os.environ['CUDA_VISIBLE_DEVICES']="1" # You can specify your GPU device here. I failed to perform it by `torch.cuda.set_device()`.
    parser = argparse.ArgumentParser(
        description='Train Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)

    if opt['train']['resume'] is False:
        util.mkdir_and_rename(
            opt['path']['exp_root'])  # rename old experiments if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'exp_root' and \
                     not key == 'pretrain_G' and not key == 'pretrain_D'))
        option.save(opt)
        opt = option.dict_to_nonedict(
            opt)  # Convert to NoneDict, which return None for missing key.
    else:
        opt = option.dict_to_nonedict(opt)
        if opt['train']['resume_path'] is None:
            raise ValueError("The 'resume_path' does not declarate")

    if opt['exec_debug']:
        NUM_EPOCH = 100
        opt['datasets']['train']['dataroot_HR'] = opt['datasets']['train'][
            'dataroot_HR_debug']  #"./dataset/TrainData/DIV2K_train_HR_sub",
        opt['datasets']['train']['dataroot_LR'] = opt['datasets']['train'][
            'dataroot_LR_debug']  #./dataset/TrainData/DIV2K_train_HR_sub_LRx3"

    else:
        NUM_EPOCH = int(opt['train']['num_epochs'])

    # random seed
    seed = opt['train']['manual_seed']  #0
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_loader = create_dataloader(train_set, dataset_opt)
            print('Number of train images in [%s]: %d' %
                  (dataset_opt['name'], len(train_set)))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [%s]: %d' %
                  (dataset_opt['name'], len(val_set)))
        elif phase == 'test':
            pass
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)

    if train_loader is None:
        raise ValueError("The training data does not exist")

    # TODO: design an exp that can obtain the location of the biggest error
    if opt['mode'] == 'sr':
        solver = SRModel1(opt)
    elif opt['mode'] == 'fi':
        solver = SRModel1(opt)
    elif opt['mode'] == 'srgan':
        solver = SRModelGAN(opt)
    elif opt['mode'] == 'msan':
        solver = SRModel1(opt)
    elif opt['mode'] == 'sr_curriculum':
        solver = SRModelCurriculum(opt)

    solver.summary(train_set[0]['LR'].size())
    solver.net_init()
    print('[Start Training]')

    start_time = time.time()

    start_epoch = 1
    if opt['train']['resume']:
        start_epoch = solver.load()

    for epoch in range(start_epoch, NUM_EPOCH + 1):
        # Initialization
        solver.training_loss = 0.0
        epoch_loss_log = 0.0

        if opt['mode'] == 'sr' or opt['mode'] == 'srgan' or opt[
                'mode'] == 'sr_curriculum' or opt['mode'] == 'fi' or opt[
                    'mode'] == 'msan':
            training_results = {'batch_size': 0, 'training_loss': 0.0}
        else:
            pass  # TODO
        train_bar = tqdm(train_loader)

        # Train model
        for iter, batch in enumerate(train_bar):
            solver.feed_data(batch)
            iter_loss = solver.train_step()
            epoch_loss_log += iter_loss.item()
            batch_size = batch['LR'].size(0)
            training_results['batch_size'] += batch_size

            if opt['mode'] == 'sr':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'srgan':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'fi':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'msan':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'sr_curriculum':
                training_results[
                    'training_loss'] += iter_loss.data * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            else:
                pass  # TODO

        solver.last_epoch_loss = epoch_loss_log / (len(train_bar))

        train_bar.close()
        time_elapse = time.time() - start_time
        start_time = time.time()
        print('Train Loss: %.4f' % (training_results['training_loss'] /
                                    training_results['batch_size']))

        # validate
        val_results = {
            'batch_size': 0,
            'val_loss': 0.0,
            'psnr': 0.0,
            'ssim': 0.0
        }

        if epoch % solver.val_step == 0 and epoch != 0:
            print('[Validating...]')
            start_time = time.time()
            solver.val_loss = 0.0

            vis_index = 1

            for iter, batch in enumerate(val_loader):
                visuals_list = []

                solver.feed_data(batch)
                iter_loss = solver.test(opt['chop'])
                batch_size = batch['LR'].size(0)
                val_results['batch_size'] += batch_size

                visuals = solver.get_current_visual()  # float cpu tensor

                sr_img = np.transpose(
                    util.quantize(visuals['SR'], opt['rgb_range']).numpy(),
                    (1, 2, 0)).astype(np.uint8)
                gt_img = np.transpose(
                    util.quantize(visuals['HR'], opt['rgb_range']).numpy(),
                    (1, 2, 0)).astype(np.uint8)

                # calculate PSNR
                crop_size = opt['scale']
                cropped_sr_img = sr_img[crop_size:-crop_size,
                                        crop_size:-crop_size, :]
                cropped_gt_img = gt_img[crop_size:-crop_size,
                                        crop_size:-crop_size, :]

                cropped_sr_img = cropped_sr_img / 255.
                cropped_gt_img = cropped_gt_img / 255.
                cropped_sr_img = rgb2ycbcr(cropped_sr_img).astype(np.float32)
                cropped_gt_img = rgb2ycbcr(cropped_gt_img).astype(np.float32)

                ##################################################################################
                # b, r, g = cv2.split(cropped_sr_img)
                #
                # RG = r - g
                # YB = (r + g) / 2 - b
                # m, n, o = np.shape(cropped_sr_img)  # img为三维 rbg为二维 o并未用到
                # K = m * n
                # alpha_L = 0.1
                # alpha_R = 0.1  # 参数α 可调
                # T_alpha_L = math.ceil(alpha_L * K)  # 向上取整 #表示去除区间
                # T_alpha_R = math.floor(alpha_R * K)  # 向下取整
                #
                # RG_list = RG.flatten()  # 二维数组转一维(方便计算)
                # RG_list = sorted(RG_list)  # 排序
                # sum_RG = 0  # 计算平均值
                # for i in range(T_alpha_L + 1, K - T_alpha_R):
                #     sum_RG = sum_RG + RG_list[i]
                # U_RG = sum_RG / (K - T_alpha_R - T_alpha_L)
                # squ_RG = 0  # 计算方差
                # for i in range(K):
                #     squ_RG = squ_RG + np.square(RG_list[i] - U_RG)
                # sigma2_RG = squ_RG / K
                #
                # # YB和RG计算一样
                # YB_list = YB.flatten()
                # YB_list = sorted(YB_list)
                # sum_YB = 0
                # for i in range(T_alpha_L + 1, K - T_alpha_R):
                #     sum_YB = sum_YB + YB_list[i]
                # U_YB = sum_YB / (K - T_alpha_R - T_alpha_L)
                # squ_YB = 0
                # for i in range(K):
                #     squ_YB = squ_YB + np.square(YB_list[i] - U_YB)
                # sigma2_YB = squ_YB / K
                #
                # uicm = -0.0268 * np.sqrt(np.square(U_RG) + np.square(U_YB)) + 0.1586 * np.sqrt(sigma2_RG + sigma2_RG)
                ##################################################################################

                val_results['val_loss'] += iter_loss * batch_size

                val_results['psnr'] += util.calc_psnr(cropped_sr_img * 255,
                                                      cropped_gt_img * 255)
                val_results['ssim'] += util.compute_ssim1(
                    cropped_sr_img * 255, cropped_gt_img * 255)

                if opt['mode'] == 'srgan':
                    pass  # TODO

                # if opt['save_image']:
                #     visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), opt['rgb_range']),
                #                          util.quantize(visuals['SR'].squeeze(0), opt['rgb_range'])])
                #
                #     images = torch.stack(visuals_list)
                #     img = thutil.make_grid(images, nrow=2, padding=5)
                #     ndarr = img.byte().permute(1, 2, 0).numpy()
                #     misc.imsave(os.path.join(solver.vis_dir, 'epoch_%d_%d.png' % (epoch, vis_index)), ndarr)
                #     vis_index += 1

            avg_psnr = val_results['psnr'] / val_results['batch_size']
            avg_ssim = val_results['ssim'] / val_results['batch_size']
            print(
                'Valid Loss: %.4f | Avg. PSNR: %.4f | Avg. SSIM: %.4f | Learning Rate: %f'
                % (val_results['val_loss'] / val_results['batch_size'],
                   avg_psnr, avg_ssim, solver.current_learning_rate()))

            time_elapse = start_time - time.time()

            #if epoch%solver.log_step == 0 and epoch != 0:
            # tensorboard visualization
            solver.training_loss = training_results[
                'training_loss'] / training_results['batch_size']
            solver.val_loss = val_results['val_loss'] / val_results[
                'batch_size']

            solver.tf_log(epoch)

            # statistics
            if opt['mode'] == 'sr' or opt['mode'] == 'srgan' or opt[
                    'mode'] == 'sr_curriculum' or opt['mode'] == 'fi' or opt[
                        'mode'] == 'msan':
                solver.results['training_loss'].append(
                    solver.training_loss.cpu().data.item())
                solver.results['val_loss'].append(
                    solver.val_loss.cpu().data.item())
                solver.results['psnr'].append(avg_psnr)
                solver.results['ssim'].append(avg_ssim)
            else:
                pass  # TODO

            is_best = False
            if solver.best_prec < solver.results['psnr'][-1]:
                solver.best_prec = solver.results['psnr'][-1]
                is_best = True

            print(
                '#############################################################'
            )
            print(solver.best_prec)
            print(solver.results['psnr'][-1])
            print(
                '***************************************************************'
            )
            # print(is_best)
            # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
            # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
            solver.save(epoch, is_best)

        # update lr
        solver.update_learning_rate(epoch)

    data_frame = pd.DataFrame(data={
        'training_loss': solver.results['training_loss'],
        'val_loss': solver.results['val_loss'],
        'psnr': solver.results['psnr'],
        'ssim': solver.results['ssim']
    },
                              index=range(1, NUM_EPOCH + 1))
    data_frame.to_csv(os.path.join(solver.results_dir, 'train_results.csv'),
                      index_label='Epoch')
Ejemplo n.º 6
0
from train_sate_encoder import option as SateOption
from models import networks
import torch
import torchvision
import random
import numpy as np
from PIL import Image

# options
opt = TestOptions().parse()
opt.num_threads = 1   # test code only supports num_threads=1
opt.batch_size = 1   # test code only supports batch_size=1
opt.serial_batches = True  # no shuffle

# create dataset
dataset = create_dataset(opt)
model = create_model(opt)
model.setup(opt)
model.eval()
print('Loading model %s' % opt.model)

######
sateOpt = SateOption()
sateE = networks.define_E(sateOpt.output_nc, sateOpt.nz, sateOpt.nef, netE=sateOpt.netE, norm=sateOpt.norm, nl=sateOpt.nl,
                            init_type=sateOpt.init_type, init_gain=sateOpt.init_gain, gpu_ids=sateOpt.gpu_ids, vaeLike=sateOpt.use_vae)
sateCheckpoint = torch.load('sate_encoder/sate_encoder_latest.pth')
sateE.load_state_dict(sateCheckpoint['model_state_dict'])
sateE.eval()
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
Ejemplo n.º 7
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    opt['dist'] = False
    rank = -1
    print('Disabled distributed training.')

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        resume_state_path, _ = get_resume_paths(opt)

        # distributed resuming: all load into default GPU
        if resume_state_path is None:
            resume_state = None
        else:
            device_id = torch.cuda.current_device()
            resume_state = torch.load(resume_state_path,
                                      map_location=lambda storage, loc: storage.cuda(device_id))
            option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))

        # tensorboard logger
        if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            conf_name = basename(args.opt).replace(".yml", "")
            exp_dir = opt['path']['experiments_root']
            log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
            log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
            tb_logger_train = SummaryWriter(log_dir=log_dir_train)
            tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            print('Dataset created')
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    current_step = 0 if resume_state is None else resume_state['iter']
    model = create_model(opt, current_step)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    timer = Timer()
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    timerData = TickTock()

    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)

        timerData.tick()
        for _, train_data in enumerate(train_loader):
            timerData.tock()
            current_step += 1
            if current_step > total_iters:
                break

            #### training
            model.feed_data(train_data)

            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            try:
                nll = model.optimize_parameters(current_step)
            except RuntimeError as e:
                print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
                print(e)

            if nll is None:
                nll = 0

            #### log
            def eta(t_iter):
                return (t_iter * (opt['train']['niter'] - current_step)) / 3600

            if current_step % opt['logger']['print_freq'] == 0 \
                    or current_step - (resume_state['iter'] if resume_state else 0) < 25:
                avg_time = timer.get_average_and_reset()
                avg_data_time = timerData.get_average_and_reset()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
                    eta(avg_time), nll)
                print(message)
            timer.tick()
            # Reduce number of logs
            if current_step % 5 == 0:
                tb_logger_train.add_scalar('loss/nll', nll, current_step)
                tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
                tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
                tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
                tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
                for k, v in model.get_current_log().items():
                    tb_logger_train.add_scalar(k, v, current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                nlls = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)

                    nll = model.test()
                    if nll is None:
                        nll = 0
                    nlls.append(nll)

                    visuals = model.get_current_visuals()

                    sr_img = None
                    # Save SR images for reference
                    if hasattr(model, 'heats'):
                        for heat in model.heats:
                            for i in range(model.n_sample):
                                sr_img = util.tensor2img(visuals['SR', heat, i])  # uint8
                                save_img_path = os.path.join(img_dir,
                                                             '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
                                                                                                    current_step,
                                                                                                    int(heat * 100), i))
                                util.save_img(sr_img, save_img_path)
                    else:
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)
                    assert sr_img is not None

                    # Save LQ images for reference
                    save_img_path_lq = os.path.join(img_dir,
                                                    '{:s}_LQ.png'.format(img_name))
                    if not os.path.isfile(save_img_path_lq):
                        lq_img = util.tensor2img(visuals['LQ'])  # uint8
                        util.save_img(
                            cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
                                       interpolation=cv2.INTER_NEAREST),
                            save_img_path_lq)

                    # Save GT images for reference
                    gt_img = util.tensor2img(visuals['GT'])  # uint8
                    save_img_path_gt = os.path.join(img_dir,
                                                    '{:s}_GT.png'.format(img_name))
                    if not os.path.isfile(save_img_path_gt):
                        util.save_img(gt_img, save_img_path_gt)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_nll = sum(nlls) / len(nlls)

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                    epoch, current_step, avg_psnr))

                # tensorboard logger
                tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
                tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)

                tb_logger_train.flush()
                tb_logger_valid.flush()

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

            timerData.tick()

    with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
        f.write("TRAIN_DONE")

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Ejemplo n.º 8
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    val_loader = False
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            batch_size = dataset_opt.get('batch_size', 4)
            virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size)
            virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size
            train_size = int(math.ceil(len(train_set) / batch_size))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        virtual_step = current_step * virtual_batch_size / batch_size \
            if virtual_batch_size and virtual_batch_size > batch_size else current_step
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed
        del resume_state
        # start the iteration time when resuming
        t0 = time.time()
    else:
        current_step = 0
        virtual_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        for epoch in range(start_epoch, total_epochs*(virtual_batch_size//batch_size)):
            for n, train_data in enumerate(train_loader,start=1):

                if virtual_step == 0:
                    # first iteration start time
                    t0 = time.time()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)
                model.optimize_parameters(virtual_step)

                # log
                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    t1 = time.time()

                    logs = model.get_current_log()
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, i_time: {:.4f} sec.> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), (t1 - t0))
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar(k, v, current_step)
                    logger.info(message)

                    # # start time for next iteration
                    # t0 = time.time()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))

                # save models and training states (changed to save models before validation)
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa:
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=train_loader)
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(epoch + (n >= len(train_loader)), current_step, opt['logger']['overwrite_chkp'])
                    logger.info('Models and training states saved.')

                # validation
                if val_loader and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_sr_imgs_list = []
                    val_gt_imgs_list = []
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    for val_data in val_loader:
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test(val_data)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()
                        sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])

                        # Save SR images for reference
                        if opt['train']['overwrite_val_imgs']:
                            save_img_path = os.path.join(img_dir, '{:s}.png'.format(\
                                img_name))
                        else:
                            save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                                img_name, current_step))

                        # save single images or lr / sr comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp(lr_img, sr_img, save_img_path)
                        else:
                            util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        crop_size = opt['scale']
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size = crop_size)  #, only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        #print(r)
                        formatted_res = r['name'].upper()+': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res

                    logger.info('# Validation # '+logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step)+logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB

                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)

                    # # reset time for next iteration to skip the validation time from calculation
                    # t0 = time.time()

                if current_step % opt['logger']['print_freq'] == 0 and take_step or \
                    (val_loader and current_step % opt['train']['val_freq'] == 0 and take_step):
                    # reset time for next iteration to skip the validation time from calculation
                    t0 = time.time()

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=train_loader)
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=train_loader)
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(train_loader)), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
Ejemplo n.º 9
0
from models import create_model
from util.visualizer import Visualizer
from util.visualizer import save_segment_result
from util.metrics import RunningScore
from util import util
import time
import os
import numpy as np
import torch.nn as nn

if __name__ == '__main__':
    # 加载设置
    opt_train = TrainOptions().parse()

    # 加载训练数据集
    dataset_train = create_dataset(opt_train)
    dataset_train_size = len(dataset_train)
    print('The number of training images = %d' % dataset_train_size)

    # 创建训练模型
    model_train = create_model(opt_train)
    model_train.train()
    opt_train.continue_train = True
    model_train.setup(opt_train)

    # 设置显示训练结果的类
    visualizer = Visualizer(opt_train)
    for epoch in range(opt_train.epoch_count,
                       opt_train.niter + opt_train.niter_decay + 1):
        epoch_iters = 0
        epoch_start_time = time.time()
from mpl_toolkits.axes_grid1 import make_axes_locatable
from utils.colors import rgb2ycbcr
import torch.nn.functional as F

input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
output_file = "fdpr_diff_means.pt"
device = 'cuda'
patch_size = 128

if __name__ == '__main__':
    opt = option.parse(input_config, is_train=True)
    opt['dist'] = False

    # Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
    dataset_opt = opt['datasets']['train']
    train_set = create_dataset(dataset_opt)
    train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
    total_iters = int(opt['train']['niter'])
    total_epochs = int(math.ceil(total_iters / train_size))
    train_loader = create_dataloader(train_set, dataset_opt, opt, None)
    print('Number of train images: {:,d}, iters: {:,d}'.format(
        len(train_set), train_size))

    # calculate the perceptual weights
    master_diff = np.zeros((patch_size, patch_size))
    num_patches = 0
    all_diff_patches = []
    tq = tqdm(train_loader)
    sampled = 0
    for train_data in tq:
        if sampled > 200:
Ejemplo n.º 11
0
def train(cfg, writer, logger):
    init_random()

    device = torch.device("cuda:{}".format(cfg['model']['default_gpu'])
                          if torch.cuda.is_available() else 'cpu')

    # create dataSet
    data_sets = create_dataset(cfg, writer, logger)  # source_train\ target_train\ source_valid\ target_valid + _loader
    if cfg.get('valset') == 'gta5':
        val_loader = data_sets.source_valid_loader
    else:
        val_loader = data_sets.target_valid_loader
    logger.info('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    print('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    logger.info('valset is {}'.format(cfg.get('valset')))
    print('val_set is {}'.format(cfg.get('valset')))
    logger.info('val batch_size is {}'.format(val_loader.batch_size))
    print('val batch_size is {}'.format(val_loader.batch_size))

    # create model
    model = CustomModel(cfg, writer, logger)

    # LOSS function
    loss_fn = get_loss_function(cfg)

    # load category anchors
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']

    # Setup Metrics
    running_metrics_val = RunningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = RunningScore(cfg['data']['source']['n_class'])
    val_loss_meter, source_val_loss_meter = AverageMeter(), AverageMeter()
    time_meter = AverageMeter()

    # begin training
    model.iter = 0
    epochs = cfg['training']['epochs']
    for epoch in tqdm(range(epochs)):
        if model.iter > cfg['training']['train_iters']:
            break

        for (target_image, target_label, target_img_name) in tqdm(data_sets.target_train_loader):
            start_ts = time.time()
            model.iter += 1
            if model.iter > cfg['training']['train_iters']:
                break

            ############################
            # train on source & target #
            ############################
            # get data
            images, labels, source_img_name = data_sets.source_train_loader.next()
            images, labels = images.to(device), labels.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)

            # init model
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn'):
                model.freeze_bn_apply()
            model.optimizer_zero_grad()

            # train for one batch
            loss, loss_cls_L2, loss_pseudo = model.step(images, labels, target_image, target_label)
            model.scheduler_step()

            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            # print
            time_meter.update(time.time() - start_ts)
            if (model.iter + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f} " \
                          "Loss_cls_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f} "
                print_str = fmt_str.format(epoch + 1, epochs, model.iter + 1, cfg['training']['train_iters'],
                                           loss.item(), loss_cls_L2, loss_pseudo,
                                           time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info('unchanged number of objective class vector: {}'.format(unchanged_cls_num))
                writer.add_scalar('loss/train_loss', loss.item(), model.iter + 1)
                writer.add_scalar('loss/train_cls_L2Loss', loss_cls_L2, model.iter + 1)
                writer.add_scalar('loss/train_pseudoLoss', loss_pseudo, model.iter + 1)
                time_meter.reset()

                score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores()
                logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('clus_Recall: {}'.format(model.metrics.calc_mean_Clu_recall()))
                logger.info('clus_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_clu[:, 0] / model.metrics.classes_recall_clu[:, 2])))

                score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores()
                logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('thr_Recall: {}'.format(model.metrics.calc_mean_Thr_recall()))
                logger.info('thr_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_thr[:, 0] / model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (model.iter + 1) % cfg['training']['val_interval'] == 0 or \
                    (model.iter + 1) == cfg['training']['train_iters']:
                validation(model, logger, writer, data_sets, device, running_metrics_val, val_loss_meter, loss_fn,
                           source_val_loss_meter, source_running_metrics_val, iters=model.iter)

                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))

            # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
            monitor(model)

            model.metrics.reset()
Ejemplo n.º 12
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--HR_Root', type = str, default = "/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/FFHQ/HR", 
                        help = 'Path to val HR.')
    parser.add_argument('--LR_Root', type = str, default = "/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/FFHQ/LR", 
                        help = 'Path to val LR.')
    parser.add_argument('--Clusters', type = int, default = 3, help = 'Number of clusters')
    parser.add_argument('--Train', type = int, default = 0, help = 'Train or not')
    parser.add_argument('--Model_Path', type = str, default = "/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/FFHQ/cluster.model", 
                        help = 'Path to Cluster model')
    args = parser.parse_args()

    Root = '/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/FFHQ'
    opt['dataset']['dataroot_LR'] = args.LR_Root
    opt['dataset']['dataroot_HR'] = args.HR_Root

    test_set = create_dataset(opt['dataset'])
    test_loader = create_dataloader(test_set, opt['dataset'])

    device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
    sphere = networks.define_F(opt).to(device)

    for i in range(args.Clusters):
        try:
            os.makedirs(args.HR_Root + str(i))
            os.makedirs(args.LR_Root + str(i))
        except:
            pass

    vectors = None
    LR_paths = []
    HR_paths = []
Ejemplo n.º 13
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    Avg_train_psnr = AverageMeter()
    if opt['datasets']['train']['color'] == 'YUV':
        Avg_train_yuv_psnr = AverageMeter()
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        # reset psnr
        Avg_train_psnr.reset()
        if opt['datasets']['train']['color'] == 'YUV':
            Avg_train_yuv_psnr.reset()

        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8  NCHW

                print('GT_img', GT_img.shape)
                print('LQ', LQ.shape)

                if opt['datasets']['train']['color'] == 'YUV':
                    GT_img = data_util.ycbcr2rgb(GT_img)

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))

                if opt['datasets']['train']['color'] == 'YUV':
                    util.save_img(GT_img, save_img_path, mode='RGB')
                else:
                    util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[:, i, ...])  # uint8
                    if opt['datasets']['train']['color'] == 'YUV':
                        LQ_img = data_util.ycbcr2rgb(LQ_img)
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    if opt['datasets']['train']['color'] == 'YUV':
                        util.save_img(LQ_img, save_img_path, mode='RGB')
                    else:
                        util.save_img(LQ_img, save_img_path)

                if (train_idx >= 10):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            visuals = model.get_current_visuals(need_GT=True, save=False)

            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
            gt_img = util.tensor2img(visuals['GT'])  # uint8

            if opt['datasets']['train']['color'] == 'YUV':
                yuv_psnr = util.calculate_psnr(rlt_img, gt_img)
                rlt_img = data_util.ycbcr2rgb(rlt_img)
                gt_img = data_util.ycbcr2rgb(gt_img)

            # calculate PSNR
            psnr = util.calculate_psnr(rlt_img, gt_img)

            if opt['train']['pixel_criterion'] == 'cb+ssim':
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_psnr.update(psnr, 1)
                if opt['datasets']['train']['color'] == 'YUV':
                    Avg_train_yuv_psnr.update(yuv_psnr, 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)
                Avg_train_psnr.update(psnr, 1)
                if opt['datasets']['train']['color'] == 'YUV':
                    Avg_train_yuv_psnr.update(yuv_psnr, 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
                message_train_loss += ' psnr_inst : {:.2f}'.format(psnr)
                message_train_loss += ' psnr_avg : {:.2f}'.format(
                    Avg_train_psnr.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
                if opt['datasets']['train']['color'] == 'YUV':
                    message_train_loss += ' yuv_psnr_inst : {:.2f}'.format(
                        yuv_psnr)
                message_train_loss += ' psnr_inst : {:.2f}'.format(psnr)
                if opt['datasets']['train']['color'] == 'YUV':
                    message_train_loss += ' yuv_psnr_avg : {:.2f}'.format(
                        Avg_train_yuv_psnr.avg)
                message_train_loss += ' psnr_avg : {:.2f}'.format(
                    Avg_train_psnr.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                # tensorboard logger - avg part
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar('train_avg_loss',
                                                 Avg_train_loss.avg,
                                                 current_step)
                            if opt['datasets']['train']['color'] == 'YUV':
                                tb_logger.add_scalar('yuv_psnr_avg',
                                                     Avg_train_yuv_psnr.avg,
                                                     current_step)
                            tb_logger.add_scalar('psnr_avg',
                                                 Avg_train_psnr.avg,
                                                 current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        if 'debug' in opt['name']:
                            print('idx', idx)
                        #     if (idx >= 3):
                        #         break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        if opt['datasets']['train']['color'] == 'YUV':
                            rlt_img = data_util.ycbcr2rgb(rlt_img)
                            gt_img = data_util.ycbcr2rgb(gt_img)

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        # ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim = 0  # to do save time do not use it
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # SSIM: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        # if 'debug' in opt['name']:
                        #     if (val_inx >= 5):
                        #         break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        # print(val_data['LQs'].size())

                        # H_S = val_data['LQs'].size(3)  # 540
                        # W_S = val_data['LQs'].size(4)  # 960

                        # print(H_S)
                        # print(W_S)

                        # blank_1_S = 0
                        # blank_2_S = 0

                        # print(val_data['LQs'][0,2,0,:,:].size())

                        # for i in range(H_S):
                        #     if not sum(val_data['LQs'][0,2,0,i,:]) == 0:
                        #         blank_1_S = i - 1
                        #         # assert not sum(data_S[:, :, 0][i+1]) == 0
                        #         break

                        # for i in range(H_S):
                        #     if not sum(val_data['LQs'][0,2,0,:,H_S - i - 1]) == 0:
                        #         blank_2_S = (H_S - 1) - i - 1
                        #         # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                        #         break
                        # print('LQ :', blank_1_S, blank_2_S)

                        # if blank_1_S == -1:
                        #     print('LQ has no blank')
                        #     blank_1_S = 0
                        #     blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        # print("LQ",val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        # blank_1_L = blank_1_S << 2
                        # blank_2_L = blank_2_S << 2
                        # print(blank_1_L, blank_2_L)

                        # print(model.fake_H.size())

                        # if not blank_1_S == 0:
                        #     # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                        #     model.fake_H[:, :,0:blank_1_L, :] = 0
                        #     model.fake_H[:, :,blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        if opt['datasets']['train']['color'] == 'YUV':
                            rlt_img = data_util.ycbcr2rgb(rlt_img)
                            gt_img = data_util.ycbcr2rgb(gt_img)

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        # ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim = 0  # to do save time do not use it
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            if rank <= 0:
                logger.info(
                    "Finished an epoch, Check and Save the model weights")
                # we check the validation loss instead of training loss. OK~
                if saved_total_loss >= val_loss_total_avg:
                    saved_total_loss = val_loss_total_avg
                    #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                    model.save('best')
                    logger.info(
                        "Best Weights updated for decreased validation loss")

                else:
                    logger.info(
                        "Weights Not updated for undecreased validation loss")

                if saved_total_PSNR <= psnr_total_avg:
                    saved_total_PSNR = psnr_total_avg
                    model.save('bestPSNR')
                    logger.info(
                        "Best Weights updated for increased validation PSNR")

                else:
                    logger.info(
                        "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
Ejemplo n.º 14
0
def main():
    
    # file and stream logger
    log_path = 'log/logger_info.log'
    lg = logger('Base', log_path)
    pn = 40
    print('\n','-'*pn, 'General INFO', '-'*pn)

    # setting arguments
    parser = argparse.ArgumentParser(description='Test arguments')
    parser.add_argument('--opt', type=str, required=True, help='path to test yaml file')
    parser.add_argument('--dataset_name', type=str, default=None)
    parser.add_argument('--scale', type=int, required=True)
    parser.add_argument('--which_model', type=str, required=True, help='which pretrained model')
    parser.add_argument('--pretrained', type=str, required=True, help='pretrain path')

    args = parser.parse_args()
    args = test_parse(args, lg)
    
    # create test dataloader
    test_dataset = create_dataset(args['datasets']['test'])
    test_loader = create_loader(test_dataset, args['datasets']['test'])
    lg.info('\nHR root: [{}]\nLR root: [{}]'.format(args['datasets']['test']['dataroot_HR'], args['datasets']['test']['dataroot_LR']))
    lg.info('Number of test images: [{}]'.format(len(test_dataset)))

    # create model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = create_model(args['networks']).to(device)
    lg.info('Create model: [{}]'.format(args['networks']['which_model']))
    scale = args['scale']
    state_dict = torch.load(args['networks']['pretrained'])
    lg.info('Load pretrained from: [{}]'.format(args['networks']['pretrained']))
    model.load_state_dict(state_dict)
    
    # calculate cuda time
    if args['calc_cuda_time']:
        lg.info('Start calculating cuda time...')
        avg_test_time = calc_cuda_time(test_loader, model)
        lg.info('Average cuda time: [{:.5f}]'.format(avg_test_time))
    
    # Test
    print('\n', '-'*pn, 'Testing {}'.format(args['dataset_name']), '-'*pn)
    #pbar = ProgressBar(len(test_loader))
    psnr_list = []
    ssim_list = []
    time_list = []

    for iter, data in enumerate(test_loader):
        lr = data['LR'].to(device)
        hr = data['HR']
        
        # calculate evaluation metrics
        sr = model(lr)
        psnr, ssim = calc_metrics(tensor2np(sr), tensor2np(hr), crop_border=scale, test_Y=True)
        psnr_list.append(psnr)
        ssim_list.append(ssim)

        #pbar.update('')
        print('[{:03d}/{:03d}] || PSNR/SSIM: {:.2f}/{:.4f} || {}'.format(iter+1, len(test_loader), psnr, ssim, data['filename']))
    
    avg_psnr = sum(psnr_list) / len(psnr_list)
    avg_ssim = sum(ssim_list) / len(ssim_list)

    print('\n','-'*pn, 'Summary', '-'*pn)
    print('Average PSNR: {:.2f}  Average SSIM: {:.4f}'.format(avg_psnr, avg_ssim))


    print('\n','-'*pn, 'Finish', '-'*pn)
Ejemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser(
        description='Train Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)

    if opt['train']['resume'] is False:
        util.mkdir_and_rename(
            opt['path']['exp_root'])  # rename old experiments if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'exp_root' and \
                     not key == 'pretrain_G' and not key == 'pretrain_D'))
        option.save(opt)
        opt = option.dict_to_nonedict(
            opt)  # Convert to NoneDict, which return None for missing key.
    else:
        opt = option.dict_to_nonedict(opt)
        if opt['train']['resume_path'] is None:
            raise ValueError("The 'resume_path' does not declarate")

    NUM_EPOCH = int(opt['train']['num_epochs'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_loader = create_dataloader(train_set, dataset_opt)
            print('Number of train images in [%s]: %d' %
                  (dataset_opt['name'], len(train_set)))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [%s]: %d' %
                  (dataset_opt['name'], len(val_set)))
        elif phase == 'test':
            pass
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)

    if train_loader is None:
        raise ValueError("The training data does not exist")

    solver = RCGANModel(opt)

    solver.summary(train_set[0]['CAT'].size(), train_set[0]['IR'].size())
    solver.net_init()
    print('[Start Training]')

    start_epoch = 1
    if opt['train']['resume']:
        start_epoch = solver.load()

    for epoch in range(start_epoch, NUM_EPOCH + 1):
        # Initialization
        train_loss_g1 = 0.0
        train_loss_g2 = 0.0
        train_loss_d1 = 0.0
        train_loss_d2 = 0.0

        train_bar = tqdm(train_loader)
        # Train model
        for iter, batch in enumerate(train_bar):
            solver.feed_data(batch)
            loss_g_total, loss_d_total = solver.train_step()
            cur_batch_size = batch['CAT'].size(0)
            train_loss_g1 += loss_g_total[0] * cur_batch_size
            train_loss_g2 += loss_g_total[1] * cur_batch_size
            train_loss_d1 += loss_d_total[0] * cur_batch_size
            train_loss_d2 += loss_d_total[1] * cur_batch_size
            train_bar.set_description(
                desc='[%d/%d] G-Loss: %.4f D-Loss: %.4f' %
                (epoch, NUM_EPOCH, loss_g_total[0] + loss_g_total[1],
                 loss_d_total[0] + loss_d_total[1]))

        solver.results['train_G_loss1'].append(train_loss_g1 / len(train_set))
        solver.results['train_G_loss2'].append(train_loss_g2 / len(train_set))
        solver.results['train_D_loss1'].append(train_loss_d1 / len(train_set))
        solver.results['train_D_loss2'].append(train_loss_d2 / len(train_set))
        print('Train G-Loss: %.4f' %
              ((train_loss_g1 + train_loss_g2) / len(train_set)))
        print('Train D-Loss: %.4f' %
              ((train_loss_d1 + train_loss_d2) / len(train_set)))

        train_bar.close()

        if epoch % solver.val_step == 0 and epoch != 0:
            print('[Validating...]')
            vis_index = 1
            val_loss = 0.0
            for iter, batch in enumerate(val_loader):
                solver.feed_data(batch)
                loss_total = solver.test()
                batch_size = batch['VIS'].size(0)
                vis_list = solver.get_current_visual_list()
                images = torch.stack(vis_list)
                saveimg = thutil.make_grid(images, nrow=3, padding=5)
                saveimg_nd = saveimg.byte().permute(1, 2, 0).numpy()
                misc.imsave(
                    os.path.join(solver.vis_dir,
                                 'epoch_%d_%d.png' % (epoch, vis_index)),
                    saveimg_nd)
                vis_index += 1
                val_loss += loss_total * batch_size

            solver.results['val_G_loss'].append(val_loss / len(val_set))
            print('Valid Loss: %.4f' % (val_loss / len(val_set)))
            # statistics
            is_best = False
            if solver.best_prec > solver.results['val_G_loss'][-1]:
                solver.best_prec = solver.results['val_G_loss'][-1]
                is_best = True

            solver.save(epoch, is_best)

    data_frame = pd.DataFrame(data={
        'train_G_loss1': solver.results['train_G_loss1'],
        'train_G_loss2': solver.results['train_G_loss2'],
        'train_D_loss1': solver.results['train_D_loss1'],
        'train_D_loss2': solver.results['train_D_loss2'],
        'val_G_loss': solver.results['val_G_loss']
    },
                              index=range(1, NUM_EPOCH + 1))

    data_frame.to_csv(os.path.join(solver.results_dir, 'train_results.csv'),
                      index_label='Epoch')
Ejemplo n.º 16
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            # model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    # model.feed_data2(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            model.update_learning_rate()

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Ejemplo n.º 17
0
def main():
    args = option.add_args()
    opt = option.parse(args.opt,
                       nblocks=args.nblocks,
                       nlayers=args.nlayers,
                       iterations=args.iterations,
                       trained_model=args.trained_model,
                       lr_path=args.lr_path
                       )
    opt = option.dict_to_nonedict(opt)

    # initial configure
    scale = opt['scale']
    degrad = opt['degradation']
    network_opt = opt['networks']
    model_name = network_opt['which_model'].upper()
    if opt['self_ensemble']:
        model_name += 'plus'

    # create test dataloader
    bm_names =[]
    test_loaders = []
    for _, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        test_loaders.append(test_loader)
        print('===> Test Dataset: [%s]   Number of images: [%d]' % (test_set.name(), len(test_set)))
        bm_names.append(test_set.name())

    # create solver (and load model)
    solver = create_solver(opt)
    # Test phase
    print('===> Start Test')
    print("==================================================")
    print("Method: %s || Scale: %d || Degradation: %s"%(model_name, scale, degrad))

    # whether save the SR image?
    if opt['save_image']:
        para_save = Paralle_save_img()
        para_save.begin_background()
    # with para_save.begin_background() as para_save_imag

    for bm, test_loader in zip(bm_names, test_loaders):
        print("Test set : [%s]" % bm)

        total_psnr = []
        total_ssim = []
        total_time = []

        need_HR = False if test_loader.dataset.__class__.__name__.find('LRHR') < 0 else True

        if need_HR:
            save_img_path = os.path.join('./results/SR/' + degrad, model_name, bm, "x%d" % scale)
        else:
            save_img_path = os.path.join('./results/SR/' + bm, model_name, "x%d" % scale)

        if not os.path.exists(save_img_path):
            os.makedirs(save_img_path)

        for iter, batch in enumerate(test_loader):
            solver.feed_data(batch, need_HR=need_HR)

            # calculate forward time
            t0 = time.time()
            solver.test()
            t1 = time.time()
            total_time.append((t1 - t0))

            visuals = solver.get_current_visual(need_HR=need_HR)

            # calculate PSNR/SSIM metrics on Python
            if need_HR:
                psnr, ssim = util.calc_metrics(visuals['SR'], visuals['HR'], crop_border=scale)
                total_psnr.append(psnr)
                total_ssim.append(ssim)
                name = os.path.basename(batch['HR_path'][0]).replace('.', ('_x{}_' + model_name + '.').format(scale))

                print("[%d/%d] %s || PSNR(dB)/SSIM: %.2f/%.4f || Timer: %.4f sec ." % (iter + 1, len(test_loader),
                                                                                       os.path.basename(
                                                                                           batch['LR_path'][0]),
                                                                                       psnr, ssim,
                                                                                       (t1 - t0)))
            else:
                print("[%d/%d] %s || Timer: %.4f sec ." % (iter + 1, len(test_loader),
                                                           os.path.basename(batch['LR_path'][0]),
                                                           (t1 - t0)))
            if opt['save_image']:
                name = os.path.basename(batch['LR_path'][0]).replace('.', ('_x{}_' + model_name + '.').format(scale))
                para_save.put_image_path(filename=os.path.join(save_img_path, name), img=visuals['SR'])

        total_psnr, total_ssim = np.array(total_psnr), np.array(total_ssim)
        if need_HR:
            print("---- Average PSNR(dB) /SSIM /Speed(s) for [%s] ----" % bm)
            print("PSNR: %.2f(+/-%.2f)      SSIM: %.4f      Speed: %.4f" % (total_psnr.mean(), total_psnr.std(),
                                                                   total_ssim.mean(),
                                                                   sum(total_time) / len(total_time)))
        else:
            print("---- Average Speed(s) for [%s] is %.4f sec ----" % (bm,

                                                                       sum(total_time) / len(total_time)))
    if opt['save_image']:
        para_save.end_background()

    print("==================================================")
    print("===> Finished !")
Ejemplo n.º 18
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in [
                        'sr', 'srgan'
                ] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            # update learning rate - https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        tb_logger.close()
Ejemplo n.º 19
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals['HR'])
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Ejemplo n.º 20
0
def main():
    """
    Performs training, validation and testing.
    """
    args = setup_train_args()
    distributed = args.local_rank != -1
    master_process = args.local_rank in [0, -1]

    logger = create_logger(args)

    if distributed and args.cuda:
        # use distributed training if local rank is given
        # and GPU training is requested
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)

        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    else:
        device = torch.device('cuda' if args.cuda else 'cpu')

    # creating dataset and storing dataset splits
    # as individual variables for convenience
    datasets, tokenizer = create_dataset(args=args,
                                         device=device,
                                         distributed=distributed)

    vocab_size = len(tokenizer)

    model = create_model(args=args, vocab_size=vocab_size, device=device)

    optimizer = create_optimizer(args=args, parameters=model.parameters())

    writer = create_summary_writer(args=args,
                                   model=model,
                                   logger=logger,
                                   device=device)

    # loading previous state of the training
    best_avg_loss, init_epoch, step = load_state(args=args,
                                                 model=model,
                                                 optimizer=optimizer,
                                                 logger=logger,
                                                 device=device)

    if args.mixed and args.cuda:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    # TODO get world size here instead of 1
    train, valid, test = [(split, ceil(size / args.batch_size / 1))
                          for split, size in datasets]

    # computing the sizes of the dataset splits
    train_dataset, num_train_steps = train
    valid_dataset, num_valid_steps = valid
    test_dataset, num_test_steps = test

    patience = 0

    def convert_to_tensor(ids):
        """
        Convenience function for converting int32
        ndarray to torch int64.
        """
        return torch.as_tensor(ids).long().to(device)

    def forward_step(batch):
        """
        Applies forward pass with the given batch.
        """
        inputs, labels = batch

        labels = convert_to_tensor(labels)

        # converting the batch of inputs to torch tensor
        inputs = [convert_to_tensor(m) for m in inputs]

        input_ids, token_type_ids, attn_mask, \
            perm_mask, target_map = inputs

        outputs = model(input_ids=input_ids,
                        token_type_ids=token_type_ids,
                        attention_mask=attn_mask,
                        perm_mask=perm_mask,
                        target_mapping=target_map.float())

        loss, accuracy = compute_loss(outputs=outputs, labels=labels)

        return loss, accuracy

    def train_step(batch):
        """
        Performs a single step of training.
        """
        nonlocal step

        loss, accuracy = forward_step(batch)

        if torch.isnan(loss).item():
            logger.warn('skipping step (nan)')
            # returning None values when a NaN loss
            # is encountered and skipping backprop
            # so model grads will not be corrupted
            return None, None

        loss /= args.grad_accum_steps

        backward(loss)
        clip_grad_norm(1.0)

        step += 1

        if step % args.grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return loss.item(), accuracy

    def backward(loss):
        """
        Backpropagates the loss in either mixed or
        normal precision mode.
        """
        # cuda is required for mixed precision training.
        if args.mixed and args.cuda:
            with amp.scale_loss(loss, optimizer) as scaled:
                scaled.backward()
        else:
            loss.backward()

    def clip_grad_norm(max_norm):
        """
        Applies gradient clipping.
        """
        if args.mixed and args.cuda:
            clip_grad_norm_(amp.master_params(optimizer), max_norm)
        else:
            clip_grad_norm_(model.parameters(), max_norm)

    # TODO fix scheduling
    # scheduler = WarmupLinearSchedule(
    #     optimizer=optimizer,
    #     warmup_steps=args.warmup_steps,
    #     t_total=args.total_steps)

    if master_process:
        logger.info(str(vars(args)))

    # stepping optimizer from initial (0) learning rate
    # if init_epoch == 0:
    #     scheduler.step()

    for epoch in range(init_epoch, args.max_epochs):
        # running training loop
        loop = tqdm(train_dataset(),
                    total=num_train_steps,
                    disable=not master_process)

        loop.set_description('{}'.format(epoch))
        model.train()
        avg_loss = []

        for batch in loop:
            try:
                loss, acc = train_step(batch)

                # logging to tensorboard
                if master_process:
                    writer.add_scalar('train/loss', loss, step)
                    writer.add_scalar('train/acc', acc, step)

                if loss is not None:
                    avg_loss.append(loss)

                loop.set_postfix(ordered_dict=OrderedDict(loss=loss, acc=acc))

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    logger.warn('skipping step (oom)')

        # scheduler.step(epoch=epoch)

        if len(avg_loss) > 0:
            avg_loss = sum(avg_loss) / len(avg_loss)
        else:
            avg_loss = 0.0

        if master_process:
            logger.info('train loss: {:.4}'.format(avg_loss))

        loop = tqdm(valid_dataset(),
                    total=num_valid_steps,
                    disable=not master_process)

        model.eval()
        avg_loss = []

        # running validation loop
        with torch.no_grad():
            for batch in loop:
                loss, accuracy = forward_step(batch)

                avg_loss.append(loss.item())

                loop.set_postfix(
                    ordered_dict=OrderedDict(loss=loss.item(), acc=accuracy))

            avg_loss = sum(avg_loss) / len(avg_loss)

            if master_process:
                writer.add_scalar('valid/loss', avg_loss, step)

        if master_process:
            logger.info('valid loss: {:.4}'.format(avg_loss))

        if avg_loss < best_avg_loss:
            patience = 0
            best_avg_loss = avg_loss
            if master_process:
                save_state(args=args,
                           logger=logger,
                           state={
                               'model': model.state_dict(),
                               'optimzier': optimizer.state_dict(),
                               'avg_loss': avg_loss,
                               'epoch': epoch + 1,
                               'step': step
                           })

        else:
            patience += 1
            if patience == args.patience:
                # terminate when max patience level is hit
                break

    writer.close()

    loop = tqdm(test_dataset(),
                total=num_test_steps,
                disable=not master_process)

    model.eval()
    avg_loss = []

    # running testing loop
    with torch.no_grad():
        for batch in loop:
            loss, accuracy = forward_step(batch)

            avg_loss.append(loss.item())

            loop.set_postfix(
                ordered_dict=OrderedDict(loss=loss.item(), acc=accuracy))

        avg_loss = sum(avg_loss) / len(avg_loss)
Ejemplo n.º 21
0
opt['GT_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = 'Vimeo90K_train_keys.pkl'
################################################################################
opt['data_type'] = 'lmdb'  # img | lmdb | mc
opt['dist'] = False
opt['gpu_ids'] = [0]

util.mkdir('tmp')
train_set = create_dataset(opt)
train_loader = create_dataloader(train_set, opt, opt, None)
nrow = int(math.sqrt(opt['batch_size']))
if opt['phase'] == 'train':
    padding = 2
else:
    padding = 0

print('start...')
for i, data in enumerate(train_loader):
    if i > 5:
        break
    print(i)
    LQs = data['LQs']
    GT = data['GT']
    key = data['key']
Ejemplo n.º 22
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs(
        (path for key, path in opt['path'].items()
         if not key == 'experiments_root' and not key == 'pretrain_model_G'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(
                    len(train_set) / dataset_opt['batch_size_per_month']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            num_months = int(opt['train']['num_months'])
            num_days = int(opt['train']['num_days'])
            total_iters = int(num_months * num_days)
            print('Total epochs needed: {:d} for iters {:,d}'.format(
                num_months, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
            batch_size_per_month = dataset_opt['batch_size_per_month']
            batch_size_per_day = int(
                opt['datasets']['train']['batch_size_per_day'])
            use_dci = opt['train']['use_dci']
            inter_supervision = opt['train']['inter_supervision']
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    validate(val_loader, opt, model, current_step, 0, logger)
    for epoch in range(num_months):
        for i, train_data in enumerate(train_loader):
            # Sample the codes used for training of the month
            if use_dci:
                cur_month_code = generate_code_samples(model, train_data, opt)
            else:
                tensor_type = torch.zeros if opt['train'][
                    'zero_code'] else torch.randn
                cur_month_code = model.gen_code(train_data['LR'].shape[0],
                                                train_data['LR'].shape[2],
                                                train_data['LR'].shape[3],
                                                tensor_type=tensor_type)
            # clear projection matrix to save memory
            model.clear_projection()
            for j in range(num_days):
                current_step += 1
                # get the sliced data
                cur_day_batch_start_idx = (
                    j * batch_size_per_day) % batch_size_per_month
                cur_day_batch_end_idx = cur_day_batch_start_idx + batch_size_per_day
                if cur_day_batch_end_idx > batch_size_per_month:
                    cur_day_batch_idx = np.hstack(
                        (np.arange(cur_day_batch_start_idx,
                                   batch_size_per_month),
                         np.arange(cur_day_batch_end_idx -
                                   batch_size_per_month)))
                else:
                    cur_day_batch_idx = slice(cur_day_batch_start_idx,
                                              cur_day_batch_end_idx)

                cur_day_train_data = {
                    key: val[cur_day_batch_idx]
                    for key, val in train_data.items()
                }
                code = [
                    gen_code[cur_day_batch_idx] for gen_code in cur_month_code
                ]
                # training
                model.feed_data(cur_day_train_data, code=code)
                model.optimize_parameters(current_step,
                                          inter_supervision=inter_supervision)

                time_elapsed = time.time() - start_time
                start_time = time.time()

                # log
                if current_step % opt['logger'][
                        'print_freq'] == 0 or current_step == 1:
                    logs = model.get_current_log()
                    print_rlt = OrderedDict()
                    print_rlt['model'] = opt['model']
                    print_rlt['epoch'] = epoch
                    print_rlt['iters'] = current_step
                    print_rlt['time'] = time_elapsed
                    for k, v in logs.items():
                        print_rlt[k] = v
                    print_rlt['lr'] = model.get_current_learning_rate()
                    logger.print_format_results('train', print_rlt)

                # save models
                if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                    print('Saving the model at the end of iter {:d}.'.format(
                        current_step))
                    model.save(current_step)

                # validation
                if current_step % opt['train']['val_freq'] == 0:
                    validate(val_loader, opt, model, current_step, epoch,
                             logger)

                # update learning rate
                model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
Ejemplo n.º 23
0
def test(number):
    # Variables for Options
    RESULTS_DIR = './results/edges2handbags_2'
    CLASS = 'edges2handbags'
    DIRECTION = 'AtoB'  # from domain A to domain B
    LOAD_SIZE = 256  # scale images to this size
    CROP_SIZE = 256  # then crop to this size
    INPUT_NC = 1  # number of channels in the input image

    # misc
    GPU_ID = -1  # gpu id
    NUM_TEST = 1  # number of input images duirng test
    NUM_SAMPLES = 20  # number of samples per input images

    # options
    opt = TestOptions().parse()
    opt.num_threads = 1  # test code only supports num_threads=1
    opt.batch_size = 1  # test code only supports batch_size=1
    opt.serial_batches = True  # no shuffle
    # Options Added
    opt.dataroot = './datasets/edges2handbags'
    opt.results_dir = RESULTS_DIR
    opt.checkpoints_dir = './pretrained_models/'
    opt.name = CLASS
    opt.direction = DIRECTION
    opt.load_size = LOAD_SIZE
    opt.crop_size = CROP_SIZE
    opt.input_nc = INPUT_NC
    opt.num_test = NUM_TEST
    opt.n_samples = NUM_SAMPLES

    # create dataset
    dataset = create_dataset(opt)
    model = create_model(opt)
    model.setup(opt)
    model.eval()
    print('Loading model %s' % opt.model)
    print(dataset)

    # create website
    web_dir = os.path.join(opt.results_dir,
                           opt.phase + '_sync' if opt.sync else opt.phase)
    webpage = html.HTML(
        web_dir, 'Training = %s, Phase = %s, Class =%s' %
        (opt.name, opt.phase, opt.name))

    # sample random z
    if opt.sync:
        z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)

    # test stage
    #for i, data in enumerate(islice(dataset, opt.num_test)):
    for i, data in enumerate(islice(dataset, number, number + opt.num_test)):
        print(data)
        model.set_input(data)
        print('process input image %3.3d/%3.3d' % (i + 1, opt.num_test))
        if not opt.sync:
            z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)
        for nn in range(opt.n_samples + 1):
            encode = nn == 0 and not opt.no_encode
            real_A, fake_B, real_B = model.test(z_samples[[nn]], encode=encode)
            if nn == 0:
                images = [real_A, real_B, fake_B]
                #names = ['input' + str(number), 'ground truth' + str(number), 'encoded']
                names = [str(99999), str(9999), '999']
            else:
                images.append(fake_B)
                #names.append('random_sample%2.2d' % nn)
                names.append(nn)

        img_path = str(number + 101)
        #img_path = 'input_%3.3d' % (number + 101)
        #img_path = ''
        print(img_path)
        save_images(webpage,
                    images,
                    names,
                    img_path,
                    aspect_ratio=opt.aspect_ratio,
                    width=opt.crop_size)

    webpage.save()
Ejemplo n.º 24
0
def train(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(
        cfg, writer, logger
    )  #source_train\ target_train\ source_valid\ target_valid + _loader

    model = CustomModel(cfg, writer, logger)

    # Setup Metrics
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    flag_train = True

    epoches = cfg['training']['epoches']

    source_train_loader = datasets.source_train_loader
    target_train_loader = datasets.target_train_loader
    logger.info('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(
        target_train_loader.batch_size))
    print('target train batchsize is {}'.format(
        target_train_loader.batch_size))

    val_loader = None
    if cfg.get('valset') == 'gta5':
        val_loader = datasets.source_valid_loader
        logger.info('valset is gta5')
        print('valset is gta5')
    else:
        val_loader = datasets.target_valid_loader
        logger.info('valset is cityscapes')
        print('valset is cityscapes')
    logger.info('val batchsize is {}'.format(val_loader.batch_size))
    print('val batchsize is {}'.format(val_loader.batch_size))

    # load category anchors
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']

    # begin training
    model.iter = 0
    for epoch in range(epoches):
        if not flag_train:
            break
        if model.iter > cfg['training']['train_iters']:
            break

        # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
        score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores()
        print('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))

        logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
        logger.info('clus_Recall: {}'.format(
            model.metrics.calc_mean_Clu_recall()))
        logger.info(model.metrics.classes_recall_clu[:, 0] /
                    model.metrics.classes_recall_clu[:, 1])
        logger.info('clus_Acc: {}'.format(
            np.mean(model.metrics.classes_recall_clu[:, 0] /
                    model.metrics.classes_recall_clu[:, 1])))
        logger.info(model.metrics.classes_recall_clu[:, 0] /
                    model.metrics.classes_recall_clu[:, 2])

        score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores()
        logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
        logger.info('thr_Recall: {}'.format(
            model.metrics.calc_mean_Thr_recall()))
        logger.info(model.metrics.classes_recall_thr[:, 0] /
                    model.metrics.classes_recall_thr[:, 1])
        logger.info('thr_Acc: {}'.format(
            np.mean(model.metrics.classes_recall_thr[:, 0] /
                    model.metrics.classes_recall_thr[:, 1])))
        logger.info(model.metrics.classes_recall_thr[:, 0] /
                    model.metrics.classes_recall_thr[:, 2])
        model.metrics.reset()

        for (target_image, target_label,
             target_img_name) in datasets.target_train_loader:
            model.iter += 1
            i = model.iter
            if i > cfg['training']['train_iters']:
                break
            source_batchsize = cfg['data']['source']['batch_size']
            images, labels, source_img_name = datasets.source_train_loader.next(
            )
            start_ts = time.time()

            images = images.to(device)
            labels = labels.to(device)
            target_image = target_image.to(device)
            target_label = target_label.to(device)
            model.scheduler_step()
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn') == True:
                model.freeze_bn_apply()
            model.optimizer_zerograd()

            loss, loss_cls_L2, loss_pseudo = model.step(
                images, labels, target_image, target_label)
            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            time_meter.update(time.time() - start_ts)
            if (i + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f}  Loss_cls_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    epoch + 1, epoches, i + 1, cfg['training']['train_iters'],
                    loss.item(), loss_cls_L2, loss_pseudo,
                    time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info(
                    'unchanged number of objective class vector: {}'.format(
                        unchanged_cls_num))
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                writer.add_scalar('loss/train_cls_L2Loss', loss_cls_L2, i + 1)
                writer.add_scalar('loss/train_pseudoLoss', loss_pseudo, i + 1)
                time_meter.reset()

                score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores(
                )
                logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('clus_Recall: {}'.format(
                    model.metrics.calc_mean_Clu_recall()))
                logger.info('clus_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_clu[:, 0] /
                            model.metrics.classes_recall_clu[:, 2])))

                score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores(
                )
                logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('thr_Recall: {}'.format(
                    model.metrics.calc_mean_Thr_recall()))
                logger.info('thr_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_thr[:, 0] /
                            model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (i + 1) % cfg['training']['val_interval'] == 0 or \
                (i + 1) == cfg['training']['train_iters']:
                validation(
                    model, logger, writer, datasets, device, running_metrics_val, val_loss_meter, loss_fn,\
                    source_val_loss_meter, source_running_metrics_val, iters = model.iter
                    )
                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))
            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Ejemplo n.º 25
0
Archivo: test.py Proyecto: ryul99/mmsr
util.mkdirs((path for key, path in opt['path'].items()
             if not key == 'experiments_root' and 'pretrain_model' not in key
             and 'resume' not in key and 'wandb_load_run_path' not in key))
util.setup_logger('base',
                  opt['path']['log'],
                  'test_' + opt['name'],
                  level=logging.INFO,
                  screen=True,
                  tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info('Number of test images in [{:s}]: {:d}'.format(
        dataset_opt['name'], len(test_set)))
    test_loaders.append(test_loader)

model = create_model(opt)
for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt['name']
    logger.info('\nTesting [{:s}]...'.format(test_set_name))
    test_start_time = time.time()
    dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
    util.mkdir(dataset_dir)

    test_results = OrderedDict()
    test_results['psnr'] = []
Ejemplo n.º 26
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.
    pytorch_ver = get_pytorch_ver()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(
                glob.glob(
                    os.path.normpath(opt['path']['resume_state']) +
                    '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(
                logdir='../tb_logger/' +
                opt['name'])  #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(
                log_dir='../tb_logger/' +
                opt['name'])  #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(
            opt['train']
        )  # updated schedulers in case JSON configuration has changed
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for n, train_data in enumerate(train_loader, start=1):
            current_step += 1
            if current_step > total_iters:
                break

            if pytorch_ver == "pre":  #Order for PyTorch ver < 1.1.0
                # update learning rate
                model.update_learning_rate(current_step - 1)
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
            elif pytorch_ver == "post":  #Order for PyTorch ver > 1.1.0
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
                # update learning rate
                model.update_learning_rate(current_step - 1)
            else:
                print('Error identifying PyTorch version. ', torch.__version__)
                break

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # save models and training states (changed to save models before validation)
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                model.save(current_step)
                model.save_training_state(epoch + (n >= len(train_loader)),
                                          current_step)
                logger.info('Models and training states saved.')

            # validation
            if val_loader and current_step % opt['train']['val_freq'] == 0:
                avg_psnr_c = 0.0
                avg_psnr_s = 0.0
                avg_psnr_p = 0.0

                avg_ssim_c = 0.0
                avg_ssim_s = 0.0
                avg_ssim_p = 0.0

                idx = 0

                val_sr_imgs_list = []
                val_gt_imgs_list = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    if opt['datasets']['train'][
                            'znorm']:  # If the image range is [-1,1]
                        img_c = util.tensor2img(visuals['img_c'],
                                                min_max=(-1, 1))  # uint8
                        img_s = util.tensor2img(visuals['img_s'],
                                                min_max=(-1, 1))  # uint8
                        img_p = util.tensor2img(visuals['img_p'],
                                                min_max=(-1, 1))  # uint8
                        gt_img = util.tensor2img(visuals['HR'],
                                                 min_max=(-1, 1))  # uint8
                    else:  # Default: Image range is [0,1]
                        img_c = util.tensor2img(visuals['img_c'])  # uint8
                        img_s = util.tensor2img(visuals['img_s'])  # uint8
                        img_p = util.tensor2img(visuals['img_p'])  # uint8
                        gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_c_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_c.png'.format(img_name, current_step))
                    save_s_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_s.png'.format(img_name, current_step))
                    save_p_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_d.png'.format(img_name, current_step))

                    util.save_img(img_c, save_c_img_path)
                    util.save_img(img_s, save_s_img_path)
                    util.save_img(img_p, save_p_img_path)

                    # calculate PSNR, SSIM and LPIPS distance
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    #sr_img = sr_img / 255. #ESRGAN
                    #PPON

                    sr_img_c = img_c / 255.  #C
                    sr_img_s = img_s / 255.  #S
                    sr_img_p = img_p / 255.  #D

                    # For training models with only one channel ndim==2, if RGB ndim==3, etc.
                    if gt_img.ndim == 2:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:  # gt_img.ndim == 3, # Default: RGB images
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                    # All 3 output images will have the same dimensions
                    if sr_img_c.ndim == 2:
                        cropped_sr_img_c = sr_img_c[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        cropped_sr_img_s = sr_img_s[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        cropped_sr_img_p = sr_img_p[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                    else:  #sr_img_c.ndim == 3, # Default: RGB images
                        cropped_sr_img_c = sr_img_c[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                        cropped_sr_img_s = sr_img_s[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                        cropped_sr_img_p = sr_img_p[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]

                    avg_psnr_c += util.calculate_psnr(cropped_sr_img_c * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_c += util.calculate_ssim(cropped_sr_img_c * 255,
                                                      cropped_gt_img * 255)

                    avg_psnr_s += util.calculate_psnr(cropped_sr_img_s * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_s += util.calculate_ssim(cropped_sr_img_s * 255,
                                                      cropped_gt_img * 255)

                    avg_psnr_p += util.calculate_psnr(cropped_sr_img_p * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_p += util.calculate_ssim(cropped_sr_img_p * 255,
                                                      cropped_gt_img * 255)

                    # LPIPS only works for RGB images
                    # Using only the final perceptual image to calulate LPIPS
                    if sr_img_c.ndim == 3:
                        #avg_lpips += lpips.calculate_lpips([cropped_sr_img], [cropped_gt_img]) # If calculating for each image
                        val_gt_imgs_list.append(
                            cropped_gt_img
                        )  # If calculating LPIPS only once for all images
                        val_sr_imgs_list.append(
                            cropped_sr_img_p
                        )  # If calculating LPIPS only once for all images

                # PSNR
                avg_psnr_c = avg_psnr_c / idx
                avg_psnr_s = avg_psnr_s / idx
                avg_psnr_p = avg_psnr_p / idx
                # SSIM
                avg_ssim_c = avg_ssim_c / idx
                avg_ssim_s = avg_ssim_s / idx
                avg_ssim_p = avg_ssim_p / idx
                # LPIPS
                #avg_lpips = avg_lpips / idx # If calculating for each image
                avg_lpips = lpips.calculate_lpips(
                    val_sr_imgs_list, val_gt_imgs_list
                )  # If calculating only once for all images

                # log
                # PSNR
                logger.info('# Validation # PSNR_c: {:.5g}'.format(avg_psnr_c))
                logger.info('# Validation # PSNR_s: {:.5g}'.format(avg_psnr_s))
                logger.info('# Validation # PSNR_p: {:.5g}'.format(avg_psnr_p))
                # SSIM
                logger.info('# Validation # SSIM_c: {:.5g}'.format(avg_ssim_c))
                logger.info('# Validation # SSIM_s: {:.5g}'.format(avg_ssim_s))
                logger.info('# Validation # SSIM_p: {:.5g}'.format(avg_ssim_p))
                # LPIPS
                logger.info('# Validation # LPIPS: {:.5g}'.format(avg_lpips))

                logger_val = logging.getLogger('val')  # validation logger
                # logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format(
                # epoch, current_step, avg_psnr_c, avg_psnr_s, avg_psnr_p))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}>'.format(
                    epoch, current_step))
                logger_val.info(
                    'psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format(
                        avg_psnr_c, avg_psnr_s, avg_psnr_p))
                logger_val.info(
                    'ssim_c: {:.5g}, ssim_s: {:.5g}, ssim_p: {:.5g}'.format(
                        avg_ssim_c, avg_ssim_s, avg_ssim_p))
                logger_val.info('lpips: {:.5g}'.format(avg_lpips))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr_c', avg_psnr_c, current_step)
                    tb_logger.add_scalar('psnr_s', avg_psnr_s, current_step)
                    tb_logger.add_scalar('psnr_p', avg_psnr_p, current_step)
                    tb_logger.add_scalar('ssim_c', avg_ssim_c, current_step)
                    tb_logger.add_scalar('ssim_s', avg_ssim_s, current_step)
                    tb_logger.add_scalar('ssim_p', avg_ssim_p, current_step)
                    tb_logger.add_scalar('lpips', avg_lpips, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Ejemplo n.º 27
0
    Train a pix2pix model:
        python train.py --dataroot ./datasets/maps --name maps_pix2pix --model pix2pix --direction BtoA
    Train a SMAPGAN model:
        python train.py --dataroot ./datasets/maps --name maps_smapgan --model smapgan
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options

    opt.dataset_mode = 'aligned'
    datasetP = create_dataset(opt)  # create a paired dataset
    datasetP_size = len(datasetP)  # get the number of images in the dataset.
    print('The number of paired training images = %d' % datasetP_size)

    # opt.dataset_mode = 'unaligned'
    # datasetU = create_dataset(opt)  # create a unpaired dataset
    # datasetU_size = len(datasetU)   # get the number of images in the dataset.
    # print('The number of unpaired training images = %d' % datasetU_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations
Ejemplo n.º 28
0
def main():
    parser = argparse.ArgumentParser(
        description='Train Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    #    opt = option.parse(parser.parse_args().opt)
    opt = option.parse('./options/train/train_MSFN_example.json')

    # random seed
    seed = opt['solver']['manual_seed']
    if seed is None: seed = random.randint(1, 10000)
    print("===> Random Seed: [%d]" % seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_loader = create_dataloader(train_set, dataset_opt)
            print('===> Train Dataset: %s   Number of images: [%d]' %
                  (train_set.name(), len(train_set)))
            if train_loader is None:
                raise ValueError("[Error] The training data does not exist")

        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('===> Val Dataset: %s   Number of images: [%d]' %
                  (val_set.name(), len(val_set)))

        else:
            raise NotImplementedError(
                "[Error] Dataset phase [%s] in *.json is not recognized." %
                phase)

    solver = create_solver(opt)

    scale = opt['scale']
    model_name = opt['networks']['which_model'].upper()

    print('===> Start Train')
    print("==================================================")

    solver_log = solver.get_current_log()

    NUM_EPOCH = int(opt['solver']['num_epochs'])
    start_epoch = solver_log['epoch']

    print("Method: %s || Scale: %d || Epoch Range: (%d ~ %d)" %
          (model_name, scale, start_epoch, NUM_EPOCH))

    for epoch in range(start_epoch, NUM_EPOCH + 1):
        print('\n===> Training Epoch: [%d/%d]...  Learning Rate: %f' %
              (epoch, NUM_EPOCH, solver.get_current_learning_rate()))

        # Initialization
        solver_log['epoch'] = epoch

        # Train model
        train_loss_list = []
        with tqdm(total=len(train_loader),
                  desc='Epoch: [%d/%d]' % (epoch, NUM_EPOCH),
                  miniters=1) as t:
            for iter, batch in enumerate(train_loader):
                solver.feed_data(batch)
                iter_loss = solver.train_step()
                batch_size = batch['LR'].size(0)
                train_loss_list.append(iter_loss * batch_size)
                t.set_postfix_str("Batch Loss: %.4f" % iter_loss)
                t.update()

        solver_log['records']['train_loss'].append(
            sum(train_loss_list) / len(train_set))
        solver_log['records']['lr'].append(solver.get_current_learning_rate())

        print('\nEpoch: [%d/%d]   Avg Train Loss: %.6f' %
              (epoch, NUM_EPOCH, sum(train_loss_list) / len(train_set)))

        print('===> Validating...', )

        psnr_list = []
        ssim_list = []
        val_loss_list = []

        for iter, batch in enumerate(val_loader):
            solver.feed_data(batch)
            iter_loss = solver.test()
            val_loss_list.append(iter_loss)

            # calculate evaluation metrics
            visuals = solver.get_current_visual()
            psnr, ssim = util.pan_calc_metrics(visuals['SR'],
                                               visuals['HR'],
                                               crop_border=scale,
                                               img_range=opt['img_range'])
            psnr_list.append(psnr)
            ssim_list.append(ssim)

            if opt["save_image"]:
                solver.save_current_visual(epoch, iter)

        solver_log['records']['val_loss'].append(
            sum(val_loss_list) / len(val_loss_list))
        solver_log['records']['CC'].append(sum(psnr_list) / len(psnr_list))
        solver_log['records']['RMSE'].append(sum(ssim_list) / len(ssim_list))

        # record the best epoch
        epoch_is_best = False
        if solver_log['best_pred'] < (sum(psnr_list) / len(psnr_list)):
            solver_log['best_pred'] = (sum(psnr_list) / len(psnr_list))
            epoch_is_best = True
            solver_log['best_epoch'] = epoch

        print(
            "[%s] CC: %.4f   RMSE: %.4f   Loss: %.6f   Best PSNR: %.2f in Epoch: [%d]"
            %
            (val_set.name(), sum(psnr_list) / len(psnr_list), sum(ssim_list) /
             len(ssim_list), sum(val_loss_list) / len(val_loss_list),
             solver_log['best_pred'], solver_log['best_epoch']))

        solver.set_current_log(solver_log)
        solver.save_checkpoint(epoch, epoch_is_best)
        solver.save_current_log()

        # update lr
        solver.update_learning_rate(epoch)

    print('===> Finished !')
Ejemplo n.º 29
0
from models import create_model
from util.visualizer import save_images
from util import html
from util.util import eval_error_metrics


if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.eval_mode = True
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers

    metrics_log_file = os.path.join(opt.results_dir, opt.name, 'eval_metrics.txt')

    ##### NEW TESTING CODE
    if opt.eval:
        model.eval()
    eval_error_metrics(200, model, dataset, log_filename=metrics_log_file)

    # opt.full_slice = True
    # dataset2 = create_dataset(opt)
    # eval_error_metrics(200, model, dataset2, log_filename=metrics_log_file)

    ##### ORIGINAL CODE
import sys, os
from models.base_visual import BaseVisual
import torch.nn as nn

if __name__ == '__main__':

    ##################################
    #Preparing the Logger and Backup
    ##################################
    opt, logger = TestOptions().parse()
    opt.batch_size = 1
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
    ##################################
    #Preparing the Dataset
    ##################################
    test_loader = create_dataset(opt)

    ##################################
    #Initilizing the Model
    ##################################
    model = create_model(opt)
    # the_model = torch.load(osp.join(opt.backup_path, 'Best.pth'))
    #  model.load_state_dict(the_model)

    print(model)
    #    grad_cam = GradCam(model=model, target_layer_names=['layer4'],use_cuda=True)

    criterion = nn.CrossEntropyLoss().cuda()
    runner = BaseVisual(opt, model)

    runner.test(test_loader)