def try_to_load_broden(directory, imgsize, broden_version, perturbation, download, size): # Load broden dataset ds_resolution = (224 if max(imgsize) <= 224 else 227 if max(imgsize) <= 227 else 384) if not os.path.isfile(os.path.join(directory, 'broden%d_%d' % (broden_version, ds_resolution), 'index.csv')): return None return BrodenDataset(directory, resolution=ds_resolution, download=download, broden_version=broden_version, transform=transforms.Compose([ transforms.Resize(imgsize), AddPerturbation(perturbation), transforms.ToTensor(), transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), size=size)
def test_dissection(): verbose_progress(True) from torchvision.models import alexnet from torchvision import transforms model = InstrumentedModel(alexnet(pretrained=True)) model.eval() # Load an alexnet model.retain_layers([('features.0', 'conv1'), ('features.3', 'conv2'), ('features.6', 'conv3'), ('features.8', 'conv4'), ('features.10', 'conv5')]) # load broden dataset bds = BrodenDataset('dataset/broden', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV) ]), size=100) # run dissect dissect('dissect/test', model, bds, examples_per_unit=10)