コード例 #1
0
ファイル: common.py プロジェクト: wx-b/trivialaugment
def copy_and_replace_transform(ds: Union[CIFAR10, ImageFolder, Subset], transform):
    assert ds.dataset.transform is not None if isinstance(ds,Subset) else (all(d.transform is not None for d in ds.datasets) if isinstance(ds,ConcatDataset) else ds.transform is not None) # make sure still uses old style transform
    if isinstance(ds, Subset):
        new_super_ds = copy(ds.dataset)
        new_super_ds.transform = transform
        new_ds = copy(ds)
        new_ds.dataset = new_super_ds
    elif isinstance(ds, ConcatDataset):
        def copy_and_replace_transform(ds):
            new_ds = copy(ds)
            new_ds.transform = transform
            return new_ds

        new_ds = ConcatDataset([copy_and_replace_transform(d) for d in ds.datasets])

    else:
        new_ds = copy(ds)
        new_ds.transform = transform
    return new_ds