Beispiel #1
0
def get_dataloader(module_name, module_args):
    if module_args['dataset']['img_type'] == 'cv':
        from opencv_transforms import opencv_transforms as transforms
    else:
        from torchvision import transforms

    train_transfroms = transforms.Compose(
        [transforms.ColorJitter(brightness=0.5),
         transforms.ToTensor()])
    val_transfroms = transforms.ToTensor()

    # 创建数据集
    dataset_args = copy.deepcopy(module_args['dataset'])
    train_data_path = dataset_args.pop('train_data_path')
    train_data_ratio = dataset_args.pop('train_data_ratio')
    val_data_path = dataset_args.pop('val_data_path')

    train_data_list, val_data_list = get_datalist(
        train_data_path, val_data_path,
        module_args['loader']['validation_split'])

    train_dataset_list = []
    for train_data in train_data_list:
        train_dataset_list.append(
            get_dataset(data_list=train_data,
                        module_name=module_name,
                        transform=train_transfroms,
                        phase='train',
                        dataset_args=dataset_args))
    if len(train_dataset_list) > 1:
        train_loader = dataset.Batch_Balanced_Dataset(
            dataset_list=train_dataset_list,
            ratio_list=train_data_ratio,
            module_args=module_args,
            phase='train')
    elif len(train_dataset_list) == 1:
        train_loader = DataLoader(
            dataset=train_dataset_list[0],
            batch_size=module_args['loader']['train_batch_size'],
            shuffle=module_args['loader']['shuffle'],
            num_workers=module_args['loader']['num_workers'])
        train_loader.dataset_len = len(train_dataset_list[0])
    else:
        raise Exception('no images found')

    if len(val_data_list):
        val_dataset = get_dataset(data_list=val_data_list,
                                  module_name=module_name,
                                  transform=val_transfroms,
                                  phase='test',
                                  dataset_args=dataset_args)
        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=module_args['loader']['val_batch_size'],
            shuffle=module_args['loader']['shuffle'],
            num_workers=module_args['loader']['num_workers'])
        val_loader.dataset_len = len(val_dataset)
    else:
        val_loader = None
    return train_loader, val_loader
Beispiel #2
0
def get_dataloader(module_config, num_label):
    if module_config is None:
        return None
    config = copy.deepcopy(module_config)
    dataset_args = config['dataset']['args']
    dataset_args['num_label'] = num_label
    if 'transforms' in dataset_args:
        img_transfroms = get_transforms(dataset_args.pop('transforms'))
    else:
        img_transfroms = None
    # 创建数据集
    dataset_name = config['dataset']['type']
    data_path_list = dataset_args.pop('data_path')
    if 'data_ratio' in dataset_args:
        data_ratio = dataset_args.pop('data_ratio')
    else:
        data_ratio = [1.0]

    _dataset_list = []
    for data_path in data_path_list:
        _dataset_list.append(get_dataset(data_path=data_path, module_name=dataset_name, dataset_args=dataset_args, transform=img_transfroms))
    if len(data_ratio) > 1 and len(dataset_args['data_ratio']) == len(_dataset_list):
        from . import dataset
        loader = dataset.Batch_Balanced_Dataset(dataset_list=_dataset_list, ratio_list=data_ratio, loader_args=config['loader'])
    else:
        _dataset = _dataset_list[0]
        loader = DataLoader(dataset=_dataset, **config['loader'])
        loader.dataset_len = len(_dataset)
    return loader
def get_dataloader(module_name, module_args):
    train_transfroms = transforms.Compose([
        transforms.ColorJitter(brightness=0.5),
        transforms.ToTensor()
    ])

    # 创建数据集
    dataset_args = copy.deepcopy(module_args['dataset'])
    train_data_path = dataset_args.pop('train_data_path')
    train_data_ratio = dataset_args.pop('train_data_ratio')
    dataset_args.pop('val_data_path')
    train_data_list = get_datalist(train_data_path, module_args['loader']['validation_split'])
    train_dataset_list = []
    for train_data in train_data_list:
        train_dataset_list.append(get_dataset(data_list=train_data,
                                              module_name=module_name,
                                              transform=train_transfroms,
                                              dataset_args=dataset_args))

    if len(train_dataset_list) > 1:
        train_loader = dataset.Batch_Balanced_Dataset(dataset_list=train_dataset_list,
                                                      ratio_list=train_data_ratio,
                                                      module_args=module_args,
                                                      phase='train')
    elif len(train_dataset_list) == 1:
        train_loader = DataLoader(dataset=train_dataset_list[0],
                                  batch_size=module_args['loader']['train_batch_size'],
                                  shuffle=module_args['loader']['shuffle'],
                                  num_workers=module_args['loader']['num_workers'])
        train_loader.dataset_len = len(train_dataset_list[0])
    else:
        raise Exception('no images found')
    return train_loader