Example #1
0
from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance

# Load data
loaders_src = load_torchvision_data(data=data_src,
                                    label=label_src,
                                    valid_size=0,
                                    resize=28,
                                    maxsize=2000)[0]
loaders_tgt = load_torchvision_data(data=data_tgt,
                                    label=label_tgt,
                                    valid_size=0,
                                    resize=28,
                                    maxsize=2000)[0]

# Instantiate distance
dist = DatasetDistance(loaders_src['train'],
                       loaders_tgt['train'],
                       inner_ot_method='exact',
                       debiased_loss=True,
                       p=2,
                       entreg=1e-1,
                       device='cpu')

d = dist.distance(maxsamples=1000)
print(f'OTDD(MNIST,USPS)={d:8.2f}')
Example #2
0
from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance

# Load data
loaders_src = load_torchvision_data('MNIST',
                                    valid_size=0,
                                    resize=28,
                                    maxsize=2000)[0]
loaders_tgt = load_torchvision_data('USPS',
                                    valid_size=0,
                                    resize=28,
                                    maxsize=2000)[0]

# Instantiate distance
dist = DatasetDistance(loaders_src['train'],
                       loaders_tgt['train'],
                       inner_ot_method='exact',
                       debiased_loss=True,
                       p=2,
                       entreg=1e-1,
                       device='cpu')

d = dist.distance(maxsamples=1000)
print(f'OTDD(MNIST,USPS)={d:8.2f}')
Example #3
0
from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance
import torch
import numpy as np

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Load datasets
loaders_src = load_torchvision_data('MNIST',
                                    valid_size=0,
                                    resize=32,
                                    to3channels=True,
                                    maxsize=2000)[0]
loaders_tgt = load_torchvision_data('FashionMNIST',
                                    valid_size=0,
                                    resize=32,
                                    to3channels=True,
                                    maxsize=2000)[0]
#print(np.shape(loaders_src['train']))

# Instantiate distance
dist = DatasetDistance(loaders_src['train'],
                       loaders_tgt['train'],
                       inner_ot_method='exact',
                       debiased_loss=True,
                       p=2,
                       entreg=1e-1,
                       device=device)
Example #4
0
import torch
from torchvision.models import resnet18

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance, FeatureCost

# Load MNIST/CIFAR in 3channels (needed by torchvision models)
loaders_src = load_torchvision_data('CIFAR10', resize=28, maxsize=2000)[0]
loaders_tgt = load_torchvision_data('MNIST', resize=28, to3channels=True, maxsize=2000)[0]

# Embed using a pretrained (+frozen) resnet
embedder = resnet18(pretrained=True).eval()
embedder.fc = torch.nn.Identity()
for p in embedder.parameters():
    p.requires_grad = False

# Here we use same embedder for both datasets
feature_cost = FeatureCost(src_embedding = embedder,
                           src_dim = (3,28,28),
                           tgt_embedding = embedder,
                           tgt_dim = (3,28,28),
                           p = 2,
                           device='cpu')

dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'],
                          inner_ot_method = 'exact',
                          debiased_loss = True,
                          feature_cost = feature_cost,
                          sqrt_method = 'spectral',
                          sqrt_niters=10,
                          precision='single',
Example #5
0
    'between - distances between dataset \n inter - distances between classes')

args = parser.parse_args()

TO3CHANNELS = True

# Load datasets
DATASETS = ['MNIST', 'FashionMNIST', 'KMNIST', 'CIFAR10', 'SVHN']
# 'USPS', 'SVHN', 'EMNIST'
datasets = {}
n_datasets = len(DATASETS)

for ds_name in DATASETS:
    datasets[ds_name] = load_torchvision_data(ds_name,
                                              to3channels=TO3CHANNELS,
                                              resize=34,
                                              valid_size=0,
                                              maxsize=5000)[0]['train']

if args.dist_type == 'between':
    print('Calculating distance between datasets')
    distances = np.zeros((n_datasets, n_datasets))
    for i, set1 in enumerate(datasets):
        for j, set2 in enumerate(datasets):
            if i >= j:
                continue
            dist = DatasetDistance(datasets[set1],
                                   datasets[set2],
                                   inner_ot_method='exact',
                                   debiased_loss=True,
                                   p=2,