from otdd.pytorch.datasets import load_torchvision_data from otdd.pytorch.distance import DatasetDistance # Load data loaders_src = load_torchvision_data(data=data_src, label=label_src, valid_size=0, resize=28, maxsize=2000)[0] loaders_tgt = load_torchvision_data(data=data_tgt, label=label_tgt, valid_size=0, resize=28, maxsize=2000)[0] # Instantiate distance dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], inner_ot_method='exact', debiased_loss=True, p=2, entreg=1e-1, device='cpu') d = dist.distance(maxsamples=1000) print(f'OTDD(MNIST,USPS)={d:8.2f}')
from otdd.pytorch.datasets import load_torchvision_data from otdd.pytorch.distance import DatasetDistance # Load data loaders_src = load_torchvision_data('MNIST', valid_size=0, resize=28, maxsize=2000)[0] loaders_tgt = load_torchvision_data('USPS', valid_size=0, resize=28, maxsize=2000)[0] # Instantiate distance dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], inner_ot_method='exact', debiased_loss=True, p=2, entreg=1e-1, device='cpu') d = dist.distance(maxsamples=1000) print(f'OTDD(MNIST,USPS)={d:8.2f}')
from otdd.pytorch.datasets import load_torchvision_data from otdd.pytorch.distance import DatasetDistance import torch import numpy as np if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') # Load datasets loaders_src = load_torchvision_data('MNIST', valid_size=0, resize=32, to3channels=True, maxsize=2000)[0] loaders_tgt = load_torchvision_data('FashionMNIST', valid_size=0, resize=32, to3channels=True, maxsize=2000)[0] #print(np.shape(loaders_src['train'])) # Instantiate distance dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], inner_ot_method='exact', debiased_loss=True, p=2, entreg=1e-1, device=device)
import torch from torchvision.models import resnet18 from otdd.pytorch.datasets import load_torchvision_data from otdd.pytorch.distance import DatasetDistance, FeatureCost # Load MNIST/CIFAR in 3channels (needed by torchvision models) loaders_src = load_torchvision_data('CIFAR10', resize=28, maxsize=2000)[0] loaders_tgt = load_torchvision_data('MNIST', resize=28, to3channels=True, maxsize=2000)[0] # Embed using a pretrained (+frozen) resnet embedder = resnet18(pretrained=True).eval() embedder.fc = torch.nn.Identity() for p in embedder.parameters(): p.requires_grad = False # Here we use same embedder for both datasets feature_cost = FeatureCost(src_embedding = embedder, src_dim = (3,28,28), tgt_embedding = embedder, tgt_dim = (3,28,28), p = 2, device='cpu') dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'], inner_ot_method = 'exact', debiased_loss = True, feature_cost = feature_cost, sqrt_method = 'spectral', sqrt_niters=10, precision='single',
'between - distances between dataset \n inter - distances between classes') args = parser.parse_args() TO3CHANNELS = True # Load datasets DATASETS = ['MNIST', 'FashionMNIST', 'KMNIST', 'CIFAR10', 'SVHN'] # 'USPS', 'SVHN', 'EMNIST' datasets = {} n_datasets = len(DATASETS) for ds_name in DATASETS: datasets[ds_name] = load_torchvision_data(ds_name, to3channels=TO3CHANNELS, resize=34, valid_size=0, maxsize=5000)[0]['train'] if args.dist_type == 'between': print('Calculating distance between datasets') distances = np.zeros((n_datasets, n_datasets)) for i, set1 in enumerate(datasets): for j, set2 in enumerate(datasets): if i >= j: continue dist = DatasetDistance(datasets[set1], datasets[set2], inner_ot_method='exact', debiased_loss=True, p=2,