def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    train_dataset = OmniglotDataset(mode='train',
                                    root=opt.dataset_root)

    val_dataset = OmniglotDataset(mode='val',
                                  root=opt.dataset_root)

    trainval_dataset = OmniglotDataset(mode='trainval',
                                       root=opt.dataset_root)

    test_dataset = OmniglotDataset(mode='test',
                                   root=opt.dataset_root)

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_support=opt.num_support_tr,
                                          num_query=opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(labels=val_dataset.y,
                                           classes_per_it=opt.classes_per_it_val,
                                           num_support=opt.num_support_val,
                                           num_query=opt.num_query_val,
                                           iterations=opt.iterations)

    trainval_sampler = PrototypicalBatchSampler(labels=trainval_dataset.y,
                                                classes_per_it=opt.classes_per_it_tr,
                                                num_support=opt.num_support_tr,
                                                num_query=opt.num_query_tr,
                                                iterations=opt.iterations)

    test_sampler = PrototypicalBatchSampler(labels=test_dataset.y,
                                            classes_per_it=opt.classes_per_it_val,
                                            num_support=opt.num_support_val,
                                            num_query=opt.num_query_val,
                                            iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(trainval_dataset,
                                                      batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
Пример #2
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        train_dataset = OmniglotDataset(mode='train')
        val_dataset = OmniglotDataset(mode='val')
        trainval_dataset = OmniglotDataset(mode='trainval')
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        train_dataset = MiniImagenetDataset(mode='train')
        val_dataset = MiniImagenetDataset(mode='val')
        trainval_dataset = MiniImagenetDataset(mode='val')
        test_dataset = MiniImagenetDataset(mode='test')

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_samples=opt.num_support_tr +
                                          opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(
        labels=val_dataset.y,
        classes_per_it=opt.classes_per_it_val,
        num_samples=opt.num_support_val + opt.num_query_val,
        iterations=opt.iterations)

    trainval_sampler = PrototypicalBatchSampler(
        labels=trainval_dataset.y,
        classes_per_it=opt.classes_per_it_tr,
        num_samples=opt.num_support_tr + opt.num_query_tr,
        iterations=opt.iterations)

    test_sampler = PrototypicalBatchSampler(
        labels=test_dataset.y,
        classes_per_it=opt.classes_per_it_val,
        num_samples=opt.num_support_val + opt.num_query_val,
        iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(
        trainval_dataset, batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader
Пример #3
0
def init_sampler(opt, labels, mode, n_proto_support):
    '''
    Initialize the sampler
    '''
    if 'train' in mode:
        classes, counts = np.unique(labels, return_counts=True)
        classes_per_it = len(classes)
        num_samples = n_proto_support + opt.num_query_tr
        for idx, count in enumerate(counts):
            if num_samples > count:
                print(
                    "*** Error ***: You do not have enough samples in class {} for training -- {} samples vs {} needed."
                    .format(idx, count, num_samples))
                sys.exit()
    else:
        classes, counts = np.unique(labels, return_counts=True)
        classes_per_it = len(classes)
        num_samples = n_proto_support + opt.num_query_val
        for idx, count in enumerate(counts):
            if num_samples > count:
                print(
                    "*** Error ***: You do not have enough samples in class {} for testing -- {} samples vs {} needed."
                    .format(idx, count, num_samples))
                sys.exit()
    return PrototypicalBatchSampler(labels=labels,
                                    classes_per_it=classes_per_it,
                                    num_samples=num_samples,
                                    iterations=opt.iterations)
Пример #4
0
def init_sampler(opt, labels, mode):
    if 'train' in mode:
        classes_per_it = opt.classes_per_it_tr
        num_samples = opt.num_support_tr + opt.num_query_tr
    else:
        classes_per_it = opt.classes_per_it_val
        num_samples = opt.num_support_val + opt.num_query_val

    return PrototypicalBatchSampler(labels=labels,
                                    classes_per_it=classes_per_it,
                                    num_samples=num_samples,
                                    iterations=opt.iterations)
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    train_dataset = OmniglotDataset(mode='train',
                                    root=opt.dataset_root,
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                                    ]))

    val_dataset = OmniglotDataset(mode='val',
                                  root=opt.dataset_root,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                                  ]))

    tr_sampler = PrototypicalBatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.classes_per_it_tr,
                                          num_support=opt.num_support_tr,
                                          num_query=opt.num_query_tr,
                                          iterations=opt.iterations)

    val_sampler = PrototypicalBatchSampler(labels=val_dataset.y,
                                           classes_per_it=opt.classes_per_it_val,
                                           num_support=opt.num_support_val,
                                           num_query=opt.num_query_val,
                                           iterations=opt.iterations)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)
    return tr_dataloader, val_dataloader
Пример #6
0
def init_sampler(opt, labels, mode):
    if 'train' in mode:
        # 训练模式 每个小批的类数初始化
        classes_per_it = opt.classes_per_it_tr
        # 取样数量为训练支持集和查询集之和
        num_samples = opt.num_support_tr + opt.num_query_tr
    else:
        # 验证模式
        classes_per_it = opt.classes_per_it_val
        num_samples = opt.num_support_val + opt.num_query_val

    return PrototypicalBatchSampler(labels=labels,
                                    classes_per_it=classes_per_it,
                                    num_samples=num_samples,
                                    iterations=opt.iterations)
def init_sampler(opt, labels, mode):
    if mode == 'train':
        classes_per_it = opt.classes_per_it_tr
        num_samples = opt.num_support_tr + opt.num_query_tr
    elif mode == 'val':
        classes_per_it = opt.classes_per_it_val
        num_samples = opt.num_support_val + opt.num_query_val
    elif mode == 'test':
        classes_per_it = opt.classes_per_it_test
        num_samples = opt.num_support_test + opt.num_query_test
    else:
        raise Exception("Invalid mode", mode)
    return PrototypicalBatchSampler(labels=labels,
                                    classes_per_it=classes_per_it,
                                    num_samples=num_samples,
                                    iterations=opt.iterations)
Пример #8
0
def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        test_dataset = MiniImagenetDataset(mode='val')
    else:
        print('Dataset is not valid')
    test_sampler = PrototypicalBatchSampler(labels=test_dataset.y,
                                            classes_per_it=opt.classes_per_it_val,
                                            num_samples=opt.num_support_val + opt.num_query_val,
                                            iterations=opt.iterations)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
    return test_dataloader