Пример #1
0
def LoadData(MNISTsaveFolder,imsize=[54,98],train = True,batch_size=32,num_works=0,DataSetName = MNIST):
    # original image size is [28,28]
    # data_set = DealDataset(imsize=imsize)
    datamean      = 0.5
    datastd       = 0.5
    Trans    = trasnFcn(imsize,datamean = datamean, datastd = datastd)
    if train:
        if isinstance(DataSetName ,tuple):
            data_set = []
            for SUBsetName in DataSetName:
                data_set.append ( SUBsetName(root=MNISTsaveFolder, train=True, transform=Trans , download=True))
            data_set = ConcatDataset(data_set)
        else:
            data_set = DataSetName(root=MNISTsaveFolder, train=True, transform=Trans , download=True)
    else:
        if isinstance(DataSetName ,tuple):
            data_set = []
            for SUBsetName in DataSetName:
                data_set.append ( SUBsetName(root=MNISTsaveFolder, train=False, transform=Trans , download=True))
            data_set = ConcatDataset(data_set)
        else:
            data_set = DataSetName(root=MNISTsaveFolder, train=False, transform=Trans)  
    
    dataLoader = DataLoader(dataset= data_set,batch_size=batch_size, shuffle = True, num_workers=num_works,drop_last=True)
    return dataLoader
Пример #2
0
def get_train_loader(conf):
    if conf.data_mode in ['ms1m', 'concat']:
        ms1m_ds, ms1m_class_num = get_train_dataset(conf.ms1m_folder/'imgs')
        print('ms1m loader generated')
    if conf.data_mode in ['vgg', 'concat']:
        vgg_ds, vgg_class_num = get_train_dataset(conf.vgg_folder/'imgs')
        print('vgg loader generated')        
    if conf.data_mode == 'vgg':
        ds = vgg_ds
        class_num = vgg_class_num
    elif conf.data_mode == 'ms1m':
        ds = ms1m_ds
        class_num = ms1m_class_num
    elif conf.data_mode == 'concat':
        for i,(url,label) in enumerate(vgg_ds.imgs):
            vgg_ds.imgs[i] = (url, label + ms1m_class_num)
        ds = ConcatDataset([ms1m_ds,vgg_ds])
        class_num = vgg_class_num + ms1m_class_num
    elif conf.data_mode == 'emore':
        ds, class_num = get_train_dataset(conf.emore_folder/'imgs')
    elif conf.data_mode == 'African':
        ds, class_num = get_train_dataset(conf.ccf_folder/'African')
    elif conf.data_mode == 'Caucasian':
        ds, class_num = get_train_dataset(conf.ccf_folder/'Caucasian')
    elif conf.data_mode == 'Asian':
        ds, class_num = get_train_dataset(conf.ccf_folder/'Asian')
    elif conf.data_mode == 'Indian':
        ds, class_num = get_train_dataset(conf.ccf_folder/'Indian')
    elif conf.data_mode =='ccf':
        ds = []
        class_num = []
        imgs_num = []
        for path in conf.ccf_folder.iterdir():
            if path.is_file():
                continue
            else:
                ds_tmp, class_num_tmp = get_train_dataset(path)
                ds.append(ds_tmp)
                class_num.append(class_num_tmp)
                imgs_num.append(len(ds_tmp))
        for j,sub_ds in enumerate(ds):
            for i,(url,label) in enumerate(sub_ds.imgs):
                if j>0:
                    sub_ds.imgs[i] = (url, label + sum(class_num[:j]))
        ds = ConcatDataset(ds)
        # ds = ccf_test_dataset(conf.ccf_folder)
        # class_num = ds.class_num()
    print('##################################')
    print(conf.batch_size)
    conf.race_num = class_num
    weights = []
    for i in range(4):
        weights+=[sum(class_num)//class_num[i] for j in range(imgs_num[i])]
    print(len(ds))
    print(len(weights))
    assert len(ds) ==len(weights)
    weights = torch.FloatTensor(weights)
    
    train_sampler = WeightedRandomSampler(weights,len(ds),replacement=True)

    loader = DataLoader(ds, batch_size=conf.batch_size, sampler = train_sampler, pin_memory=conf.pin_memory, num_workers=conf.num_workers)

    if isinstance(class_num,list):
        class_num = sum(class_num)
    return loader, class_num 
Пример #3
0
    dataset = ConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
    val_dataset = ScaperLoader(folder=params['validation_folder'],
                               length=params['initial_length'],
                               n_fft=params['n_fft'],
                               hop_length=params['hop_length'],
                               output_type=params['target_type'],
                               group_sources=params['group_sources'],
                               ignore_sources=params['ignore_sources'],
                               source_labels=params['source_labels'])
elif params['dataset_type'] == 'wsj':
    for i in range(len(params['training_folder'])):
        dataset.append(WSJ0(folder=params['training_folder'][i],
                            length=params['initial_length'],
                            n_fft=params['n_fft'],
                            hop_length=params['hop_length'],
                            output_type=params['target_type'],
                            weight_method=params['weight_method'],
                            create_cache=params['create_cache'],
                            num_channels=1))

    dataset = ConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
    target_type = params['target_type'] if params['target_type'] != 'spatial_bootstrap' else 'psa'
    val_dataset = WSJ0(folder=params['validation_folder'],
                       length='full',
                       n_fft=params['n_fft'],
                       hop_length=params['hop_length'],
                       output_type=target_type,
                       create_cache=True, #params['create_cache'],
                       num_channels=1)

if args.sample_strategy == 'sequential':