示例#1
0
def get_dataloader(module_name, module_args, num_label):
    train_transfroms = transforms.Compose(
        [transforms.RandomColorJitter(brightness=0.5),
         transforms.ToTensor()])

    val_transfroms = transforms.ToTensor()
    # 创建数据集
    dataset_args = module_args['dataset']
    train_data_path = dataset_args.pop('train_data_path')
    train_data_ratio = dataset_args.pop('train_data_ratio')
    val_data_path = dataset_args.pop('val_data_path')

    train_data_list, val_data_list = get_datalist(
        train_data_path, val_data_path,
        module_args['loader']['validation_split'])

    dataset_args['num_label'] = num_label
    train_dataset_list = []
    for train_data in train_data_list:
        train_dataset_list.append(
            get_dataset(data_list=train_data,
                        module_name=module_name,
                        phase='train',
                        dataset_args=dataset_args))

    if len(train_dataset_list) > 1:
        train_loader = dataset.Batch_Balanced_Dataset(
            dataset_list=train_dataset_list,
            ratio_list=train_data_ratio,
            module_args=module_args,
            dataset_transfroms=train_transfroms,
            phase='train')
    elif len(train_dataset_list) == 1:
        train_loader = DataLoader(
            dataset=train_dataset_list[0].transform_first(val_transfroms),
            batch_size=module_args['loader']['train_batch_size'],
            shuffle=module_args['loader']['shuffle'],
            last_batch='rollover',
            num_workers=module_args['loader']['num_workers'])
        train_loader.dataset_len = len(train_dataset_list[0])
    else:
        raise Exception('no images found')
    if len(val_data_list):
        val_dataset = get_dataset(data_list=val_data_list,
                                  module_name=module_name,
                                  phase='test',
                                  dataset_args=dataset_args)
        val_loader = DataLoader(
            dataset=val_dataset.transform_first(val_transfroms),
            batch_size=module_args['loader']['val_batch_size'],
            shuffle=module_args['loader']['shuffle'],
            last_batch='rollover',
            num_workers=module_args['loader']['num_workers'])
        val_loader.dataset_len = len(val_dataset)
    else:
        val_loader = None
    return train_loader, val_loader
示例#2
0
def get_dataloader(module_config, num_label, alphabet):
    if module_config is None:
        return None
    config = copy.deepcopy(module_config)
    dataset_args = config['dataset']['args']
    dataset_args['num_label'] = num_label
    dataset_args['alphabet'] = alphabet
    if 'transforms' in dataset_args:
        img_transfroms = get_transforms(dataset_args.pop('transforms'))
    else:
        img_transfroms = None
    # 创建数据集
    dataset_name = config['dataset']['type']
    data_path_list = dataset_args.pop('data_path')
    if 'data_ratio' in dataset_args:
        data_ratio = dataset_args.pop('data_ratio')
    else:
        data_ratio = [1.0]

    _dataset_list = []
    for data_path in data_path_list:
        _dataset_list.append(
            get_dataset(data_path=data_path,
                        module_name=dataset_name,
                        dataset_args=dataset_args))
    if len(data_ratio) > 1 and len(
            dataset_args['data_ratio']) == len(_dataset_list):
        from . import dataset
        loader = dataset.Batch_Balanced_Dataset(
            dataset_list=_dataset_list,
            ratio_list=data_ratio,
            loader_args=config['loader'],
            dataset_transfroms=img_transfroms,
            phase='train')
    else:
        _dataset = _dataset_list[0]
        loader = DataLoader(dataset=_dataset.transform_first(img_transfroms),
                            **config['loader'])
        loader.dataset_len = len(_dataset)
    return loader