示例#1
0
def make_data_loader(config, phase, batch_size, num_threads=0, shuffle=None):
    assert phase in ['train', 'trainval', 'val', 'test']
    if shuffle is None:
        # 如果是训练则默认打乱
        shuffle = phase != 'test'

    if config.dataset not in dataset_str_mapping.keys():
        logging.error(f'Dataset {config.dataset}, does not exists in ' +
                      ', '.join(dataset_str_mapping.keys()))

    Dataset = dataset_str_mapping[config.dataset]

    use_random_scale = False
    use_random_rotation = False
    transforms = []
    if phase in ['train', 'trainval']:
        use_random_rotation = config.use_random_rotation #True
        use_random_scale = config.use_random_scale #False
        transforms += [t.Jitter()]

    dset = Dataset(
        phase,
        transform=t.Compose(transforms),
        random_scale=use_random_scale,
        random_rotation=use_random_rotation,
        config=config)

    loader = torch.utils.data.DataLoader(
        dset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_threads,
        collate_fn=collate_pair_fn,
        pin_memory=False,
        drop_last=True)

    return loader
def make_data_loader(config, batch_size, num_threads=0):

    if config.data.dataset not in dataset_str_mapping.keys():
        logging.error(f'Dataset {config.data.dataset}, does not exists in ' +
                      ', '.join(dataset_str_mapping.keys()))

    Dataset = dataset_str_mapping[config.data.dataset]

    transforms = []
    use_random_rotation = config.trainer.use_random_rotation
    use_random_scale = config.trainer.use_random_scale
    transforms += [t.Jitter()]

    dset = Dataset(phase="train",
                   transform=t.Compose(transforms),
                   random_scale=use_random_scale,
                   random_rotation=use_random_rotation,
                   config=config)
    collate_pair_fn = default_collate_pair_fn
    batch_size = batch_size // config.misc.num_gpus

    if config.misc.num_gpus > 1:
        sampler = DistributedInfSampler(dset)
    else:
        sampler = None

    loader = torch.utils.data.DataLoader(dset,
                                         batch_size=batch_size,
                                         shuffle=False if sampler else True,
                                         num_workers=num_threads,
                                         collate_fn=collate_pair_fn,
                                         pin_memory=False,
                                         sampler=sampler,
                                         drop_last=True)

    return loader
示例#3
0
def make_data_loader(config, phase, batch_size, num_threads=0, shuffle=None):
    assert phase in ['train', 'trainval', 'val', 'test']
    if shuffle is None:
        shuffle = phase != 'test'

    if config.dataset not in dataset_str_mapping.keys():
        logging.error(f'Dataset {config.dataset}, does not exists in ' +
                      ', '.join(dataset_str_mapping.keys()))

    Dataset = dataset_str_mapping[config.dataset]

    use_random_scale = False
    use_random_rotation = False
    transforms = []
    if phase in ['train', 'trainval']:
        use_random_rotation = config.use_random_rotation
        use_random_scale = config.use_random_scale
        transforms += [t.Jitter()]

    if config.dataset == "KITTIMapDataset":
        #import sys
        import importlib

        split = phase
        cfg = importlib.import_module("ext.benchmark_tools.configs.config")
        dset = Dataset(split, cfg,
                       phase,
                       transform=t.Compose(transforms),
                       random_scale=use_random_scale,
                       random_rotation=use_random_rotation,
                       config=config)

    elif config.dataset == "ArgoverseMapDataset":
        #import sys
        import importlib

        split = phase
        cfg = importlib.import_module("ext.benchmark_tools.configs.config")
        dset = Dataset(split, cfg,
                       phase,
                       transform=t.Compose(transforms),
                       random_scale=use_random_scale,
                       random_rotation=use_random_rotation,
                       config=config)

    else:
        dset = Dataset(
            phase,
            transform=t.Compose(transforms),
            random_scale=use_random_scale,
            random_rotation=use_random_rotation,
            config=config)


    loader = torch.utils.data.DataLoader(
        dset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_threads,
        collate_fn=collate_pair_fn,
        pin_memory=False,
        drop_last=True)

    return loader