예제 #1
0
    def __getitem__(self, index):
        if isinstance(index, int):
            raise ValueError(
                'The index of a `CombinationMetaDataset` must be '
                'a tuple of integers, and not an integer. For example, call '
                '`dataset[({0})]` to get a task with classes from 0 to {1} '
                '(got `{2}`).'.format(
                    ', '.join([
                        str(idx) for idx in range(self.num_classes_per_task)
                    ]), self.num_classes_per_task - 1, index))
        assert len(index) == self.num_classes_per_task
        datasets = [self.dataset[i] for i in index]
        # Use deepcopy on `Categorical` target transforms, to avoid any side
        # effect across tasks.
        task = ConcatTask(datasets,
                          self.num_classes_per_task,
                          target_transform=wrap_transform(
                              self.target_transform,
                              self._copy_categorical,
                              transform_type=Categorical))

        if self.dataset_transform is not None:
            task = self.dataset_transform(task)

        return task
예제 #2
0
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=True,
                 sampler=None,
                 batch_sampler=None,
                 num_workers=0,
                 collate_fn=None,
                 pin_memory=False,
                 drop_last=False,
                 timeout=0,
                 worker_init_fn=None):
        if collate_fn is None:
            collate_fn = no_collate

        if isinstance(dataset, CombinationMetaDataset) and (sampler is None):
            if shuffle:
                sampler = CombinationRandomSampler(dataset)
            else:
                sampler = CombinationSequentialSampler(dataset)
            shuffle = False
        elif isinstance(dataset, (list, tuple)):
            if shuffle:
                sampler = MultiCombinationRandomSampler(dataset)
                dataset = ConcatTask(
                    dataset,
                    dataset[0].num_classes,
                    target_transform=dataset[0].target_transform)
            else:
                raise NotImplementedError()
            shuffle = False
            # sample combinations within each dataset
            # assumes indices are concatenated
            # RandomSampler

        super(MetaDataLoader, self).__init__(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             sampler=sampler,
                                             batch_sampler=batch_sampler,
                                             num_workers=num_workers,
                                             collate_fn=collate_fn,
                                             pin_memory=pin_memory,
                                             drop_last=drop_last,
                                             timeout=timeout,
                                             worker_init_fn=worker_init_fn)