Esempio n. 1
0
def main(opt):
    """
    Trains SRVP and saved the resulting model.

    Parameters
    ----------
    opt : DotDict
        Contains the training configuration.
    """
    ##################################################################################################################
    # Setup
    ##################################################################################################################
    opt.hostname = os.uname()[1]
    # Device handling (CPU, GPU, multi GPU)
    if opt.device is None:
        device = torch.device('cpu')
        opt.n_gpu = 0
    else:
        opt.n_gpu = len(opt.device)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.device[opt.local_rank])
        device = torch.device('cuda:0')
        torch.cuda.set_device(0)
        # In the case of multi GPU: sets up distributed training
        if opt.n_gpu > 1 or opt.local_rank > 0:
            torch.distributed.init_process_group(backend='nccl')
            # Since we are in distributed mode, divide batch size by the number of GPUs
            assert opt.batch_size % opt.n_gpu == 0
            opt.batch_size = opt.batch_size // opt.n_gpu
    # -- Seed
    if opt.seed is None:
        opt.seed = random.randint(1, 10000)
    else:
        assert isinstance(opt.seed, int) and opt.seed > 0
    print(f"Learning on {opt.n_gpu} GPU(s) (seed: {opt.seed})")
    random.seed(opt.seed)
    np.random.seed(opt.seed + opt.local_rank)
    torch.manual_seed(opt.seed)
    # -- cuDNN
    if opt.n_gpu > 1 or opt.local_rank > 0:
        assert torch.backends.cudnn.enabled
        cudnn.deterministic = True

    ##################################################################################################################
    # Data
    ##################################################################################################################
    print('Loading data...')
    # Load data
    dataset = data.load_dataset(opt, train=True)
    trainset = dataset.get_fold('train')

    # Handle random seed for dataloader workers
    def worker_init_fn(worker_id):
        np.random.seed(
            (opt.seed + itr + opt.local_rank * 10 + worker_id) % (2**32 - 1))

    # Dataloader
    sampler = None
    shuffle = True
    if opt.n_gpu > 1:
        # Let the distributed sampler shuffle for the distributed case
        sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        shuffle = False
    train_loader = DataLoader(trainset,
                              batch_size=opt.batch_size,
                              collate_fn=data.collate_fn,
                              sampler=sampler,
                              num_workers=opt.num_workers,
                              shuffle=shuffle,
                              drop_last=True,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)

    ##################################################################################################################
    # Model
    ##################################################################################################################
    # Buid model
    print('Building model...')
    model = srvp.StochasticLatentResidualVideoPredictor(
        opt.nx, opt.nc, opt.nf, opt.nhx, opt.ny, opt.nz, opt.skipco,
        opt.nt_inf, opt.nh_inf, opt.nlayers_inf, opt.nh_res, opt.nlayers_res,
        opt.archi)
    model.init(res_gain=opt.res_gain)
    # Make the batch norms in the model synchronized in the distributed case
    if opt.n_gpu > 1:
        if opt.amp_opt_lvl != 'none':
            try:
                from apex.parallel import convert_syncbn_model
            except ImportError:
                raise ImportError(
                    'Please install apex: https://github.com/NVIDIA/apex')
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    ##################################################################################################################
    # Optimizer
    ##################################################################################################################
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    opt.niter = opt.lr_scheduling_burnin + opt.lr_scheduling_niter
    niter = opt.lr_scheduling_niter
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda i: max(0, (niter - i) / niter))

    ##################################################################################################################
    # Apex's Automatic Mixed Precision
    ##################################################################################################################
    model.to(device)
    if opt.amp_opt_lvl != 'none':
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                'Please install apex: https://github.com/NVIDIA/apex')
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=opt.amp_opt_lvl,
            keep_batchnorm_fp32=opt.keep_batchnorm_fp32)

    ##################################################################################################################
    # Multi GPU
    ##################################################################################################################
    if opt.n_gpu > 1:
        if opt.amp_opt_lvl != 'none':
            from apex.parallel import DistributedDataParallel
            forward_fn = DistributedDataParallel(model)
        else:
            forward_fn = torch.nn.parallel.DistributedDataParallel(model)
    else:
        forward_fn = model

    ##################################################################################################################
    # Training
    ##################################################################################################################
    cudnn.benchmark = True  # Activate benchmarks to select the fastest algorithms
    assert opt.niter > 0
    # Progress bar
    if opt.local_rank == 0:
        pb = tqdm(total=opt.niter, ncols=0)
    itr = 0
    finished = False
    try:
        while not finished:
            if sampler is not None:
                sampler.set_epoch(opt.seed + itr)
            # -------- TRAIN --------
            for batch in train_loader:
                # Stop when the given number of optimization steps have been done
                if itr >= opt.niter:
                    finished = True
                    status_code = 0
                    break
                itr += 1
                # Closure
                model.train()
                loss, disto, rate_y_0, rate_z = train(forward_fn, optimizer,
                                                      batch, device, opt)
                # Learning rate scheduling
                if itr >= opt.lr_scheduling_burnin:
                    lr_scheduler.step()
                # Progress bar
                if opt.local_rank == 0:
                    pb.set_postfix(loss=loss,
                                   disto=disto,
                                   rate_y_0=rate_y_0,
                                   rate_z=rate_z,
                                   refresh=False)
                    pb.update()
    except KeyboardInterrupt:
        status_code = 130

    if opt.local_rank == 0:
        pb.close()
    print('Done')
    # Save model
    print('Saving...')
    torch.save(model.state_dict(), os.path.join(opt.save_path, 'model.pt'))
    return status_code
Esempio n. 2
0
def main(opt):
    """
    Tests SRVP.

    Parameters
    ----------
    opt : DotDict
        Contains the testing configuration.
    """
    ##################################################################################################################
    # Setup
    ##################################################################################################################
    # -- Device handling (CPU, GPU)
    opt.train = False
    if opt.device is None:
        device = torch.device('cpu')
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.device)
        device = torch.device('cuda:0')
        torch.cuda.set_device(0)
    # Seed
    random.seed(opt.test_seed)
    np.random.seed(opt.test_seed)
    torch.manual_seed(opt.test_seed)
    # cuDNN
    assert torch.backends.cudnn.enabled
    # Load LPIPS model
    global lpips_model
    lpips_model = PerceptualLoss(opt.lpips_dir)

    ##################################################################################################################
    # Load XP config
    ##################################################################################################################
    xp_config = helper.load_json(os.path.join(opt.xp_dir, 'config.json'))
    nt_cond = opt.nt_cond if opt.nt_cond is not None else xp_config.nt_cond
    nt_test = opt.nt_gen if opt.nt_gen is not None else xp_config.seq_len_test

    ##################################################################################################################
    # Load test data
    ##################################################################################################################
    print('Loading data...')
    xp_config.data_dir = opt.data_dir
    xp_config.seq_len = nt_test
    dataset = data.load_dataset(xp_config, train=False)
    testset = dataset.get_fold('test')
    test_loader = DataLoader(testset,
                             batch_size=opt.batch_size,
                             collate_fn=data.collate_fn,
                             pin_memory=True)

    ##################################################################################################################
    # Load model
    ##################################################################################################################
    print('Loading model...')
    model = srvp.StochasticLatentResidualVideoPredictor(
        xp_config.nx, xp_config.nc, xp_config.nf, xp_config.nhx, xp_config.ny,
        xp_config.nz, xp_config.skipco, xp_config.nt_inf, xp_config.nh_inf,
        xp_config.nlayers_inf, xp_config.nh_res, xp_config.nlayers_res,
        xp_config.archi)
    state_dict = torch.load(os.path.join(opt.xp_dir, 'model.pt'),
                            map_location='cpu')
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    ##################################################################################################################
    # Eval
    ##################################################################################################################
    print('Generating samples...')
    torch.set_grad_enabled(False)
    best_samples = defaultdict(list)
    worst_samples = defaultdict(list)
    results = defaultdict(list)
    cond = []
    cond_rec = []
    gt = []
    random_samples = [[] for _ in range(5)]
    # Evaluation is done by batch
    for batch in tqdm(test_loader, ncols=80, desc='evaluation'):
        # Data
        x = batch.to(device)
        assert nt_test <= len(x)
        x = x[:nt_test]
        x_cond = x[:nt_cond]
        x_target = x[nt_cond:]
        cond.append(x_cond.cpu().mul(255).byte().permute(1, 0, 3, 4, 2))
        gt.append(x_target.cpu().mul(255).byte().permute(1, 0, 3, 4, 2))
        # Predictions
        metric_best = {}
        sample_best = {}
        metric_worst = {}
        sample_worst = {}
        # Encode conditional frames and extracts skip connections
        skip = model.encode(x_cond)[1] if model.skipco != 'none' else None
        # Generate opt.n_samples predictions
        for i in range(opt.n_samples):
            # Infer latent variables
            x_rec, y, _, w, _, _, _, _ = model(x_cond,
                                               nt_cond,
                                               dt=1 / xp_config.n_euler_steps)
            y_0 = y[-1]
            if i == 0:
                x_rec = x_rec[::xp_config.n_euler_steps]
                cond_rec.append(x_rec.cpu().mul(255).byte().permute(
                    1, 0, 3, 4, 2))
            # Use the model in prediction mode starting from the last inferred state
            y_os = model.generate(y_0, [],
                                  nt_test - nt_cond + 1,
                                  dt=1 / xp_config.n_euler_steps)[0]
            y = y_os[xp_config.n_euler_steps::xp_config.
                     n_euler_steps].contiguous()
            x_pred = model.decode(w, y, skip).clamp(0, 1)
            # Pixelwise quantitative eval
            mse = torch.mean(F.mse_loss(x_pred, x_target, reduction='none'),
                             dim=[3, 4])
            metrics_batch = {
                'psnr': 10 * torch.log10(1 / mse).mean(2).mean(0).cpu(),
                'ssim': _ssim_wrapper(x_pred, x_target).mean(2).mean(0).cpu(),
                'lpips': _lpips_wrapper(x_pred, x_target).mean(0).cpu()
            }
            x_pred_byte = x_pred.cpu().mul(255).byte().permute(1, 0, 3, 4, 2)
            if i < 5:
                random_samples[i].append(x_pred_byte)
            for name, values in metrics_batch.items():
                if i == 0:
                    metric_best[name] = values.clone()
                    sample_best[name] = x_pred_byte.clone()
                    metric_worst[name] = values.clone()
                    sample_worst[name] = x_pred_byte.clone()
                    continue
                # Best samples
                idx_better = _get_idx_better(name, metric_best[name], values)
                metric_best[name][idx_better] = values[idx_better]
                sample_best[name][idx_better] = x_pred_byte[idx_better]
                # Worst samples
                idx_worst = _get_idx_worst(name, metric_worst[name], values)
                metric_worst[name][idx_worst] = values[idx_worst]
                sample_worst[name][idx_worst] = x_pred_byte[idx_worst]
        # Compute metrics for best samples and register
        for name in sample_best.keys():
            best_samples[name].append(sample_best[name])
            worst_samples[name].append(sample_worst[name])
            results[name].append(metric_best[name])
    # Store best, worst and random samples
    samples = {
        f'random_{i + 1}': torch.cat(random_sample).numpy()
        for i, random_sample in enumerate(random_samples)
    }
    samples['cond_rec'] = torch.cat(cond_rec)
    for name in best_samples.keys():
        samples[f'{name}_best'] = torch.cat(best_samples[name]).numpy()
        samples[f'{name}_worst'] = torch.cat(worst_samples[name]).numpy()
        results[name] = torch.cat(results[name]).numpy()

    ##################################################################################################################
    # Compute FVD
    ##################################################################################################################
    print('Computing FVD...')
    cond = torch.cat(cond, 0).permute(1, 0, 4, 2, 3).float().div(255)
    gt = torch.cat(gt, 0).permute(1, 0, 4, 2, 3).float().div(255)
    ref = torch.cat([cond, gt], 0)
    hyp = torch.from_numpy(samples['random_1']).clone().permute(
        1, 0, 4, 2, 3).float().div(255)
    hyp = torch.cat([cond, hyp], 0)
    fvd = fvd_score(ref, hyp)

    ##################################################################################################################
    # Print results
    ##################################################################################################################
    print('\n')
    print('Results:')
    for name, res in results.items():
        print(name, res.mean(), '+/-', 1.960 * res.std() / np.sqrt(len(res)))
    print(f'FVD', fvd)

    ##################################################################################################################
    # Save samples
    ##################################################################################################################
    np.savez_compressed(os.path.join(opt.xp_dir, 'results.npz'), **results)
    for name, res in samples.items():
        np.savez_compressed(os.path.join(opt.xp_dir, f'{name}.npz'),
                            samples=res)
Esempio n. 3
0
def main(opt):
    """
    Trains SRVP and saved the resulting model.

    Parameters
    ----------
    opt : helper.DotDict
        Contains the training configuration.
    """
    ##################################################################################################################
    # Setup
    ##################################################################################################################
    # Device handling (CPU, GPU, multi GPU)
    if opt.device is None:
        device = torch.device('cpu')
        opt.n_gpu = 0
    else:
        opt.n_gpu = len(opt.device)
        os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.device[opt.local_rank])
        device = torch.device('cuda:0')
        torch.cuda.set_device(0)
        # In the case of multi GPU: sets up distributed training
        if opt.n_gpu > 1 or opt.local_rank > 0:
            torch.distributed.init_process_group(backend='nccl')
            # Since we are in distributed mode, divide batch size by the number of GPUs
            assert opt.batch_size % opt.n_gpu == 0
            opt.batch_size = opt.batch_size // opt.n_gpu
    # Seed
    if opt.seed is None:
        opt.seed = random.randint(1, 10000)
    else:
        assert isinstance(opt.seed, int) and opt.seed > 0
    print(f'Learning on {opt.n_gpu} GPU(s) (seed: {opt.seed})')
    random.seed(opt.seed)
    np.random.seed(opt.seed + opt.local_rank)
    torch.manual_seed(opt.seed)
    # cuDNN
    if opt.n_gpu > 1 or opt.local_rank > 0:
        assert torch.backends.cudnn.enabled
        cudnn.deterministic = True
    # Mixed-precision training
    if opt.torch_amp and not torch_amp_imported:
        raise ImportError(
            'Mixed-precision not supported by this PyTorch version, upgrade PyTorch or use Apex'
        )
    if opt.apex_amp and not apex_amp_imported:
        raise ImportError(
            'Apex not installed (https://github.com/NVIDIA/apex)')

    ##################################################################################################################
    # Data
    ##################################################################################################################
    print('Loading data...')
    # Load data
    dataset = data.load_dataset(opt, True)
    trainset = dataset.get_fold('train')
    valset = dataset.get_fold('val')
    # Change validation sequence length, if specified
    if opt.seq_len_test is not None:
        valset.change_seq_len(opt.seq_len_test)

    # Handle random seed for dataloader workers
    def worker_init_fn(worker_id):
        np.random.seed(
            (opt.seed + itr + opt.local_rank * opt.n_workers + worker_id) %
            (2**32 - 1))

    # Dataloader
    sampler = None
    shuffle = True
    if opt.n_gpu > 1:
        # Let the distributed sampler shuffle for the distributed case
        sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        shuffle = False
    train_loader = DataLoader(trainset,
                              batch_size=opt.batch_size,
                              collate_fn=data.collate_fn,
                              sampler=sampler,
                              num_workers=opt.n_workers,
                              shuffle=shuffle,
                              drop_last=True,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(
        valset,
        batch_size=opt.batch_size_test,
        collate_fn=data.collate_fn,
        num_workers=opt.n_workers,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        worker_init_fn=worker_init_fn) if opt.local_rank == 0 else None

    ##################################################################################################################
    # Model
    ##################################################################################################################
    # Buid model
    print('Building model...')
    model = srvp.StochasticLatentResidualVideoPredictor(
        opt.nx, opt.nc, opt.nf, opt.nhx, opt.ny, opt.nz, opt.skipco,
        opt.nt_inf, opt.nh_inf, opt.nlayers_inf, opt.nh_res, opt.nlayers_res,
        opt.archi)
    model.init(res_gain=opt.res_gain)
    # Make the batch norms in the model synchronized in the distributed case
    if opt.n_gpu > 1:
        if opt.apex_amp:
            from apex.parallel import convert_syncbn_model
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)

    ##################################################################################################################
    # Optimizer
    ##################################################################################################################
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    opt.n_iter = opt.lr_scheduling_burnin + opt.lr_scheduling_n_iter
    lr_sch_n_iter = opt.lr_scheduling_n_iter
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda i: max(0, (lr_sch_n_iter - i) / lr_sch_n_iter))

    ##################################################################################################################
    # Automatic Mixed Precision
    ##################################################################################################################
    scaler = None
    if opt.torch_amp:
        scaler = torch_amp.GradScaler()
    if opt.apex_amp:
        model, optimizer = apex_amp.initialize(
            model,
            optimizer,
            opt_level=opt.amp_opt_lvl,
            keep_batchnorm_fp32=opt.keep_batchnorm_fp32,
            verbosity=opt.apex_verbose)

    ##################################################################################################################
    # Multi GPU
    ##################################################################################################################
    if opt.n_gpu > 1:
        if opt.apex_amp:
            from apex.parallel import DistributedDataParallel
            forward_fn = DistributedDataParallel(model)
        else:
            forward_fn = torch.nn.parallel.DistributedDataParallel(model)
    else:
        forward_fn = model

    ##################################################################################################################
    # Training
    ##################################################################################################################
    cudnn.benchmark = True  # Activate benchmarks to select the fastest algorithms
    assert opt.n_iter > 0
    itr = 0
    finished = False
    # Progress bar
    if opt.local_rank == 0:
        pb = tqdm(total=opt.n_iter, ncols=0)
    # Current and best model evaluation metric (lower is better)
    val_metric = None
    best_val_metric = None
    try:
        while not finished:
            if sampler is not None:
                sampler.set_epoch(opt.seed + itr)
            # -------- TRAIN --------
            for batch in train_loader:
                # Stop when the given number of optimization steps have been done
                if itr >= opt.n_iter:
                    finished = True
                    status_code = 0
                    break

                itr += 1
                model.train()
                # Optimization step on batch
                # Allow PyTorch's mixed-precision computations if required while ensuring retrocompatibilty
                with (torch_amp.autocast()
                      if opt.torch_amp else nullcontext()):
                    loss, nll, kl_y_0, kl_z = train(forward_fn, optimizer,
                                                    scaler, batch, device, opt)

                # Learning rate scheduling
                if itr >= opt.lr_scheduling_burnin:
                    lr_scheduler.step()

                # Evaluation and model saving are performed on the process with local rank zero
                if opt.local_rank == 0:
                    # Evaluation
                    if itr % opt.val_interval == 0:
                        model.eval()
                        val_metric = evaluate(forward_fn, val_loader, device,
                                              opt)
                        if best_val_metric is None or best_val_metric > val_metric:
                            best_val_metric = val_metric
                            torch.save(
                                model.state_dict(),
                                os.path.join(opt.save_path, 'model_best.pt'))

                    # Checkpointing
                    if opt.chkpt_interval is not None and itr % opt.chkpt_interval == 0:
                        torch.save(
                            model.state_dict(),
                            os.path.join(opt.save_path, f'model_{itr}.pt'))

                # Progress bar
                if opt.local_rank == 0:
                    pb.set_postfix(
                        {
                            'loss': loss,
                            'nll': nll,
                            'kl_y_0': kl_y_0,
                            'kl_z': kl_z,
                            'val_metric': val_metric,
                            'best_val_metric': best_val_metric
                        },
                        refresh=False)
                    pb.update()

    except KeyboardInterrupt:
        status_code = 130

    if opt.local_rank == 0:
        pb.close()
    # Save model
    print('Saving...')
    if opt.local_rank == 0:
        torch.save(model.state_dict(), os.path.join(opt.save_path, 'model.pt'))
    print('Done')
    return status_code