示例#1
0
def main(model, res=(512, ), pyramids=None, up_pyramid=False, max_depth=None):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device
    from hyperseg.utils.img_utils import create_pyramid

    assert len(
        res
    ) <= 2, f'res must be either a single number or a pair of numbers: "{res}"'
    res = res * 2 if len(res) == 1 else res

    device, gpus = set_device()
    model = obj_factory(model).to(device)

    x = torch.rand(1, 3, *res).to(device)
    x = create_pyramid(x, pyramids) if pyramids is not None else x
    if up_pyramid:
        x.append(
            F.interpolate(x[0],
                          scale_factor=2,
                          mode='bilinear',
                          align_corners=False))  # Upsample x2

    # Run profile
    flops_summary, params_summary, meta_params_summary = profile(
        model, inputs=(x, ), max_depth=max_depth)
    print_summary(flops_summary, params_summary, meta_params_summary)
示例#2
0
def main(input, label, img_transforms=None, tensor_transforms=None):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.img_utils import tensor2rgb
    from hyperseg.datasets.seg_transforms import Compose
    from PIL import Image

    # Initialize transforms
    img_transforms = obj_factory(
        img_transforms) if img_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    transform = Compose(img_transforms + tensor_transforms)

    # Read input image and corresponding label
    img = Image.open(input).convert('RGB')
    lbl = Image.open(label)
    palette = lbl.getpalette()

    # Apply transformations
    img_t, lbl_t = transform(img, lbl)

    if isinstance(img_t, (list, tuple)):
        img_t = img_t[-1]
        if lbl_t.shape[-2:] != img_t.shape[-2:]:
            lbl_t = lbl_t
            lbl_t = interpolate(lbl_t.float().view(1, 1, *lbl_t.shape),
                                img_t.shape[-2:],
                                mode='nearest').long().squeeze()

    # Render results
    img, lbl = np.array(img), np.array(lbl.convert('RGB'))
    img_t = img_t[0] if isinstance(img_t, (list, tuple)) else img_t
    img_t = tensor2rgb(img_t)
    lbl_t = Image.fromarray(lbl_t.squeeze().numpy().astype('uint8'), mode='P')
    lbl_t.putpalette(palette)
    lbl_t = np.array(lbl_t.convert('RGB'))

    render_img_orig = np.concatenate((img, lbl), axis=1)
    render_img_transformed = np.concatenate((img_t, lbl_t), axis=1)
    f, ax = plt.subplots(2, 1, figsize=(8, 8))
    ax[0].imshow(render_img_orig)
    ax[1].imshow(render_img_transformed)
    plt.show()
    pass
示例#3
0
def main(model="efficientnet_custom_02_scales.efficientnet('efficientnet-b0')", res=256):
    import torch
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device

    device, gpus = set_device()
    model = obj_factory(model).to(device)
    x = torch.rand(4, 3, res, res).to(device)
    pred = model(x)
    print(pred.__class__)
示例#4
0
def main(dataset='hyperseg.datasets.voc_sbd.VOCSBDDataset',
         train_img_transforms=None,
         val_img_transforms=None,
         tensor_transforms=('seg_transforms.ToTensor',
                            'seg_transforms.Normalize'),
         workers=4,
         batch_size=4):
    from hyperseg.utils.obj_factory import obj_factory

    dataset = obj_factory(dataset)
    print(len(dataset))
示例#5
0
def main(dataset='hyperseg.datasets.cityscapes.CityscapesDataset',
         train_img_transforms=None,
         val_img_transforms=None,
         tensor_transforms=('seg_transforms.ToTensor',
                            'seg_transforms.Normalize'),
         workers=4,
         batch_size=4):
    from hyperseg.utils.obj_factory import obj_factory

    dataset = obj_factory(dataset)
    data = dataset[0]
    print(len(dataset))
示例#6
0
def main(dataset='hyperseg.datasets.camvid.CamVidDataset',
         train_img_transforms=None,
         val_img_transforms=None,
         tensor_transforms=('seg_transforms.ToTensor',
                            'seg_transforms.Normalize'),
         workers=4,
         batch_size=4):
    from hyperseg.utils.obj_factory import obj_factory

    dataset = obj_factory(dataset)
    for img, target in dataset:
        print(img)
        print(target.shape)
    print(len(dataset))
示例#7
0
def main(model='hyperseg.models.layers.meta_linear.MetaLinear',
         in_features=3,
         out_features=5):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device

    device, gpus = set_device()
    model = obj_factory(model,
                        in_features=in_features,
                        out_features=out_features).to(device)
    print(model)
    x = torch.rand(2, in_features).to(device)
    w = torch.ones(2, out_features * in_features).to(device)
    out = model(x, w)
    print(out.shape)
示例#8
0
def load_model(model_path,
               name='',
               device=None,
               arch=None,
               return_checkpoint=False,
               train=False):
    """ Load a model from checkpoint.

    This is a utility function that combines the model weights and architecture (string representation) to easily
    load any model without explicit knowledge of its class.

    Args:
        model_path (str): Path to the model's checkpoint (.pth)
        name (str): The name of the model (for printing and error management)
        device (torch.device): The device to load the model to
        arch (str): The model's architecture (string representation)
        return_checkpoint (bool): If True, the checkpoint will be returned as well
        train (bool): If True, the model will be set to train mode, else it will be set to test mode

    Returns:
        (nn.Module, dict (optional)): A tuple that contains:
            - model (nn.Module): The loaded model
            - checkpoint (dict, optional): The model's checkpoint (only if return_checkpoint is True)
    """
    assert model_path is not None, '%s model must be specified!' % name
    assert os.path.exists(
        model_path), 'Couldn\'t find %s model in path: %s' % (name, model_path)
    print('=> Loading %s model: "%s"...' %
          (name, os.path.basename(model_path)))
    checkpoint = torch.load(model_path)
    assert arch is not None or 'arch' in checkpoint, 'Couldn\'t determine %s model architecture!' % name
    arch = checkpoint['arch'] if arch is None else arch
    model = obj_factory(arch)
    if device is not None:
        model.to(device)
    model.load_state_dict(
        remove_data_parallel_from_state_dict(checkpoint['state_dict']))
    model.train(train)

    if return_checkpoint:
        return model, checkpoint
    else:
        return model
示例#9
0
def main(model='hyperseg.models.hyperseg_v0_1.hyperseg_efficientnet',
         res=(512, ),
         pyramids=None,
         train=False):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device
    from hyperseg.utils.img_utils import create_pyramid

    assert len(
        res
    ) <= 2, f'res must be either a single number or a pair of numbers: "{res}"'
    res = res * 2 if len(res) == 1 else res

    device, gpus = set_device()
    model = obj_factory(model).to(device).train(train)
    x = torch.rand(2, 3, *res).to(device)
    x = create_pyramid(x, pyramids) if pyramids is not None else x
    pred = model(x)
    print(pred.shape)
def main(model='hyperseg.models.hyperseg_v1_0.hypergen_efficientnet', res=(512,),
         pyramids=None,
         train=False):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device
    from hyperseg.utils.img_utils import create_pyramid
    from tqdm import tqdm

    assert len(res) <= 2, f'res must be either a single number or a pair of numbers: "{res}"'
    res = res * 2 if len(res) == 1 else res

    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    device, gpus = set_device()
    model = obj_factory(model).to(device).train(train)
    x = torch.rand(1, 3, *res).to(device)
    x = create_pyramid(x, pyramids) if pyramids is not None else x
    pred = model(x)
    print(pred.shape)
示例#11
0
def main(model, res=(512, ), pyramids=None, max_depth=None):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device
    from hyperseg.utils.img_utils import create_pyramid

    assert len(
        res
    ) <= 2, f'res must be either a single number or a pair of numbers: "{res}"'
    res = res * 2 if len(res) == 1 else res

    device, gpus = set_device()
    model = obj_factory(model).to(device)

    x = torch.rand(1, 3, *res).to(device)
    x = create_pyramid(x, pyramids) if pyramids is not None else x

    # Run profile
    flops_summary, params_summary = profile(model,
                                            inputs=(x, ),
                                            max_depth=max_depth)
    print_summary(flops_summary, params_summary)
示例#12
0
def main(model='hyperseg.models.layers.meta_conv.MetaConv2d(kernel_size=3)', in_channels=10, out_channels=20,
         padding=0, test_fps=False):
    from hyperseg.utils.obj_factory import obj_factory
    from hyperseg.utils.utils import set_device
    import time
    from tqdm import tqdm

    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    device, gpus = set_device()
    model = obj_factory(model, in_channels=in_channels, out_channels=out_channels).to(device)
    patch_model = MetaPatch(model, padding=padding)

    x = torch.rand(2, in_channels, 256, 256).to(device)
    w = torch.ones(2, model.hyper_params, 8, 8).to(device)
    out = patch_model(x, w)
    print(out.shape)

    if test_fps:
        total_time = 0.
        total_iterations = 0
        pbar = tqdm(range(1000), unit='frames')
        for i in pbar:
            # Start measuring time
            torch.cuda.synchronize()
            start_time = time.perf_counter()

            out = patch_model(x[:1], w[:1])

            # Stop measuring time
            torch.cuda.synchronize()
            elapsed_time = time.perf_counter() - start_time
            total_time += elapsed_time
            total_iterations += out.shape[0]
            fps = total_iterations / total_time

            # Update progress bar info
            pbar.set_description(f'fps = {fps}')
示例#13
0
def main(
    # General arguments
    exp_dir,
    model=d('model'),
    gpus=d('gpus'),
    cpu_only=d('cpu_only'),
    workers=d('workers'),
    batch_size=d('batch_size'),
    arch=d('arch'),
    display_worst=d('display_worst'),
    display_best=d('display_best'),
    display_sources=d('display_sources'),
    display_with_input=d('display_with_input'),
    display_alpha=d('display_alpha'),
    display_background_index=d('display_background_index'),
    forced=d('forced'),

    # Data arguments
    test_dataset=d('test_dataset'),
    img_transforms=d('img_transforms'),
    tensor_transforms=d('tensor_transforms')):
    # Validation
    assert os.path.isdir(
        exp_dir), f'exp_dir "{exp_dir}" must be a path to a directory'
    model = 'model_best.pth' if model is None else model
    model = os.path.join(exp_dir,
                         model) if not os.path.isfile(model) else model
    assert os.path.isfile(model), f'model path "{model}" does not exist'

    # Initialize cache directory
    cache_dir = os.path.join(exp_dir,
                             os.path.splitext(os.path.basename(__file__))[0])
    scores_path = os.path.join(cache_dir, 'scores.npz')
    os.makedirs(cache_dir, exist_ok=True)

    # Initialize device
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    device, gpus = set_device(gpus, not cpu_only)

    # Load segmentation model
    model = load_model(model, 'segmentation', device, arch)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        model = nn.DataParallel(model, gpus)

    # Initialize transforms
    img_transforms = obj_factory(
        img_transforms) if img_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    test_transforms = Compose(img_transforms + tensor_transforms)

    # Initialize dataset
    test_dataset = obj_factory(test_dataset, transforms=test_transforms)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             pin_memory=True,
                             drop_last=False,
                             shuffle=False)

    # Initialize metric
    num_classes = len(test_dataset.classes)
    confmat = ConfusionMatrix(num_classes=num_classes)

    if forced or not os.path.isfile(scores_path):
        # For each batch of frames in the input video
        ious = []
        for i, (input, target) in enumerate(
                tqdm(test_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            if isinstance(input, (list, tuple)):
                for j in range(len(input)):
                    input[j] = input[j].to(device)
            else:
                input = input.to(device)
            target = target.to(device)

            # Execute model
            pred = model(input)
            if pred.shape[2:] != target.shape[
                    1:]:  # Make sure the prediction and target are of the same resolution
                pred = F.interpolate(pred,
                                     size=target.shape[1:],
                                     mode='bilinear')

            # Update confusion matrix
            confmat.update(
                target.flatten(),
                pred.argmax(1).flatten()
                if pred.dim() == 4 else pred.flatten())

            # Calculate IoU scores
            for b in range(target.shape[0]):
                ious.append(
                    jaccard(target[b].unsqueeze(0), pred[b].unsqueeze(0),
                            num_classes, 0).item())
        # Save metrics to file
        ious = np.array(ious)
        global_acc, class_acc, class_iou = confmat.compute()
        global_acc = global_acc.item()
        class_acc = class_acc.cpu().numpy()
        class_iou = class_iou.cpu().numpy()
        np.savez(scores_path,
                 ious=ious,
                 global_acc=global_acc,
                 class_acc=class_acc,
                 class_iou=class_iou)
    else:  # Load metrics from file
        scores_archive = np.load(scores_path)
        ious = scores_archive['ious']
        global_acc = scores_archive['global_acc']
        class_acc = scores_archive['class_acc']
        class_iou = scores_archive['class_iou']

    # Print results
    print(f'global_acc={global_acc}')
    print(f'class_acc={class_acc}')
    print(f'class_iou={class_iou}')
    print(f'mIoU={np.mean(class_iou)}')

    # Display edge predictions
    indices = np.argsort(ious)
    if display_worst:
        print('Displaying worst predictions...')
        display_subset(test_dataset,
                       indices[:display_worst],
                       model,
                       device,
                       batch_size,
                       scale=0.5,
                       alpha=display_alpha,
                       with_input=display_with_input,
                       display_sources=display_sources,
                       ignore_index=display_background_index)
    if display_best:
        print('Displaying best predictions...')
        display_subset(test_dataset,
                       indices[-display_best:],
                       model,
                       device,
                       batch_size,
                       scale=0.5,
                       alpha=display_alpha,
                       with_input=display_with_input,
                       display_sources=display_sources,
                       ignore_index=display_background_index)
示例#14
0
def main(
    # General arguments
    exp_dir,
    resume=d('resume'),
    start_epoch=d('start_epoch'),
    epochs=d('epochs'),
    train_iterations=d('train_iterations'),
    val_iterations=d('val_iterations'),
    gpus=d('gpus'),
    workers=d('workers'),
    batch_size=d('batch_size'),
    seed=d('seed'),
    log_freq=d('log_freq'),
    log_max_res=d('log_max_res'),

    # Data arguments
    train_dataset=d('train_dataset'),
    val_dataset=d('val_dataset'),
    train_img_transforms=d('train_img_transforms'),
    val_img_transforms=d('val_img_transforms'),
    tensor_transforms=d('tensor_transforms'),

    # Training arguments
    optimizer=d('optimizer'),
    scheduler=d('scheduler'),
    criterion=d('criterion'),
    model=d('model'),
    pretrained=d('pretrained'),
    benchmark=d('benchmark'),
    hard_negative_mining=d('hard_negative_mining'),
    batch_scheduler=d('batch_scheduler')):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')
        logger.reset()

        # Set networks training mode
        model.train(train)

        # For each batch
        for i, (input, target) in enumerate(pbar):
            # Set logger prefix
            logger.prefix = f'{stage}: Epoch: {epoch + 1} / {epochs}; LR: {scheduler.get_last_lr()[0]:.1e}; '

            # Prepare input
            with torch.no_grad():
                if isinstance(input, (list, tuple)):
                    for j in range(len(input)):
                        input[j] = input[j].to(device)
                else:
                    input = input.to(device)
                target = target.to(device)

            # Execute model
            pred = model(input)
            if pred.shape[2:] != target.shape[
                    1:]:  # Make sure the prediction and target are of the same resolution
                pred = F.interpolate(pred,
                                     size=target.shape[1:],
                                     mode='bilinear')

            # Calculate loss
            loss_total = criterion(pred, target)

            # Benchmark
            running_metrics.update(target.cpu().numpy(),
                                   pred.argmax(1).cpu().numpy())

            if train:
                # Update generator weights
                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()

                # Scheduler step
                if batch_scheduler:
                    scheduler.step()

            logger.update('losses', total=loss_total)
            metric_scores = {
                'iou': running_metrics.get_scores()[0]["Mean IoU : \t"]
            }
            logger.update('bench', **metric_scores)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('batch', total_iter)

        # Epoch logs
        logger.log_scalars_avg('epoch/%s' % ('train' if train else 'val'),
                               epoch,
                               category='losses')
        logger.log_scalars_val('epoch/%s' % ('train' if train else 'val'),
                               epoch,
                               category='bench')
        if not train:
            # Log images
            input = input[0] if isinstance(input, (list, tuple)) else input
            input = limit_resolution(input, log_max_res, 'bilinear')
            pred = limit_resolution(pred, log_max_res, 'bilinear')
            target = limit_resolution(target.unsqueeze(1), log_max_res,
                                      'nearest').squeeze(1)
            seg_pred = blend_seg(input,
                                 pred,
                                 train_dataset.color_map,
                                 alpha=0.75)
            seg_gt = blend_seg(input,
                               target,
                               train_dataset.color_map,
                               alpha=0.75)
            grid = make_grid(input, seg_pred, seg_gt)
            logger.log_image('vis', grid, epoch)

        return logger.log_dict['losses']['total'].avg, logger.log_dict[
            'bench']['iou'].val

    #################
    # Main pipeline #
    #################
    global_iterations = epochs * train_iterations

    # Initialize logger
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Setup seeds
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    # Setup device
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize datasets
    train_img_transforms = obj_factory(
        train_img_transforms) if train_img_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    train_transforms = Compose(train_img_transforms + tensor_transforms)
    train_dataset = obj_factory(train_dataset, transforms=train_transforms)
    if val_dataset is not None:
        val_img_transforms = obj_factory(
            val_img_transforms) if val_img_transforms is not None else []
        val_transforms = Compose(val_img_transforms + tensor_transforms)
        val_dataset = obj_factory(val_dataset, transforms=val_transforms)

    # Initialize loaders
    sampler = RandomSampler(
        train_dataset, True,
        train_iterations) if train_iterations is not None else None
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=workers,
                              sampler=sampler,
                              shuffle=sampler is None,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            num_workers=workers,
                            shuffle=False,
                            pin_memory=True)

    # Setup Metrics
    running_metrics = runningScore(len(train_dataset.classes))

    # Create model
    arch = get_arch(model, num_classes=len(train_dataset.classes))
    model = obj_factory(model,
                        num_classes=len(train_dataset.classes)).to(device)

    # Optimizer and scheduler
    optimizer = obj_factory(optimizer, model.parameters())
    scheduler = obj_factory(scheduler, optimizer)

    # Resume
    start_epoch = 0
    best_iou = 0.
    if resume is None:
        model_path, checkpoint_dir = os.path.join(exp_dir,
                                                  'model_latest.pth'), exp_dir
    elif os.path.isdir(resume):
        model_path, checkpoint_dir = os.path.join(resume,
                                                  'model_latest.pth'), resume
    else:  # resume is path to a checkpoint file
        model_path, checkpoint_dir = resume, os.path.split(resume)[0]
    if os.path.isfile(model_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # model
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint[
            'epoch'] if 'epoch' in checkpoint else start_epoch
        best_iou = checkpoint[
            'best_iou'] if 'best_iou' in checkpoint else best_iou
        model.apply(init_weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            model.apply(init_weights)

    # Lossess
    criterion = obj_factory(criterion).to(device)

    # Benchmark
    # benchmark = obj_factory(benchmark).to(device)

    # Support multiple GPUs
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model,
                                device_ids=range(torch.cuda.device_count()))

    # For each epoch
    for epoch in range(start_epoch, epochs):
        # Training step
        epoch_loss, epoch_iou = proces_epoch(train_loader, train=True)

        # Validation step
        if val_loader is not None:
            with torch.no_grad():
                running_metrics.reset()
                epoch_loss, epoch_iou = proces_epoch(val_loader, train=False)
        running_metrics.reset()

        # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
        if not batch_scheduler:
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(epoch_loss)
            else:
                scheduler.step()

        # Save models checkpoints
        is_best = epoch_iou > best_iou
        best_iou = max(epoch_iou, best_iou)
        save_checkpoint(
            exp_dir, 'model', {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_iou': best_iou,
                'arch': arch
            }, is_best)
示例#15
0
def main(
    # General arguments
    exp_dir,
    model=d('model'),
    gpus=d('gpus'),
    cpu_only=d('cpu_only'),
    workers=d('workers'),
    batch_size=d('batch_size'),
    arch=d('arch'),
    display_worst=d('display_worst'),
    display_best=d('display_best'),
    display_sources=d('display_sources'),
    forced=d('forced'),
    trace=d('trace'),
    iterations=d('iterations'),

    # Data arguments
    test_dataset=d('test_dataset'),
    img_transforms=d('img_transforms'),
    tensor_transforms=d('tensor_transforms')):
    # Validation
    assert os.path.isdir(
        exp_dir), f'exp_dir "{exp_dir}" must be a path to a directory'
    model = 'model_best.pth' if model is None else model
    model = os.path.join(exp_dir,
                         model) if not os.path.isfile(model) else model
    if not os.path.isfile(model):
        model = None

    # Initialize cache directory
    cache_dir = os.path.join(exp_dir,
                             os.path.splitext(os.path.basename(__file__))[0])
    scores_path = os.path.join(cache_dir, 'scores.npz')
    os.makedirs(cache_dir, exist_ok=True)

    # Initialize device
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    device, gpus = set_device(gpus, not cpu_only)

    # Initialize transforms
    img_transforms = obj_factory(
        img_transforms) if img_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    test_transforms = Compose(img_transforms + tensor_transforms)

    # Initialize dataset
    test_dataset = obj_factory(test_dataset, transforms=test_transforms)
    test_sampler = None if iterations is None else RandomSampler(
        test_dataset, True, iterations)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             pin_memory=True,
                             drop_last=False,
                             shuffle=False,
                             sampler=test_sampler)

    # Load segmentation model
    if model is None:
        assert arch is not None
        model = obj_factory(arch).to(device)
    else:
        model = load_model(model, 'segmentation', device, arch)

    # Remove BN
    model = remove_bn(model)

    # Trace model
    if trace:
        sample, target = test_dataset[0]
        model = torch.jit.trace(model, sample.unsqueeze(0).to(device))

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        model = nn.DataParallel(model, gpus)

    # Initialize metric
    num_classes = len(test_dataset.classes)
    confmat = ConfusionMatrix(num_classes=num_classes)

    if forced or not os.path.isfile(scores_path):
        for j in range(2):
            # For each batch of frames in the input video
            ious = []
            total_time = 0.
            total_iterations = 0
            pbar = tqdm(test_loader, unit='batches', file=sys.stdout)
            for i, (input, target) in enumerate(pbar):
                target = target.to(device)

                # Start measuring time
                torch.cuda.synchronize()
                start_time = time.perf_counter()

                # Prepare input
                if isinstance(input, (list, tuple)):
                    for j in range(len(input)):
                        input[j] = input[j].to(device)
                else:
                    input = input.to(device)

                # Execute model
                pred = model(input)

                # Stop measuring time
                torch.cuda.synchronize()
                elapsed_time = time.perf_counter() - start_time
                total_time += elapsed_time
                total_iterations += pred.shape[0]
                fps = total_iterations / total_time

                # Update confusion matrix
                confmat.update(
                    target.flatten(),
                    pred.argmax(1).flatten()
                    if pred.dim() == 4 else pred.flatten())

                # Calculate IoU scores
                for b in range(target.shape[0]):
                    ious.append(
                        jaccard(target[b].unsqueeze(0), pred[b].unsqueeze(0),
                                num_classes, 0).item())

                # Update progress bar info
                pbar.set_description(f'fps = {fps}')

        # Save metrics to file
        ious = np.array(ious)
        global_acc, class_acc, class_iou = confmat.compute()
        global_acc = global_acc.item()
        class_acc = class_acc.cpu().numpy()
        class_iou = class_iou.cpu().numpy()
        fps = len(test_loader) / total_time
        np.savez(scores_path,
                 ious=ious,
                 global_acc=global_acc,
                 class_acc=class_acc,
                 class_iou=class_iou,
                 fps=fps)
    else:  # Load metrics from file
        scores_archive = np.load(scores_path)
        ious = scores_archive['ious']
        global_acc = scores_archive['global_acc']
        class_acc = scores_archive['class_acc']
        class_iou = scores_archive['class_iou']
        fps = scores_archive['fps']

    # Print results
    print(f'global_acc={global_acc}')
    print(f'class_acc={class_acc}')
    print(f'class_iou={class_iou}')
    print(f'mIoU={np.mean(class_iou)}')
    print(f'fps={fps}')

    # Display edge predictions
    indices = np.argsort(ious)
    if display_worst:
        print('Displaying worst predictions...')
        display_subset(test_dataset,
                       indices[:display_worst],
                       model,
                       device,
                       batch_size,
                       scale=0.5,
                       alpha=0.5,
                       with_input=False,
                       display_sources=display_sources)
    if display_best:
        print('Displaying best predictions...')
        display_subset(test_dataset,
                       indices[-display_best:],
                       model,
                       device,
                       batch_size,
                       scale=0.5,
                       alpha=0.5,
                       with_input=False,
                       display_sources=display_sources)