def get_transform(train):
    transforms = []
    if train:
        transforms.append(Rescale(256))
        transforms.append(RandomCrop(224))
    else:
        transforms.append(Rescale(224))
    transforms.append(ToTensor())

    return Compose(transforms)
Example #2
0
def prepare_dataloaders(dataset_split,
                        dataset_path,
                        metadata_filename,
                        batch_size=32,
                        sample_size=-1,
                        valid_split=0.8):
    '''
    Utility function to prepare dataloaders for training.

    Parameters
    ----------
    dataset_split : str
        Any of 'train', 'extra', 'test'.
    dataset_path : str
        Absolute path to the dataset. (i.e. .../data/SVHN/train')
    metadata_filename : str
        Absolute path to the metadata pickle file.
    batch_size : int
        Mini-batch size.
    sample_size : int
        Number of elements to use as sample size,
        for debugging purposes only. If -1, use all samples.
    valid_split : float
        Returns a validation split of %size; valid_split*100,
        valid_split should be in range [0,1].

    Returns
    -------
    if dataset_split in ['train', 'extra']:
        train_loader: torch.utils.DataLoader
            Dataloader containing training data.
        valid_loader: torch.utils.DataLoader
            Dataloader containing validation data.

    if dataset_split in ['test']:
        test_loader: torch.utils.DataLoader
            Dataloader containing test data.

    '''

    assert dataset_split in ['train', 'test', 'extra'], "check dataset_split"

    metadata = load_obj(metadata_filename)

    #  dataset_path = datadir / dataset_split

    firstcrop = FirstCrop(0.3)
    rescale = Rescale((64, 64))
    random_crop = RandomCrop((54, 54))
    to_tensor = ToTensor()

    # Declare transformations

    transform = transforms.Compose(
        [firstcrop, rescale, random_crop, to_tensor])

    dataset = SVHNDataset(metadata, data_dir=dataset_path, transform=transform)

    indices = np.arange(len(metadata))
    #  indices = np.random.permutation(indices)

    # Only use a sample amount of data
    if sample_size != -1:
        indices = indices[:sample_size]

    if dataset_split in ['train', 'extra']:

        train_idx = indices[:round(valid_split * len(indices))]
        valid_idx = indices[round(valid_split * len(indices)):]

        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

        # Prepare a train and validation dataloader
        train_loader = DataLoader(dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=4,
                                  sampler=train_sampler)

        valid_loader = DataLoader(dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=4,
                                  sampler=valid_sampler)

        return train_loader, valid_loader

    elif dataset_split in ['test']:

        test_sampler = torch.utils.data.SequentialSampler(indices)

        # Prepare a test dataloader
        test_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=False,
                                 sampler=test_sampler)

        return test_loader
Example #3
0
def init(batch_size_labeled, batch_size_pseudo, state, split, input_sizes,
         sets_id, std, mean, keep_scale, reverse_channels, data_set, valtiny,
         no_aug):
    # Return data_loaders/data_loader
    # depending on whether the state is
    # 0: Pseudo labeling
    # 1: Semi-supervised training
    # 2: Fully-supervised training
    # 3: Just testing

    # For labeled set divisions
    split_u = split.replace('-r', '')
    split_u = split_u.replace('-l', '')

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomScale(min_scale=0.5, max_scale=1.5),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        if no_aug:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                Normalize(mean=mean, std=std)
            ])
        else:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomScale(min_scale=0.5, max_scale=1.5),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
        # transform_pseudo = Compose(
        #     [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
        #      Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        #      Normalize(mean=mean, std=std)])
        transform_test = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
            RandomScale(min_scale=0.5, max_scale=1.5),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std),
            LabelMap(label_id_map_city)
        ])
        if no_aug:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                Normalize(mean=mean, std=std)
            ])
        else:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                RandomScale(min_scale=0.5, max_scale=1.5),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
        # transform_pseudo = Compose(
        #     [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
        #      Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        #      Normalize(mean=mean, std=std),
        #      LabelMap(label_id_map_city)])
        transform_test = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
            Normalize(mean=mean, std=std),
            LabelMap(label_id_map_city)
        ])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(
        root=base,
        image_set='valtiny' if valtiny else 'val',
        transforms=transform_test,
        label_state=0,
        data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                             batch_size=batch_size_labeled +
                                             batch_size_pseudo,
                                             num_workers=workers,
                                             shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(
                root=base,
                image_set=(str(split) + '_labeled_' + str(sets_id)),
                transforms=transform_train,
                label_state=0,
                data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(
                dataset=labeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                mask_type='.npy',
                image_set=(str(split_u) + '_unlabeled_' + str(sets_id)),
                transforms=transform_train_pseudo,
                label_state=1)
            labeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                image_set=(str(split) + '_labeled_' + str(sets_id)),
                transforms=transform_train,
                label_state=0)
            pseudo_labeled_loader = torch.utils.data.DataLoader(
                dataset=pseudo_labeled_set,
                batch_size=batch_size_pseudo,
                num_workers=workers,
                shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(
                dataset=labeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=True)
            return labeled_loader, pseudo_labeled_loader, val_loader

        else:
            # Labeling
            unlabeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                mask_type='.npy',
                image_set=(str(split_u) + '_unlabeled_' + str(sets_id)),
                transforms=transform_test,
                label_state=2)
            unlabeled_loader = torch.utils.data.DataLoader(
                dataset=unlabeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=False)
            return unlabeled_loader
Example #4
0
def prepare_dataloaders(
    dataset_split,
    dataset_path,
    metadata_filename,
    batch_size=32,
    sample_size=-1,
    valid_split=0.1,
    test_split=0.1,
    num_worker=0,
    valid_metadata_filename=None,
    valid_dataset_dir=None,
):
    """
    Utility function to prepare dataloaders for training.

    Parameters
    ----------
    dataset_split : str
        Any of 'train', 'extra', 'test'.
    dataset_path : str
        Absolute path to the dataset. (i.e. .../data/SVHN/train')
    metadata_filename : str
        Absolute path to the metadata pickle file.
    batch_size : int
        Mini-batch size.
    sample_size : int
        Number of elements to use as sample size,
        for debugging purposes only. If -1, use all samples.
    valid_split : float
        Returns a validation split of %size; valid_split*100,
        valid_split should be in range [0,1].

    Returns
    -------
    if dataset_split in ['train', 'extra']:
        train_loader: torch.utils.DataLoader
            Dataloader containing training data.
        valid_loader: torch.utils.DataLoader
            Dataloader containing validation data.

    if dataset_split in ['test']:
        test_loader: torch.utils.DataLoader
            Dataloader containing test data.

    """

    assert dataset_split in ["train", "test", "extra"], "check dataset_split"

    metadata = load_obj(metadata_filename)

    #  dataset_path = datadir / dataset_split

    firstcrop = FirstCrop(0.3)
    downscale = Rescale((64, 64))
    random_crop = RandomCrop((54, 54))
    to_tensor = ToTensor()
    normalize = None
    # normalize = Normalize((0.434, 0.442, 0.473), (0.2, 0.202, 0.198))

    # Declare transformations

    transform = transforms.Compose(
        [firstcrop, downscale, random_crop, to_tensor])
    test_transform = transforms.Compose(
        [FirstCrop(0.1), Rescale((54, 54)), to_tensor])

    dataset = SVHNDataset(
        metadata,
        data_dir=dataset_path,
        transform=transform,
        normalize_transform=normalize,
    )

    dataset_length = len(metadata)

    indices = np.arange(dataset_length)
    # Only use a sample amount of data
    if sample_size != -1:
        indices = indices[:sample_size]
        dataset_length = sample_size

    if dataset_split in ["train", "extra"]:

        # Prepare a train and validation dataloader
        valid_loader = None
        if valid_dataset_dir is not None:
            valid_metadata = load_obj(valid_metadata_filename)
            valid_dataset = SVHNDataset(
                valid_metadata,
                data_dir=valid_dataset_dir,
                transform=test_transform,
                normalize_transform=normalize,
            )
            valid_loader = DataLoader(
                valid_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_worker,
            )

        train_sampler = torch.utils.data.SubsetRandomSampler(indices)
        train_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=train_sampler,
            num_workers=num_worker,
        )

        return train_loader, valid_loader

    elif dataset_split in ["test"]:

        test_sampler = torch.utils.data.SequentialSampler(indices)

        # change the transformer pipeline
        dataset.transform = test_transform

        # Prepare a test dataloader
        test_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=4,
            shuffle=False,
            sampler=test_sampler,
        )

        return test_loader
Example #5
0
def prepare_test_dataloader(dataset_path, metadata_filename, batch_size,
                            sample_size):
    """
    Utility function to prepare dataloaders for testing.

    Parameters of the configuration (cfg)
    -------------------------------------
    dataset_path : str
        Absolute path to the test dataset. (i.e. .../data/SVHN/test')
    metadata_filename : str
        Absolute path to the metadata pickle file.
    batch_size : int
        Mini-batch size.
    sample_size : int
        Number of elements to use as sample size,
        for debugging purposes only. If -1, use all samples.
    valid_split : float
        Returns a validation split of %size; valid_split*100,
        valid_split should be in range [0,1].

    Returns
    -------
        test_loader: torch.utils.DataLoader
            Dataloader containing test data.

    """

    firstcrop = FirstCrop(0.3)
    rescale = Rescale((64, 64))
    random_crop = RandomCrop((54, 54))
    to_tensor = ToTensor()

    # Declare transformations

    transform = transforms.Compose([
        firstcrop,
        rescale,
        random_crop,
        to_tensor,
    ])

    # Load metadata file
    metadata = load_obj(metadata_filename)

    # Initialize Dataset
    dataset = SVHNDataset(metadata, data_dir=dataset_path, transform=transform)

    indices = np.arange(len(metadata))

    # Only use a sample amount of data
    if sample_size != -1:
        indices = indices[:sample_size]

    test_sampler = torch.utils.data.SequentialSampler(indices)

    # Prepare a test dataloader
    test_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             num_workers=4,
                             shuffle=False,
                             sampler=test_sampler)
    return test_loader
Example #6
0
def prepare_dataloaders(cfg):

    dataset_path = cfg.INPUT_DIR
    metadata_filename = cfg.METADATA_FILENAME
    batch_size = cfg.TRAIN.BATCH_SIZE
    sample_size = cfg.TRAIN.SAMPLE_SIZE
    valid_split = cfg.TRAIN.VALID_SPLIT
    """
    Utility function to prepare dataloaders for training.

    Parameters of the configuration (cfg)
    -------------------------------------
    dataset_path : str
        Absolute path to the dataset. (i.e. .../data/SVHN/train')
    metadata_filename : str
        Absolute path to the metadata pickle file.
    batch_size : int
        Mini-batch size.
    sample_size : int
        Number of elements to use as sample size,
        for debugging purposes only. If -1, use all samples.
    valid_split : float
        Returns a validation split of %size; valid_split*100,
        valid_split should be in range [0,1].
    
    Returns
    -------
        train_loader: torch.utils.DataLoader
            Dataloader containing training data.
        valid_loader: torch.utils.DataLoader
            Dataloader containing validation data.

    """

    firstcrop = FirstCrop(0.3)
    rescale = Rescale((64, 64))
    random_crop = RandomCrop((54, 54))
    to_tensor = ToTensor()

    # Declare transformations

    transform = transforms.Compose([
        firstcrop,
        rescale,
        random_crop,
        to_tensor,
    ])

    # index 0 for train subset, 1 for extra subset
    metadata_train = load_obj(metadata_filename[0])
    metadata_extra = load_obj(metadata_filename[1])

    train_data_dir = dataset_path[0]
    extra_data_dir = dataset_path[1]

    valid_split_train = valid_split[0]
    valid_split_extra = valid_split[1]

    # Initialize the combined Dataset
    dataset = FullSVHNDataset(metadata_train,
                              metadata_extra,
                              train_data_dir,
                              extra_data_dir,
                              transform=transform)

    indices_train = np.arange(len(metadata_train))
    indices_extra = np.arange(len(metadata_train),
                              len(metadata_extra) + len(metadata_train))

    # Only use a sample amount of data
    if sample_size[0] != -1:
        indices_train = indices_train[:sample_size[0]]

    if sample_size[1] != -1:
        indices_extra = indices_extra[:sample_size[1]]

    # Select the indices to use for the train/valid split from the 'train' subset
    train_idx_train = indices_train[:round(valid_split_train *
                                           len(indices_train))]
    valid_idx_train = indices_train[round(valid_split_train *
                                          len(indices_train)):]

    # Select the indices to use for the train/valid split from the 'extra' subset
    train_idx_extra = indices_extra[:round(valid_split_extra *
                                           len(indices_extra))]
    valid_idx_extra = indices_extra[round(valid_split_extra *
                                          len(indices_extra)):]

    # Combine indices from 'train' and 'extra' as one single train/validation split
    train_idx = np.concatenate((train_idx_train, train_idx_extra))
    valid_idx = np.concatenate((valid_idx_train, valid_idx_extra))

    # Define the data samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

    # Prepare a train and validation dataloader
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=4,
                              sampler=train_sampler)

    valid_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=4,
                              sampler=valid_sampler)

    return train_loader, valid_loader
Example #7
0
def main():

    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))

    # writer
    if CONFIG.writer_flag:
        writer = SummaryWriter(CONFIG.result_path)
    else:
        writer = None

    # DataLoaders
    train_data = PASCALVOC(
        CONFIG,
        mode="train",
        transform=Compose([
            RandomCrop(CONFIG),
            Resize(CONFIG),
            RandomFlip(),
            ToTensor(),
            Normalize(mean=get_mean(), std=get_std()),
        ])
    )

    val_data = PASCALVOC(
        CONFIG,
        mode="val",
        transform=Compose([
            RandomCrop(CONFIG),
            Resize(CONFIG),
            ToTensor(),
            Normalize(mean=get_mean(), std=get_std()),
        ])
    )

    train_loader = DataLoader(
        train_data,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        drop_last=True
    )

    val_loader = DataLoader(
        val_data,
        batch_size=CONFIG.batch_size,
        shuffle=False,
        num_workers=CONFIG.num_workers
    )

    # load model
    print('\n------------------------Loading Model------------------------\n')

    if CONFIG.attention == 'dual':
        model = DANet(CONFIG)
        print('Dual Attintion modules will be added to this base model')
    elif CONFIG.attention == 'channel':
        model = CANet(CONFIG)
        print('Channel Attintion modules will be added to this base model')
    else:
        if CONFIG.model == 'drn_d_22':
            print(
                'Dilated ResNet D 22 w/o Dual Attention modules will be used as a model.')
            model = drn_d_22(pretrained=True, num_classes=CONFIG.n_classes)
        elif CONFIG.model == 'drn_d_38':
            print(
                'Dilated ResNet D 28 w/o Dual Attention modules will be used as a model.')
            model = drn_d_38(pretrained=True, num_classes=CONFIG.n_classes)
        else:
            print('There is no option you chose as a model.')
            print(
                'Therefore, Dilated ResNet D 22 w/o Dual Attention modules will be used as a model.')
            model = drn_d_22(pretrained=True, num_classes=CONFIG.n_classes)

    # set optimizer, lr_scheduler
    if CONFIG.optimizer == 'Adam':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate)
    elif CONFIG.optimizer == 'SGD':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.SGD(
            model.parameters(),
            lr=CONFIG.learning_rate,
            momentum=CONFIG.momentum,
            dampening=CONFIG.dampening,
            weight_decay=CONFIG.weight_decay,
            nesterov=CONFIG.nesterov)
    elif CONFIG.optimizer == 'AdaBound':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = adabound.AdaBound(
            model.parameters(),
            lr=CONFIG.learning_rate,
            final_lr=CONFIG.final_lr,
            weight_decay=CONFIG.weight_decay)
    else:
        print('There is no optimizer which suits to your option. \
            Instead, SGD will be used as an optimizer.')
        optimizer = optim.SGD(
            model.parameters(),
            lr=CONFIG.learning_rate,
            momentum=CONFIG.momentum,
            dampening=CONFIG.dampening,
            weight_decay=CONFIG.weight_decay,
            nesterov=CONFIG.nesterov)

    # learning rate scheduler
    if CONFIG.optimizer == 'SGD':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=CONFIG.lr_patience)
    else:
        scheduler = None

    # send the model to cuda/cpu
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    if device == 'cuda':
        model = torch.nn.DataParallel(model)  # make parallel
        torch.backends.cudnn.benchmark = True

    # resume if you want
    begin_epoch = 0
    if args.resume:
        if os.path.exists(os.path.join(CONFIG.result_path, 'checkpoint.pth')):
            print('loading the checkpoint...')
            begin_epoch, model, optimizer, scheduler = \
                resume(CONFIG, model, optimizer, scheduler)
            print('training will start from {} epoch'.format(begin_epoch))

    # criterion for loss
    if CONFIG.class_weight:
        criterion = nn.CrossEntropyLoss(
            weight=get_class_weight().to(device),
            ignore_index=255
        )
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=255)

    # train and validate model
    print('\n------------------------Start training------------------------\n')
    losses_train = []
    losses_val = []
    val_ious = []
    mean_ious = []
    mean_ious_without_bg = []
    best_mean_iou = 0.0

    for epoch in range(begin_epoch, CONFIG.max_epoch):
        # training
        loss_train = train(
            model, train_loader, criterion, optimizer, CONFIG, device)
        losses_train.append(loss_train)

        # validation
        val_iou, loss_val = validation(
            model, val_loader, criterion, CONFIG, device)
        val_ious.append(val_iou)
        losses_val.append(loss_val)
        if CONFIG.optimizer == 'SGD':
            scheduler.step(loss_val)

        mean_ious.append(val_ious[-1].mean().item())
        mean_ious_without_bg.append(val_ious[-1][1:].mean().item())

        # save checkpoint every 5 epoch
        if epoch % 5 == 0 and epoch != 0:
            save_checkpoint(CONFIG, epoch, model, optimizer, scheduler)

        # save a model every 50 epoch
        if epoch % 50 == 0 and epoch != 0:
            torch.save(
                model.state_dict(), os.path.join(CONFIG.result_path, 'epoch_{}_model.prm'.format(epoch)))

        if best_mean_iou < mean_ious[-1]:
            best_mean_iou = mean_ious[-1]
            torch.save(
                model.state_dict(), os.path.join(CONFIG.result_path, 'best_mean_iou_model.prm'))

        # tensorboardx
        if writer:
            writer.add_scalars(
                "loss", {
                    'loss_train': losses_train[-1],
                    'loss_val': losses_val[-1]}, epoch)
            writer.add_scalar(
                "mean_iou", mean_ious[-1], epoch)
            writer.add_scalar(
                "mean_iou_w/o_bg", mean_ious_without_bg[-1], epoch)

        print(
            'epoch: {}\tloss_train: {:.5f}\tloss_val: {:.5f}\tmean IOU: {:.3f}\tmean IOU w/o bg: {:.3f}'.format(
                epoch, losses_train[-1], losses_val[-1], mean_ious[-1], mean_ious_without_bg[-1])
        )

    torch.save(
        model.state_dict(), os.path.join(CONFIG.result_path, 'final_model.prm'))
Example #8
0
def create_loaders(dataset, inputs, train_dir, val_dir, train_list, val_list,
                   shorter_side, crop_size, input_size, low_scale, high_scale,
                   normalise_params, batch_size, num_workers, ignore_label):
    """
    Args:
      train_dir (str) : path to the root directory of the training set.
      val_dir (str) : path to the root directory of the validation set.
      train_list (str) : path to the training list.
      val_list (str) : path to the validation list.
      shorter_side (int) : parameter of the shorter_side resize transformation.
      crop_size (int) : square crop to apply during the training.
      low_scale (float) : lowest scale ratio for augmentations.
      high_scale (float) : highest scale ratio for augmentations.
      normalise_params (list / tuple) : img_scale, img_mean, img_std.
      batch_size (int) : training batch size.
      num_workers (int) : number of workers to parallelise data loading operations.
      ignore_label (int) : label to pad segmentation masks with

    Returns:
      train_loader, val loader

    """
    # Torch libraries
    from torchvision import transforms
    from torch.utils.data import DataLoader, random_split
    # Custom libraries
    from utils.datasets import SegDataset as Dataset
    from utils.transforms import Normalise, Pad, RandomCrop, RandomMirror, ResizeAndScale, \
                                 CropAlignToMask, ResizeAlignToMask, ToTensor, ResizeInputs

    input_names, input_mask_idxs = ['rgb', 'depth'], [0, 2, 1]

    AlignToMask = CropAlignToMask if dataset == 'nyudv2' else ResizeAlignToMask
    composed_trn = transforms.Compose([
        AlignToMask(),
        ResizeAndScale(shorter_side, low_scale, high_scale),
        Pad(crop_size, [123.675, 116.28 , 103.53], ignore_label),
        RandomMirror(),
        RandomCrop(crop_size),
        ResizeInputs(input_size),
        Normalise(*normalise_params),
        ToTensor()
    ])
    composed_val = transforms.Compose([
        AlignToMask(),
        ResizeInputs(input_size),
        Normalise(*normalise_params),
        ToTensor()
    ])
    # Training and validation sets
    trainset = Dataset(dataset=dataset, data_file=train_list, data_dir=train_dir,
                       input_names=input_names, input_mask_idxs=input_mask_idxs,
                       transform_trn=composed_trn, transform_val=composed_val,
                       stage='train', ignore_label=ignore_label)

    validset = Dataset(dataset=dataset, data_file=val_list, data_dir=val_dir,
                       input_names=input_names, input_mask_idxs=input_mask_idxs,
                       transform_trn=None, transform_val=composed_val, stage='val',
                       ignore_label=ignore_label)
    print_log('Created train set {} examples, val set {} examples'.format(len(trainset), len(validset)))
    # Training and validation loaders
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                              pin_memory=True, drop_last=True)
    val_loader = DataLoader(validset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader