Example #1
0
def get_train_loader(cfg, num_gpu, is_dist=True, is_shuffle=True, start_iter=0, 
                     use_augmentation=True, with_mds=False):
    # -------- get raw dataset interface -------- #
    normalize = transforms.Normalize(mean=cfg.INPUT.MEANS, std=cfg.INPUT.STDS)
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    if cfg.DATASET.NAME == 'MIX':
        Dataset = JointDataset
    else:
        raise NameError("Dataset is not defined!", cfg.DATASET.NAME)

    dataset = Dataset(cfg, 'train', transform, use_augmentation, with_mds)

    # -------- make samplers -------- #
    if is_dist:
        sampler = torch_samplers.DistributedSampler(
                dataset, shuffle=is_shuffle)
    elif is_shuffle:
        sampler = torch.utils.data.sampler.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.sampler.SequentialSampler(dataset)

    images_per_gpu = cfg.SOLVER.IMG_PER_GPU

    aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []
    if aspect_grouping:
        batch_sampler = torch_samplers.GroupedBatchSampler(
                sampler, dataset, aspect_grouping, images_per_gpu,
                drop_uneven=False)
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
                sampler, images_per_gpu, drop_last=False)

    batch_sampler = torch_samplers.IterationBasedBatchSampler(
            batch_sampler, cfg.SOLVER.MAX_ITER, start_iter)

    # -------- make data_loader -------- #
    class BatchCollator(object):
        def __init__(self, size_divisible):
            self.size_divisible = size_divisible

        def __call__(self, batch):
            transposed_batch = list(zip(*batch))
            images = torch.stack(transposed_batch[0], dim=0)
            valids = torch.stack(transposed_batch[1], dim=0)
            labels = torch.stack(transposed_batch[2], dim=0)
            rdepth = torch.stack(transposed_batch[3], dim=0)
            return images, valids, labels, rdepth

    data_loader = torch.utils.data.DataLoader(
            dataset, num_workers=cfg.DATALOADER.NUM_WORKERS,
            batch_sampler=batch_sampler,
            collate_fn=BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY), )

    return data_loader
Example #2
0
def get_train_loader(cfg,
                     num_gpu,
                     is_dist=True,
                     is_shuffle=True,
                     start_iter=0):
    # -------- get raw dataset interface -------- #
    normalize = transforms.Normalize(mean=cfg.INPUT.MEANS, std=cfg.INPUT.STDS)
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    attr = load_dataset(cfg.DATASET.NAME)
    if cfg.DATASET.NAME == 'COCO':
        Dataset = COCODataset
    elif cfg.DATASET.NAME == 'MPII':
        Dataset = MPIIDataset
    dataset = Dataset(attr, 'train', transform)
    #print("测试dataset类:"+str(dataset.data_num))

    # -------- make samplers -------- #
    if is_dist:
        sampler = torch_samplers.DistributedSampler(dataset,
                                                    shuffle=is_shuffle)
    elif is_shuffle:
        sampler = torch.utils.data.sampler.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.sampler.SequentialSampler(dataset)

    images_per_gpu = cfg.SOLVER.IMS_PER_GPU
    # images_per_gpu = cfg.SOLVER.IMS_PER_BATCH // num_gpu

    aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []
    if aspect_grouping:
        batch_sampler = torch_samplers.GroupedBatchSampler(sampler,
                                                           dataset,
                                                           aspect_grouping,
                                                           images_per_gpu,
                                                           drop_uneven=False)
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(sampler,
                                                              images_per_gpu,
                                                              drop_last=False)

    batch_sampler = torch_samplers.IterationBasedBatchSampler(
        batch_sampler, cfg.SOLVER.MAX_ITER, start_iter)

    # -------- make data_loader -------- #
    class BatchCollator(object):
        def __init__(self, size_divisible):
            self.size_divisible = size_divisible

        def __call__(self, batch):
            transposed_batch = list(zip(*batch))
            images = torch.stack(transposed_batch[0], dim=0)
            valids = torch.stack(transposed_batch[1], dim=0)
            labels = torch.stack(transposed_batch[2], dim=0)

            return images, valids, labels

    data_loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        batch_sampler=batch_sampler,
        collate_fn=BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY),
    )

    return data_loader