コード例 #1
0
ファイル: validate.py プロジェクト: dvlab-research/Simple-SR
        avg_psnr = sum(psnr_l) / len(psnr_l)
        avg_ssim = sum(ssim_l) / len(ssim_l)

    return avg_psnr, avg_ssim


if __name__ == '__main__':
    from config import config
    from network import Network
    from dataset import get_dataset
    from utils import dataloader
    from utils.model_opr import load_model

    config.VAL.DATASET = 'Set5'

    model = Network(config)
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    model = model.to(device)

    model_path = 'log/models/200000.pth'
    load_model(model, model_path, cpu=True)
    sys.exit()

    val_dataset = get_dataset(config.VAL)
    val_loader = dataloader.val_loader(val_dataset, config, 0, 1)
    psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.')
    print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim))
コード例 #2
0
ファイル: train.py プロジェクト: dvlab-research/Simple-SR
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    # initialization
    rank = 0
    num_gpu = 1
    distributed = False
    if 'WORLD_SIZE' in os.environ:
        num_gpu = int(os.environ['WORLD_SIZE'])
        distributed = num_gpu > 1
    if distributed:
        rank = args.local_rank
        init_dist(rank)
    common.init_random_seed(config.DATASET.SEED + rank)

    # set up dirs and log
    exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0])
    root_dir = osp.split(exp_dir)[0]
    log_dir = osp.join(root_dir, 'logs', cur_dir)
    model_dir = osp.join(log_dir, 'models')
    solver_dir = osp.join(log_dir, 'solvers')
    if rank <= 0:
        common.mkdir(log_dir)
        ln_log_dir = osp.join(exp_dir, cur_dir, 'log')
        if not osp.exists(ln_log_dir):
            os.system('ln -s %s log' % log_dir)
        common.mkdir(model_dir)
        common.mkdir(solver_dir)
        save_dir = osp.join(log_dir, 'saved_imgs')
        common.mkdir(save_dir)
        tb_dir = osp.join(log_dir, 'tb_log')
        tb_writer = SummaryWriter(tb_dir)
        common.setup_logger('base',
                            log_dir,
                            'train',
                            level=logging.INFO,
                            screen=True,
                            to_file=True)
        logger = logging.getLogger('base')

    # dataset
    train_dataset = get_dataset(config.DATASET)
    train_loader = dataloader.train_loader(train_dataset,
                                           config,
                                           rank=rank,
                                           seed=config.DATASET.SEED,
                                           is_dist=distributed)
    if rank <= 0:
        val_dataset = get_dataset(config.VAL)
        val_loader = dataloader.val_loader(val_dataset, config, rank, 1)
        data_len = val_dataset.data_len

    # model
    model = Network(config)
    if rank <= 0:
        print(model)

    if config.CONTINUE_ITER:
        model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER)
        if rank <= 0:
            logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER)
        model_opr.load_model(model, model_path, strict=True, cpu=True)
    elif config.INIT_MODEL:
        if rank <= 0:
            logger.info('[Initialize] Model: %s' % config.INIT_MODEL)
        model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True)

    device = torch.device(config.MODEL.DEVICE)
    model.to(device)
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[torch.cuda.current_device()])

    # solvers
    optimizer = solver.make_optimizer(config, model)  # lr without X num_gpu
    lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer,
                                                   config.SOLVER.BASE_LR)
    iteration = 0

    if config.CONTINUE_ITER:
        solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER)
        iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path)

    max_iter = max_psnr = max_ssim = 0
    for lr_img, hr_img in train_loader:
        model.train()
        iteration = iteration + 1

        optimizer.zero_grad()

        lr_img = lr_img.to(device)
        hr_img = hr_img.to(device)

        loss_dict = model(lr_img, gt=hr_img)
        total_loss = sum(loss for loss in loss_dict.values())
        total_loss.backward()

        optimizer.step()
        lr_scheduler.step()

        if rank <= 0:
            if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER:
                log_str = 'Iter: %d, LR: %.3e, ' % (
                    iteration, optimizer.param_groups[0]['lr'])
                for key in loss_dict:
                    tb_writer.add_scalar(key,
                                         loss_dict[key].mean(),
                                         global_step=iteration)
                    log_str += key + ': %.4f, ' % float(loss_dict[key])
                logger.info(log_str)

            if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER:
                logger.info('[Saving] Iter: %d' % iteration)
                model_path = osp.join(model_dir, '%d.pth' % iteration)
                solver_path = osp.join(solver_dir, '%d.solver' % iteration)
                model_opr.save_model(model, model_path)
                model_opr.save_solver(optimizer, lr_scheduler, iteration,
                                      solver_path)

            if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER:
                logger.info('[Validating] Iter: %d' % iteration)
                model.eval()
                with torch.no_grad():
                    psnr, ssim = validate(model,
                                          val_loader,
                                          config,
                                          device,
                                          iteration,
                                          save_path=save_dir)
                if psnr > max_psnr:
                    max_psnr, max_ssim, max_iter = psnr, ssim, iteration
                logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' %
                            (iteration, psnr, ssim))
                logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' %
                            (max_iter, max_psnr, max_ssim))

        if iteration >= config.SOLVER.MAX_ITER:
            break

    if rank <= 0:
        logger.info('Finish training process!')
        logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' %
                    (max_iter, max_psnr, max_ssim))
コード例 #3
0

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--input_path', type=str, default=None)
    parser.add_argument('--output_path', type=str, default=None)
    args = parser.parse_args()

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    config, model = get_network(args.model_path)
    device = torch.device('cuda')
    model = model.to(device)
    load_model(model, args.model_path, strict=True)

    ipath_l = []
    for f in sorted(os.listdir(args.input_path)):
        if f.endswith('png') or f.endswith('jpg'):
            ipath_l.append(os.path.join(args.input_path, f))
    num_img = len(ipath_l)

    down = config.MODEL.DOWN
    scale = config.MODEL.SCALE
    half_n = config.MODEL.N_FRAME // 2
    with torch.no_grad():
        for i, f in enumerate(ipath_l):
            img_name = f.split('/')[-1]
            print(img_name)
            nbr_l = []
コード例 #4
0
        bic = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
        out = bic + residual

        if gt is not None:
            loss_dict = dict(CB=self.criterion(out, gt))
            return loss_dict
        else:
            return out


if __name__ == '__main__':
    from utils import model_opr
    from config import config

    net = Network(config)
    model_opr.load_model(net, 'mucan_vimeo.pth', strict=True, cpu=True)
    
    device = torch.device('cuda')
    net = net.to(device)

    dir_path = '~/vincent/datasets/Vimeo-90K/vimeo_septuplet/sequences_matlabLRx4/00001/1000'


    sys.exit()

    import json
    saved = json.load(open('saved_parameters.json', 'r'))
    new = json.load(open('mapping.json', 'r'))

    mapping = dict()
    for i, name in enumerate(saved[0]):