Beispiel #1
0
def loader(config, ctx):
    """
    Description : dataloder for omniglot dataset
    """
    N = config.N
    K = config.K
    iterations = config.iterations
    batch_size = config.batch_size
    download = config.download

    train_dataset = OmniglotDataset(mode='train', download=download)
    test_dataset = OmniglotDataset(mode='test', download=download)

    tr_sampler = BatchSampler(labels=train_dataset.y,\
                                          classes_per_it=N,\
                                          num_samples=K,\
                                          iterations=iterations,\
                                          batch_size=batch_size)

    te_sampler = BatchSampler(labels=test_dataset.y,\
                                          classes_per_it=N,\
                                          num_samples=K,\
                                          iterations=iterations,\
                                          batch_size=int(batch_size / len(ctx)))

    tr_dataloader = DataLoader(train_dataset, batch_sampler=tr_sampler)
    te_dataloader = DataLoader(test_dataset, batch_sampler=te_sampler)

    return tr_dataloader, te_dataloader
Beispiel #2
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        train_dataset = OmniglotDataset(mode='train')
        val_dataset = OmniglotDataset(mode='val')
        trainval_dataset = OmniglotDataset(mode='trainval')
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        train_dataset = MiniImagenetDataset(mode='train')
        val_dataset = MiniImagenetDataset(mode='val')
        trainval_dataset = MiniImagenetDataset(mode='val')
        test_dataset = MiniImagenetDataset(mode='test')
    train_bs_class = BatchSampler
    eval_bs_class = BatchSampler
    if opt.task_shuffling == 'non_overlapping':
        train_bs_class = NonOverlappingTasksBatchSampler
    elif opt.task_shuffling == 'intratask':
        train_bs_class = IntraTaskBatchSampler
    # Opt for mini_imagenet:
    # Namespace(batch_size=32, cuda=True, dataset='mini_imagenet', epochs=100,
    # exp='mini_imagenet_5way_1shot', iterations=10000, lr=0.0001, num_cls=5, num_samples=1)
    tr_sampler = train_bs_class(labels=train_dataset.y,
                                classes_per_it=opt.num_cls,
                                num_samples=opt.num_samples,
                                iterations=opt.iterations,
                                batch_size=opt.batch_size)

    val_sampler = eval_bs_class(labels=val_dataset.y,
                                classes_per_it=opt.num_cls,
                                num_samples=opt.num_samples,
                                iterations=opt.iterations,
                                batch_size=opt.batch_size)

    trainval_sampler = eval_bs_class(labels=trainval_dataset.y,
                                     classes_per_it=opt.num_cls,
                                     num_samples=opt.num_samples,
                                     iterations=opt.iterations,
                                     batch_size=opt.batch_size)

    test_sampler = eval_bs_class(labels=test_dataset.y,
                                 classes_per_it=opt.num_cls,
                                 num_samples=opt.num_samples,
                                 iterations=opt.iterations,
                                 batch_size=opt.batch_size)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(
        trainval_dataset, batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    train_dataset = OmniglotDataset(mode='train',
                                    root=opt.dataset_root)

    val_dataset = OmniglotDataset(mode='val',
                                  root=opt.dataset_root)

    trainval_dataset = OmniglotDataset(mode='trainval',
                                       root=opt.dataset_root)

    test_dataset = OmniglotDataset(mode='test',
                                   root=opt.dataset_root)

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_support=opt.num_support_tr,
                                          num_query=opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(labels=val_dataset.y,
                                           classes_per_it=opt.classes_per_it_val,
                                           num_support=opt.num_support_val,
                                           num_query=opt.num_query_val,
                                           iterations=opt.iterations)

    trainval_sampler = PrototypicalBatchSampler(labels=trainval_dataset.y,
                                                classes_per_it=opt.classes_per_it_tr,
                                                num_support=opt.num_support_tr,
                                                num_query=opt.num_query_tr,
                                                iterations=opt.iterations)

    test_sampler = PrototypicalBatchSampler(labels=test_dataset.y,
                                            classes_per_it=opt.classes_per_it_val,
                                            num_support=opt.num_support_val,
                                            num_query=opt.num_query_val,
                                            iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(trainval_dataset,
                                                      batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
Beispiel #4
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        train_dataset = OmniglotDataset(mode='train')
        val_dataset = OmniglotDataset(mode='val')
        trainval_dataset = OmniglotDataset(mode='trainval')
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        train_dataset = MiniImagenetDataset(mode='train')
        val_dataset = MiniImagenetDataset(mode='val')
        trainval_dataset = MiniImagenetDataset(mode='val')
        test_dataset = MiniImagenetDataset(mode='test')

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_samples=opt.num_support_tr +
                                          opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(
        labels=val_dataset.y,
        classes_per_it=opt.classes_per_it_val,
        num_samples=opt.num_support_val + opt.num_query_val,
        iterations=opt.iterations)

    trainval_sampler = PrototypicalBatchSampler(
        labels=trainval_dataset.y,
        classes_per_it=opt.classes_per_it_tr,
        num_samples=opt.num_support_tr + opt.num_query_tr,
        iterations=opt.iterations)

    test_sampler = PrototypicalBatchSampler(
        labels=test_dataset.y,
        classes_per_it=opt.classes_per_it_val,
        num_samples=opt.num_support_val + opt.num_query_val,
        iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(
        trainval_dataset, batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
Beispiel #5
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        train_dataset = OmniglotDataset(mode='train')
        val_dataset = OmniglotDataset(mode='val')
        trainval_dataset = OmniglotDataset(mode='trainval')
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        train_dataset = MiniImagenetDataset(mode='train')
        val_dataset = MiniImagenetDataset(mode='val')
        trainval_dataset = MiniImagenetDataset(mode='val')
        test_dataset = MiniImagenetDataset(mode='test')

    tr_sampler = BatchSampler(labels=train_dataset.y,
                              classes_per_it=opt.num_cls,
                              num_samples=opt.num_samples,
                              iterations=opt.iterations,
                              batch_size=opt.batch_size)

    val_sampler = BatchSampler(labels=val_dataset.y,
                               classes_per_it=opt.num_cls,
                               num_samples=opt.num_samples,
                               iterations=opt.iterations,
                               batch_size=opt.batch_size)

    trainval_sampler = BatchSampler(labels=trainval_dataset.y,
                                    classes_per_it=opt.num_cls,
                                    num_samples=opt.num_samples,
                                    iterations=opt.iterations,
                                    batch_size=opt.batch_size)

    test_sampler = BatchSampler(labels=test_dataset.y,
                                classes_per_it=opt.num_cls,
                                num_samples=opt.num_samples,
                                iterations=opt.iterations,
                                batch_size=opt.batch_size)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(
        trainval_dataset, batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
Beispiel #6
0
def init_dataset(opt, mode):

    if opt.dataset == 0:
        dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)
        n_classes = len(np.unique(dataset.y))
        if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
            raise(Exception('There are not enough classes in the dataset in order ' +
                            'to satisfy the chosen classes_per_it. Decrease the ' +
                            'classes_per_it_{tr/val} option and try again.'))
    elif opt.dataset == 1:
        dataset = MiniImageNet(mode)
        n_classes = len(dataset.wnids)
        if mode == "train" and n_classes < opt.classes_per_it_tr:
            raise(Exception('There are not enough classes in the dataset in order ' +
                            'to satisfy the chosen classes_per_it. Decrease the ' +
                            'classes_per_it_{tr/val} option and try again.'))
        elif mode == "val" and n_classes < opt.classes_per_it_val:
            raise(Exception('There are not enough classes in the dataset in order ' +
                            'to satisfy the chosen classes_per_it. Decrease the ' +
                            'classes_per_it_{tr/val} option and try again.'))
        elif mode == "test" and n_classes < opt.classes_per_it_val:
            raise(Exception('There are not enough classes in the dataset in order ' +
                            'to satisfy the chosen classes_per_it. Decrease the ' +
                            'classes_per_it_{tr/val} option and try again.'))
    else:
        raise(Exception("No such dataset!!"))
    return dataset
def init_dataset(opt, mode):
    dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)
    n_classes = len(np.unique(dataset.y))
    #if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
    #   raise(Exception('There are not enough classes in the dataset in order ' +
    #                  'to satisfy the chosen classes_per_it. Decrease the ' +
    #                 'classes_per_it_{tr/val} option and try again.'))
    return dataset
Beispiel #8
0
def init_dataset(opt, mode):
    dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)
    # unique返回一个不包含重复数据的List
    # dataset.x为tensor格式图片, y为label
    n_classes = len(np.unique(dataset.y))
    if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
        raise (
            Exception('There are not enough classes in the dataset in order ' +
                      'to satisfy the chosen classes_per_it. Decrease the ' +
                      'classes_per_it_{tr/val} option and try again.'))
    return dataset
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    train_dataset = OmniglotDataset(mode='train',
                                    root=opt.dataset_root,
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                                    ]))

    val_dataset = OmniglotDataset(mode='val',
                                  root=opt.dataset_root,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                                  ]))

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_support=opt.num_support_tr,
                                          num_query=opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(labels=val_dataset.y,
                                           classes_per_it=opt.classes_per_it_val,
                                           num_support=opt.num_support_val,
                                           num_query=opt.num_query_val,
                                           iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)
    return tr_dataloader, val_dataloader
Beispiel #10
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        test_dataset = MiniImagenetDataset(mode='val')
    else:
        print('Dataset is not valid')
    test_sampler = PrototypicalBatchSampler(labels=test_dataset.y,
                                            classes_per_it=opt.classes_per_it_val,
                                            num_samples=opt.num_support_val + opt.num_query_val,
                                            iterations=opt.iterations)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return test_dataloader
    def set_data_loader(self):

        if self.arg_settings.data == 'omniglot':
            dataset = OmniglotDataset(mode=self.mode,
                                      root=self.arg_settings.dataset_root)
            num_classes = len(np.unique(dataset.y))
            sampler = self.create_sampler(dataset.y)

        # check if number of classes in dataset is sufficient
        if num_classes < self.arg_settings.classes_per_it_tr or num_classes < self.arg_settings.classes_per_it_val:
            raise (Exception(
                'There are not enough classes in the dataset in order ' +
                'to satisfy the chosen classes_per_it. Decrease the ' +
                'classes_per_it_{tr/val} option and try again.'))

        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_sampler=sampler)
        return dataloader