コード例 #1
0
ファイル: data.py プロジェクト: wiwi/ssl-suite
def _split_dataset(
        dataset: VisionDataset, labeled_size: int, unlabeled_size: int,
        val_size: int, num_classes: int,
        balanced: bool) -> (VisionDataset, VisionDataset, VisionDataset):
    # split given dataset into labeled, unlabeled and val

    assert labeled_size + unlabeled_size + val_size == len(dataset)
    if balanced:
        _balanced_shuffle(dataset, num_classes)
    else:
        indices = torch.randperm(len(dataset))
        dataset.data = [dataset.data[i] for i in indices]
        dataset.targets = [dataset.targets for i in indices]

    labeled_set = dataset
    unlabeled_set = deepcopy(dataset)
    val_set = deepcopy(dataset)
    labeled_set.data = [labeled_set.data[i] for i in range(labeled_size)]
    unlabeled_set.data = [
        unlabeled_set.data[i]
        for i in range(labeled_size, labeled_size + unlabeled_size)
    ]
    val_set.data = [
        val_set.data[i]
        for i in range(labeled_size + unlabeled_size, labeled_size +
                       unlabeled_size + val_size)
    ]
    labeled_set.targets = [labeled_set.targets[i] for i in range(labeled_size)]
    unlabeled_set.targets = [-1 for _ in range(unlabeled_size)]
    val_set.targets = [
        val_set.targets[i]
        for i in range(labeled_size + unlabeled_size, labeled_size +
                       unlabeled_size + val_size)
    ]
    return labeled_set, unlabeled_set, val_set
コード例 #2
0
 def __init__(
     self,
     root: str = ".",
     download: bool = True,
     split: str = "train",
     transform: Optional[Callable] = None,
 ) -> None:
     FER.__init__(self, root=root, download=download, split=split)
     VisionDataset.__init__(self, root=root, transform=transform)
コード例 #3
0
ファイル: datasets.py プロジェクト: Xiangyu-Han/homura
 def _sample_dataset(dataset: datasets.VisionDataset,
                     size: int) -> datasets.VisionDataset:
     indices = torch.randperm(len(dataset))[:size]
     if hasattr(dataset, 'data'):
         dataset.data = [dataset.data[i] for i in indices]
     if hasattr(dataset, 'samples'):
         # e.g., imagenet
         dataset.samples = [dataset.samples[i] for i in indices]
     dataset.targets = [dataset.targets[i] for i in indices]
     return dataset
コード例 #4
0
ファイル: data.py プロジェクト: taikiinoue45/dda
def _split_dataset(dataset: VisionDataset,
                   val_size: int) -> (VisionDataset, VisionDataset):
    indices = torch.randperm(len(dataset))
    valset = deepcopy(dataset)
    dataset.data = [dataset.data[i] for i in indices[val_size:]]
    dataset.targets = [dataset.targets[i] for i in indices[val_size:]]

    valset.data = [valset.data[i] for i in indices[:val_size]]
    valset.targets = [valset.targets[i] for i in indices[:val_size]]

    return dataset, valset
コード例 #5
0
ファイル: data_utils.py プロジェクト: ryanlindeborg/Sequoia
def keep_in_memory(dataset: VisionDataset) -> None:
    """ Converts the dataset's `data` and `targets` attributes to Tensors.
    
    This has the consequence of keeping the entire dataset in memory.
    """

    if hasattr(dataset, "data") and not isinstance(dataset.data, (np.ndarray, Tensor)):
        dataset.data = torch.as_tensor(dataset.data)
    if not isinstance(dataset.targets, (np.ndarray, Tensor)):
        dataset.targets = torch.as_tensor(dataset.targets)

    if isinstance(dataset, CIFAR100):
        # TODO: Cifar100 seems to want its 'data' to a numpy ndarray. 
        dataset.data = np.asarray(dataset.data)
コード例 #6
0
ファイル: datasets.py プロジェクト: matrixgame2018/homura
def _split_dataset(
        train_set: datasets.VisionDataset,
        val_size: int) -> (datasets.VisionDataset, datasets.VisionDataset):
    # split train_set to train_set and val_set
    assert len(train_set) >= val_size
    indices = torch.randperm(len(train_set))
    valset = copy.deepcopy(train_set)
    train_set.data = [train_set.data[i] for i in indices[val_size:]]
    train_set.targets = [train_set.targets[i] for i in indices[val_size:]]

    valset.data = [valset.data[i] for i in indices[:val_size]]
    valset.targets = [valset.targets[i] for i in indices[:val_size]]

    return train_set, valset