def omniglot_single_alphabet(dataroot, alphabet): tf = transforms.ToTensor() # Load either background or evaluation dataset depending on alphabet dataset = Omniglot(root=dataroot, background=True, transform=tf) if alphabet not in dataset._alphabets: dataset = Omniglot(root=dataroot, background=False, transform=tf) # Filter to the single specified alphabet and split into train-test return omniglot_filter(dataset, alphabet)
def __init__(self, num_data_per_dataset, split, mnist): """ :param num_data_per_dataset: int, number of data per dataset :param split: str, type of dataset :param mnist: boolean, whether to test on mnist """ self.split = split self.num_data_per_dataset = num_data_per_dataset self.mnist = mnist if self.split == 'test' and self.mnist: data = MNIST('mnist', download=True, train=True) x_train, y_train = data.data.float() / 255, data.targets data = MNIST('mnist', download=True, train=False) x_test, y_test = data.data.float() / 255, data.targets self.x = torch.cat( [x_train.round(), x_test.round()], dim=0)[:, None] self.y = torch.cat([y_train, y_test], dim=0) else: self.dataset_background = Omniglot('omniglot', download=True) self.dataset_eval = Omniglot('omniglot', download=True, background=False) images = [] labels = [] for image, label in self.dataset_background: image = PIL.ImageOps.invert(image) images += [(np.asarray(image.resize((28, 28))) / 255)[None]] labels += [label] for image, label in self.dataset_eval: image = PIL.ImageOps.invert(image) images += [(np.asarray(image.resize((28, 28))) / 255)[None]] labels += [label] images = np.stack(images, axis=0) labels = np.array(labels, dtype=np.int) if self.split == 'test': self.images = images[1300 * 20:].round() self.labels = labels[1300 * 20:].round() elif self.split == 'val': self.images = images[1200 * 20:1300 * 20].round() self.labels = labels[1200 * 20:1300 * 20].round() else: self.images = images[:1200 * 20] self.labels = labels[:1200 * 20] self.sample_data()
def create(cls, args): trainset = Omniglot(args.dataset_root, background=True, transform=cls.get_transforms(args, True), download=True) testset = Omniglot(args.dataset_root, background=False, transform=cls.get_transforms(args, False), download=True) train, valid = _split(args, trainset) return train, valid, testset
def __init__(self, root, meta_train=True, transform=None, class_transforms=None, download=False): TorchvisionOmniglot.__init__(self, root, background=meta_train, transform=transform, download=download) Dataset.__init__(self, class_transforms=class_transforms) self._num_classes = len(self._characters)
def __init__(self, num_data_per_dataset, split, mnist): """ :param num_data_per_dataset: int, number of data per dataset :param split: str, type of dataset :param mnist: boolean, whether to test on mnist """ self.split = split self.num_data_per_dataset = num_data_per_dataset self.mnist = mnist if self.split == 'test' and self.mnist: data = MNIST('mnist', download=True, train=True) x_train, y_train = data.data.float() / 255, data.targets data = MNIST('mnist', download=True, train=False) x_test, y_test = data.data.float() / 255, data.targets self.x = torch.cat( [x_train.round(), x_test.round()], dim=0)[:, None] self.y = torch.cat([y_train, y_test], dim=0) else: self.dataset_background = Omniglot('omniglot', download=True) self.dataset_eval = Omniglot('omniglot', download=True, background=False)
def __init__( self, root, transform=None, target_transform=None, rotations=(0, 90, 180, 270), download=False, ): self.transform = transform self.target_transform = target_transform self.rotations = rotations self.n_rotations = len(rotations) self.n_per_class = 20 self._bg = Omniglot(root=root, background=True, download=download) self._bg_n_classes = n_bg = len(self._bg._character_images) self._eval = Omniglot( root=root, background=False, target_transform=lambda t: t + n_bg, download=download, ) self.n_base_classes = n_bg + len(self._eval._character_images) self.base = data.ConcatDataset([self._bg, self._eval])
def get_loader(self, subset): if self.name == "miniimagenet": dataset = MiniImageNet(subset) labels = dataset.labels elif self.name == "omniglot": dataset = Omniglot(root="data/", download=True, transform=transforms.ToTensor(), background=subset == 'train') labels = list(map(lambda x: x[1], dataset._flat_character_images)) else: raise ValueError sampler = CategoriesSamplerMult( labels, n_batches=self.n_batches if subset == 'train' else 400, ways=dict(train=self.train_ways, valid=self.valid_ways)[subset], n_images=self.shots + self.queries, n_combinations=2) return DataLoader(dataset=dataset, batch_sampler=sampler, num_workers=8, pin_memory=True)
def __init__(self, test_size=0.2, eval_size=0.2, random_state=0): self.background = Omniglot(root='data/', background=True, download=True, transform=transforms.ToTensor()) self.evaluation = Omniglot(root='data/', background=False, download=True, transform=transforms.ToTensor()) self.num_classes, self.img_rows, self.img_cols = 1623, 105, 105 self.background_data = np.array( [t.numpy()[0] for t, _ in self.background]) self.evaluation_data = np.array( [t.numpy()[0] for t, _ in self.evaluation]) self.background_labels = np.array([l for _, l in self.background]) # Due to labels in the evaluation data also starting from 0, they are offset in order # to follow immediately after the background labels. self.evaluation_labels = np.array( [l for _, l in self.evaluation]) + np.max(self.background_labels) + 1 # Split the entire data set into a training and testing set. self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( np.concatenate((self.background_data, self.evaluation_data)), np.concatenate((self.background_labels, self.evaluation_labels)), test_size=test_size, random_state=random_state) # Split the training further into a final training set and a validation set. self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split( self.x_train, self.y_train, test_size=eval_size, random_state=random_state) self.y_test = utils.to_categorical(self.y_test, self.num_classes) self.y_train = utils.to_categorical(self.y_train, self.num_classes) self.y_valid = utils.to_categorical(self.y_valid, self.num_classes) if K.image_data_format() == 'channels_first': self.x_train = self.x_train.reshape(self.x_train.shape[0], 1, self.img_rows, self.img_cols) self.x_test = self.x_test.reshape(self.x_test.shape[0], 1, self.img_rows, self.img_cols) self.x_valid = self.x_valid.reshape(self.x_valid.shape[0], 1, self.img_rows, self.img_cols) self.input_shape = (1, self.img_rows, self.img_cols) else: self.x_train = self.x_train.reshape(self.x_train.shape[0], self.img_rows, self.img_cols, 1) self.x_test = self.x_test.reshape(self.x_test.shape[0], self.img_rows, self.img_cols, 1) self.x_valid = self.x_valid.reshape(self.x_valid.shape[0], self.img_rows, self.img_cols, 1) self.input_shape = (self.img_rows, self.img_cols, 1) self.x_train = self.x_train.astype('float32') self.x_test = self.x_test.astype('float32') self.x_valid = self.x_valid.astype('float32')
import torchvision from torchvision.datasets import Omniglot omniglot = Omniglot(root="./data", download=True) omniglot.download() image = omniglot.__getitem__(300)[0] print(image) import matplotlib.pyplot as plt plt.hist() import torch.nn as nn import torch.nn.functional as F trainset = torchvision.datasets.Omniglot(root='./data', download=True, transform=transform) torchvision.datasets.ImageFolder trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) trainset.__getitem__(0) criterion = nn.CrossEntropyLoss() criterion([], [])
def __init__(self, root, n_classes, train=True, download=True, noise_std=0): """ There is already a realization of Omniglot in torchvision.datasets, main differences of this version are: * Dataset considers only a subset of all claases at each time and may reshuffle classes * Different breakdown of characters into training and validation Check https://github.com/syed-ahmed/MatchingNetworks-OSL for more details * Adding additional classes by rotating images by 90, 180 and 270 degrees * All data is preloaded into memory - because of that no option for image and target transforms - instead we hardcode .Resize() and toTensor() transforms * Adding noise augmentation for training set Args: root: root directory of images as in torchvision.datasets.Omniglot n_classes: total number of claases considered train: training or validation dataset download: flag to download the data noise_std: standard deviation of Gaussian noise applied to training data """ super(RestrictedOmniglot, self).__init__() self.root = root self.n_classes = n_classes self.train = train self.download = download self.noise_std = noise_std # Hardcoding some dataset settings self.rotations = [0, 90, 180, 270] self.images_per_class = 20 self.image_size = 28 cache_pkl = os.path.join(self.root, 'omniglot_packed.pkl') try: with open(cache_pkl, 'rb') as f: class_members = pickle.load(f) self.data, self.target, self.n_all_classes, self.target_mapping = class_members except IOError: # Relying on pytorch Omniglot class to do the downloads and data checks self.old_train = Omniglot(self.root, True, None, None, self.download) self.old_test = Omniglot(self.root, False, None, None, self.download) # After downloads images lie in root/omniglot-py/images_{background, evaluation}/alphabet/character/image.png # Retaining only those that lie in training or test set # TODO: look for more elegant solutions with glob and trailing slashes trailing_slash = "" if self.root[-1] == "/" else "/" image_paths = glob.glob(self.root + trailing_slash + "*/*/*/*/*.png") is_test_class = lambda path: any([ alphabet == path.split("/")[-3] for alphabet in OMNIGLOT_TEST_CLASSES ]) if self.train: image_paths = [x for x in image_paths if not is_test_class(x)] else: image_paths = [x for x in image_paths if is_test_class(x)] # Mapping remaining characters to classes extract_character = lambda path: path.split("/")[ -3] + "/" + path.split("/")[-2] characters = set([extract_character(x) for x in image_paths]) character_mapping = dict( zip(list(characters), range(len(characters)))) self.n_all_classes = len(self.rotations) * len(characters) # Reading images into memory self.data = torch.zeros( (self.images_per_class * self.n_all_classes, 1, self.image_size, self.image_size)) self.target = torch.zeros( (self.images_per_class * self.n_all_classes, ), dtype=torch.long) for rotation_idx, rotation in enumerate(self.rotations): for image_idx, image_path in enumerate(image_paths): target_idx = character_mapping[extract_character( image_path)] + rotation_idx * len(characters) image = Image.open(image_path, mode="r").convert("L") processed_image = image.rotate(rotation).resize( (self.image_size, self.image_size)) self.data[rotation_idx * len(image_paths) + image_idx] = ToTensor()(processed_image) self.target[rotation_idx * len(image_paths) + image_idx] = target_idx # Recording the mapping of classes to corresponding indices self.target_mapping = {x: [] for x in range(self.n_all_classes)} for (target_idx, idx) in zip(self.target, range(self.target.shape[0])): self.target_mapping[int(target_idx)].append(idx) with open(cache_pkl, 'wb') as f: class_members = [ self.data, self.target, self.n_all_classes, self.target_mapping ] pickle.dump(class_members, f) self.shuffle_classes()
silent=False) # grid_img = torchvision.utils.make_grid(y_pred, nrow=8) # plt.imshow(grid_img.detach().numpy()[0]) # plt.show() # Define transforms tsfm = transforms.Compose([ transforms.Grayscale(1), transforms.Resize(params.resize_dim), transforms.ToTensor() ]) # Import from torchvision.datasets Omniglot dataset = Omniglot(data_path, background=True, transform=tsfm, download=True) dataloader = DataLoader(dataset, params.batch_size, shuffle=True, num_workers=params.num_workers, drop_last=True) # Load visual cortex model here. model = modules.ECPretrain(D_in=1, D_out=121, KERNEL_SIZE=9, STRIDE=5, PADDING=1) # Set loss_fn to Binary cross entropy for Autoencoder. loss_fn = nn.BCELoss()
def test(flags): if flags.xpid is None: checkpointpath = "./latest/model.tar" else: checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" config.update( dict( new_stroke_penalty=flags.new_stroke_penalty, stroke_length_penalty=flags.stroke_length_penalty, )) else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) env = env_wrapper.AddDim(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model.eval() D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.eval() checkpoint = torch.load(checkpointpath, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) D.load_state_dict(checkpoint["D_state_dict"]) if flags.condition: from random import randrange c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA( root="./", split="train", target_type=None, transform=tsfm, download=True, ) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError condition = dataset[randrange(len(dataset))].view((1, 1) + obs_shape) else: condition = None frame = env.reset() action = model.initial_action() agent_state = model.initial_state() done = torch.tensor(False).view(1, 1) rewards = [] frames = [frame] for i in range(flags.episode_length - 1): if flags.mode == "test_render": env.render() noise = torch.randn(1, 1, 10) agent_outputs, agent_state = model( dict( obs=frame, condition=condition, action=action, noise=noise, done=done, ), agent_state, ) action, *_ = agent_outputs frame, reward, done, _ = env.step(action) rewards.append(reward) frames.append(frame) reward = torch.cat(rewards) frame = torch.cat(frames) if flags.use_tca: frame = torch.flatten(frame, 0, 1) if flags.condition: condition = torch.flatten(condition, 0, 1) else: frame = frame[-1] if flags.condition: condition = condition[-1] D = D.eval() with torch.no_grad(): if flags.condition: p = D(frame, condition).view(-1, 1) else: p = D(frame).view(-1, 1) if flags.use_tca: d_reward = p[1:] - p[:-1] reward = reward[1:] + d_reward else: reward[-1] = reward[-1] + p reward = reward[1:] # empty condition condition = None logging.info( "Episode ended after %d steps. Final reward: %.4f. Episode reward: %.4f,", flags.episode_length, reward[-1].item(), reward.sum(), ) env.close()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size) image_queue = Queue(maxsize=flags.max_learner_queue_size) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order env.close() model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: actor_model = models.Condition(actor_model) actor_model.to(device=flags.actor_device) D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.to(device=flags.learner_device) D_eval = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D_eval = models.Conditional(D_eval) D_eval = D_eval.to(device=flags.learner_device) optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda) C, H, W = obs_shape if flags.condition: C //= 2 # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_action=actor_model.initial_action(), initial_agent_state=actor_model.initial_state(), image=torch.zeros(1, 1, C, H, W), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA(root="./", split="train", target_type=None, transform=tsfm, download=True) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True, pin_memory=True) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"]) D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, d_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, image_queue, ), ) for i in range(flags.num_inference_threads) ] d_learner = [ threading.Thread( target=learn_D, name="d_learner-thread-%i" % i, args=( flags, d_queue, D, D_eval, D_optimizer, D_scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] for thread in d_learner: thread.daemon = True dataloader_thread = threading.Thread(target=data_loader, args=( flags, dataloader, image_queue, )) dataloader_thread.daemon = True actorpool_thread.start() threads = learner_threads + inference_threads daemons = d_learner + [dataloader_thread] for t in threads + daemons: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "D_scheduler_state_dict": D_scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in threads: t.join()
def create_tasks(dataroot, num_tasks): alphabets = Omniglot(dataroot, background=True)._alphabets + Omniglot( dataroot, background=False)._alphabets order = np.random.RandomState(666).permutation(len(alphabets)) tasks = np.array(alphabets)[order] return tasks[:num_tasks]
# -- # IO stats = { "mean" : (0.07793742418289185,), "std" : (0.2154727578163147,) } transform = transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), transforms.Lambda(lambda x: 1 - x), # Make sparse transforms.Normalize(**stats), ]) back_dataset = Omniglot(root='./data', background=True, transform=transform) test_dataset = Omniglot(root='./data', background=False, transform=transform) back_wrapper = OmniglotTaskWrapper(back_dataset) test_wrapper = OmniglotTaskWrapper(test_dataset) back_loader = precompute_batches( wrapper=back_wrapper, num_epochs=args.num_epochs, num_steps=args.num_steps, batch_size=args.batch_size, num_classes=args.num_classes, num_shots=1 ) test_loader = precompute_batches(