Esempio n. 1
0
def omniglot_single_alphabet(dataroot, alphabet):
    tf = transforms.ToTensor()

    # Load either background or evaluation dataset depending on alphabet
    dataset = Omniglot(root=dataroot, background=True, transform=tf)
    if alphabet not in dataset._alphabets:
        dataset = Omniglot(root=dataroot, background=False, transform=tf)

    # Filter to the single specified alphabet and split into train-test
    return omniglot_filter(dataset, alphabet)
Esempio n. 2
0
    def __init__(self, num_data_per_dataset, split, mnist):
        """
        :param num_data_per_dataset: int, number of data per dataset
        :param split: str, type of dataset
        :param mnist: boolean, whether to test on mnist
        """
        self.split = split
        self.num_data_per_dataset = num_data_per_dataset
        self.mnist = mnist

        if self.split == 'test' and self.mnist:
            data = MNIST('mnist', download=True, train=True)
            x_train, y_train = data.data.float() / 255, data.targets
            data = MNIST('mnist', download=True, train=False)
            x_test, y_test = data.data.float() / 255, data.targets
            self.x = torch.cat(
                [x_train.round(), x_test.round()], dim=0)[:, None]
            self.y = torch.cat([y_train, y_test], dim=0)

        else:
            self.dataset_background = Omniglot('omniglot', download=True)
            self.dataset_eval = Omniglot('omniglot',
                                         download=True,
                                         background=False)

            images = []
            labels = []

            for image, label in self.dataset_background:
                image = PIL.ImageOps.invert(image)
                images += [(np.asarray(image.resize((28, 28))) / 255)[None]]
                labels += [label]

            for image, label in self.dataset_eval:
                image = PIL.ImageOps.invert(image)
                images += [(np.asarray(image.resize((28, 28))) / 255)[None]]
                labels += [label]

            images = np.stack(images, axis=0)
            labels = np.array(labels, dtype=np.int)

            if self.split == 'test':
                self.images = images[1300 * 20:].round()
                self.labels = labels[1300 * 20:].round()
            elif self.split == 'val':
                self.images = images[1200 * 20:1300 * 20].round()
                self.labels = labels[1200 * 20:1300 * 20].round()
            else:
                self.images = images[:1200 * 20]
                self.labels = labels[:1200 * 20]

            self.sample_data()
Esempio n. 3
0
    def create(cls, args):
        trainset = Omniglot(args.dataset_root,
                            background=True,
                            transform=cls.get_transforms(args, True),
                            download=True)
        testset = Omniglot(args.dataset_root,
                           background=False,
                           transform=cls.get_transforms(args, False),
                           download=True)

        train, valid = _split(args, trainset)

        return train, valid, testset
Esempio n. 4
0
 def __init__(self,
              root,
              meta_train=True,
              transform=None,
              class_transforms=None,
              download=False):
     TorchvisionOmniglot.__init__(self,
                                  root,
                                  background=meta_train,
                                  transform=transform,
                                  download=download)
     Dataset.__init__(self, class_transforms=class_transforms)
     self._num_classes = len(self._characters)
Esempio n. 5
0
    def __init__(self, num_data_per_dataset, split, mnist):
        """
        :param num_data_per_dataset: int, number of data per dataset
        :param split: str, type of dataset
        :param mnist: boolean, whether to test on mnist
        """
        self.split = split
        self.num_data_per_dataset = num_data_per_dataset
        self.mnist = mnist

        if self.split == 'test' and self.mnist:
            data = MNIST('mnist', download=True, train=True)
            x_train, y_train = data.data.float() / 255, data.targets
            data = MNIST('mnist', download=True, train=False)
            x_test, y_test = data.data.float() / 255, data.targets
            self.x = torch.cat(
                [x_train.round(), x_test.round()], dim=0)[:, None]
            self.y = torch.cat([y_train, y_test], dim=0)

        else:
            self.dataset_background = Omniglot('omniglot', download=True)
            self.dataset_eval = Omniglot('omniglot',
                                         download=True,
                                         background=False)
Esempio n. 6
0
    def __init__(
        self,
        root,
        transform=None,
        target_transform=None,
        rotations=(0, 90, 180, 270),
        download=False,
    ):
        self.transform = transform
        self.target_transform = target_transform
        self.rotations = rotations
        self.n_rotations = len(rotations)

        self.n_per_class = 20
        self._bg = Omniglot(root=root, background=True, download=download)
        self._bg_n_classes = n_bg = len(self._bg._character_images)
        self._eval = Omniglot(
            root=root,
            background=False,
            target_transform=lambda t: t + n_bg,
            download=download,
        )
        self.n_base_classes = n_bg + len(self._eval._character_images)
        self.base = data.ConcatDataset([self._bg, self._eval])
Esempio n. 7
0
 def get_loader(self, subset):
     if self.name == "miniimagenet":
         dataset = MiniImageNet(subset)
         labels = dataset.labels
     elif self.name == "omniglot":
         dataset = Omniglot(root="data/", download=True,
                            transform=transforms.ToTensor(),
                            background=subset == 'train')
         labels = list(map(lambda x: x[1], dataset._flat_character_images))
     else:
         raise ValueError
     sampler = CategoriesSamplerMult(
             labels,
             n_batches=self.n_batches if subset == 'train' else 400,
             ways=dict(train=self.train_ways, valid=self.valid_ways)[subset],
             n_images=self.shots + self.queries,
             n_combinations=2)
     return DataLoader(dataset=dataset, batch_sampler=sampler,
                       num_workers=8, pin_memory=True)
Esempio n. 8
0
    def __init__(self, test_size=0.2, eval_size=0.2, random_state=0):
        self.background = Omniglot(root='data/',
                                   background=True,
                                   download=True,
                                   transform=transforms.ToTensor())

        self.evaluation = Omniglot(root='data/',
                                   background=False,
                                   download=True,
                                   transform=transforms.ToTensor())

        self.num_classes, self.img_rows, self.img_cols = 1623, 105, 105

        self.background_data = np.array(
            [t.numpy()[0] for t, _ in self.background])
        self.evaluation_data = np.array(
            [t.numpy()[0] for t, _ in self.evaluation])

        self.background_labels = np.array([l for _, l in self.background])
        # Due to labels in the evaluation data also starting from 0, they are offset in order
        # to follow immediately after the background labels.
        self.evaluation_labels = np.array(
            [l
             for _, l in self.evaluation]) + np.max(self.background_labels) + 1

        # Split the entire data set into a training and testing set.
        self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
            np.concatenate((self.background_data, self.evaluation_data)),
            np.concatenate((self.background_labels, self.evaluation_labels)),
            test_size=test_size,
            random_state=random_state)

        # Split the training further into a final training set and a validation set.
        self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
            self.x_train,
            self.y_train,
            test_size=eval_size,
            random_state=random_state)

        self.y_test = utils.to_categorical(self.y_test, self.num_classes)
        self.y_train = utils.to_categorical(self.y_train, self.num_classes)
        self.y_valid = utils.to_categorical(self.y_valid, self.num_classes)

        if K.image_data_format() == 'channels_first':
            self.x_train = self.x_train.reshape(self.x_train.shape[0], 1,
                                                self.img_rows, self.img_cols)
            self.x_test = self.x_test.reshape(self.x_test.shape[0], 1,
                                              self.img_rows, self.img_cols)
            self.x_valid = self.x_valid.reshape(self.x_valid.shape[0], 1,
                                                self.img_rows, self.img_cols)
            self.input_shape = (1, self.img_rows, self.img_cols)
        else:
            self.x_train = self.x_train.reshape(self.x_train.shape[0],
                                                self.img_rows, self.img_cols,
                                                1)
            self.x_test = self.x_test.reshape(self.x_test.shape[0],
                                              self.img_rows, self.img_cols, 1)
            self.x_valid = self.x_valid.reshape(self.x_valid.shape[0],
                                                self.img_rows, self.img_cols,
                                                1)
            self.input_shape = (self.img_rows, self.img_cols, 1)

        self.x_train = self.x_train.astype('float32')
        self.x_test = self.x_test.astype('float32')
        self.x_valid = self.x_valid.astype('float32')
Esempio n. 9
0
import torchvision
from torchvision.datasets import Omniglot


omniglot = Omniglot(root="./data", download=True)
omniglot.download()

image = omniglot.__getitem__(300)[0]
print(image)



import matplotlib.pyplot as plt

plt.hist()

import torch.nn as nn
import torch.nn.functional as F

trainset = torchvision.datasets.Omniglot(root='./data',
                                        download=True, transform=transform)


torchvision.datasets.ImageFolder
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

trainset.__getitem__(0)

criterion = nn.CrossEntropyLoss()
criterion([], [])
Esempio n. 10
0
    def __init__(self,
                 root,
                 n_classes,
                 train=True,
                 download=True,
                 noise_std=0):
        """
        There is already a realization of Omniglot in torchvision.datasets, main differences of this version are:
        * Dataset considers only a subset of all claases at each time and may reshuffle classes
        * Different breakdown of characters into training and validation
          Check https://github.com/syed-ahmed/MatchingNetworks-OSL for more details
        * Adding additional classes by rotating images by 90, 180 and 270 degrees
        * All data is preloaded into memory - because of that no option for image and target transforms
          - instead we hardcode .Resize() and toTensor() transforms
        * Adding noise augmentation for training set
        Args:
            root: root directory of images as in torchvision.datasets.Omniglot
            n_classes: total number of claases considered
            train: training or validation dataset
            download: flag to download the data
            noise_std: standard deviation of Gaussian noise applied to training data
        """
        super(RestrictedOmniglot, self).__init__()
        self.root = root
        self.n_classes = n_classes
        self.train = train
        self.download = download
        self.noise_std = noise_std

        # Hardcoding some dataset settings
        self.rotations = [0, 90, 180, 270]
        self.images_per_class = 20
        self.image_size = 28

        cache_pkl = os.path.join(self.root, 'omniglot_packed.pkl')
        try:
            with open(cache_pkl, 'rb') as f:
                class_members = pickle.load(f)
            self.data, self.target, self.n_all_classes, self.target_mapping = class_members
        except IOError:
            # Relying on pytorch Omniglot class to do the downloads and data checks
            self.old_train = Omniglot(self.root, True, None, None,
                                      self.download)
            self.old_test = Omniglot(self.root, False, None, None,
                                     self.download)

            # After downloads images lie in root/omniglot-py/images_{background, evaluation}/alphabet/character/image.png
            # Retaining only those that lie in training or test set
            # TODO: look for more elegant solutions with glob and trailing slashes
            trailing_slash = "" if self.root[-1] == "/" else "/"
            image_paths = glob.glob(self.root + trailing_slash +
                                    "*/*/*/*/*.png")
            is_test_class = lambda path: any([
                alphabet == path.split("/")[-3]
                for alphabet in OMNIGLOT_TEST_CLASSES
            ])
            if self.train:
                image_paths = [x for x in image_paths if not is_test_class(x)]
            else:
                image_paths = [x for x in image_paths if is_test_class(x)]

            # Mapping remaining characters to classes
            extract_character = lambda path: path.split("/")[
                -3] + "/" + path.split("/")[-2]
            characters = set([extract_character(x) for x in image_paths])
            character_mapping = dict(
                zip(list(characters), range(len(characters))))
            self.n_all_classes = len(self.rotations) * len(characters)

            # Reading images into memory
            self.data = torch.zeros(
                (self.images_per_class * self.n_all_classes, 1,
                 self.image_size, self.image_size))
            self.target = torch.zeros(
                (self.images_per_class * self.n_all_classes, ),
                dtype=torch.long)
            for rotation_idx, rotation in enumerate(self.rotations):
                for image_idx, image_path in enumerate(image_paths):
                    target_idx = character_mapping[extract_character(
                        image_path)] + rotation_idx * len(characters)
                    image = Image.open(image_path, mode="r").convert("L")
                    processed_image = image.rotate(rotation).resize(
                        (self.image_size, self.image_size))

                    self.data[rotation_idx * len(image_paths) +
                              image_idx] = ToTensor()(processed_image)
                    self.target[rotation_idx * len(image_paths) +
                                image_idx] = target_idx

            # Recording the mapping of classes to corresponding indices
            self.target_mapping = {x: [] for x in range(self.n_all_classes)}
            for (target_idx, idx) in zip(self.target,
                                         range(self.target.shape[0])):
                self.target_mapping[int(target_idx)].append(idx)

            with open(cache_pkl, 'wb') as f:
                class_members = [
                    self.data, self.target, self.n_all_classes,
                    self.target_mapping
                ]
                pickle.dump(class_members, f)

        self.shuffle_classes()
Esempio n. 11
0
                                  silent=False)

        # grid_img = torchvision.utils.make_grid(y_pred, nrow=8)
        # plt.imshow(grid_img.detach().numpy()[0])
        # plt.show()


# Define transforms
tsfm = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize(params.resize_dim),
    transforms.ToTensor()
])

# Import from torchvision.datasets Omniglot
dataset = Omniglot(data_path, background=True, transform=tsfm, download=True)
dataloader = DataLoader(dataset,
                        params.batch_size,
                        shuffle=True,
                        num_workers=params.num_workers,
                        drop_last=True)

# Load visual cortex model here.
model = modules.ECPretrain(D_in=1,
                           D_out=121,
                           KERNEL_SIZE=9,
                           STRIDE=5,
                           PADDING=1)

# Set loss_fn to Binary cross entropy for Autoencoder.
loss_fn = nn.BCELoss()
Esempio n. 12
0
def test(flags):
    if flags.xpid is None:
        checkpointpath = "./latest/model.tar"
    else:
        checkpointpath = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" %
                               (flags.savedir, flags.xpid, "model.tar")))

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
        config.update(
            dict(
                new_stroke_penalty=flags.new_stroke_penalty,
                stroke_length_penalty=flags.stroke_length_penalty,
            ))
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)
    env = env_wrapper.AddDim(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model.eval()

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.eval()

    checkpoint = torch.load(checkpointpath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    D.load_state_dict(checkpoint["D_state_dict"])

    if flags.condition:
        from random import randrange

        c, h, w = obs_shape
        tsfm = transforms.Compose(
            [transforms.Resize((h, w)),
             transforms.ToTensor()])
        dataset = flags.dataset

        if dataset == "mnist":
            dataset = MNIST(root="./",
                            train=True,
                            transform=tsfm,
                            download=True)
        elif dataset == "omniglot":
            dataset = Omniglot(root="./",
                               background=True,
                               transform=tsfm,
                               download=True)
        elif dataset == "celeba":
            dataset = CelebA(
                root="./",
                split="train",
                target_type=None,
                transform=tsfm,
                download=True,
            )
        elif dataset == "celeba-hq":
            dataset = datasets.CelebAHQ(root="./",
                                        split="train",
                                        transform=tsfm,
                                        download=True)
        else:
            raise NotImplementedError

        condition = dataset[randrange(len(dataset))].view((1, 1) + obs_shape)
    else:
        condition = None

    frame = env.reset()
    action = model.initial_action()
    agent_state = model.initial_state()
    done = torch.tensor(False).view(1, 1)
    rewards = []
    frames = [frame]

    for i in range(flags.episode_length - 1):
        if flags.mode == "test_render":
            env.render()
        noise = torch.randn(1, 1, 10)
        agent_outputs, agent_state = model(
            dict(
                obs=frame,
                condition=condition,
                action=action,
                noise=noise,
                done=done,
            ),
            agent_state,
        )
        action, *_ = agent_outputs
        frame, reward, done, _ = env.step(action)

        rewards.append(reward)
        frames.append(frame)

    reward = torch.cat(rewards)
    frame = torch.cat(frames)

    if flags.use_tca:
        frame = torch.flatten(frame, 0, 1)
        if flags.condition:
            condition = torch.flatten(condition, 0, 1)
    else:
        frame = frame[-1]
        if flags.condition:
            condition = condition[-1]

    D = D.eval()
    with torch.no_grad():
        if flags.condition:
            p = D(frame, condition).view(-1, 1)
        else:
            p = D(frame).view(-1, 1)

        if flags.use_tca:
            d_reward = p[1:] - p[:-1]
            reward = reward[1:] + d_reward
        else:
            reward[-1] = reward[-1] + p
            reward = reward[1:]

            # empty condition
            condition = None

    logging.info(
        "Episode ended after %d steps. Final reward: %.4f. Episode reward: %.4f,",
        flags.episode_length,
        reward[-1].item(),
        reward.sum(),
    )
    env.close()
Esempio n. 13
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size)
    image_queue = Queue(maxsize=flags.max_learner_queue_size)

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        actor_model = models.Condition(actor_model)
    actor_model.to(device=flags.actor_device)

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.to(device=flags.learner_device)

    D_eval = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D_eval = models.Conditional(D_eval)
    D_eval = D_eval.to(device=flags.learner_device)

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda)

    C, H, W = obs_shape
    if flags.condition:
        C //= 2
    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_action=actor_model.initial_action(),
        initial_agent_state=actor_model.initial_state(),
        image=torch.zeros(1, 1, C, H, W),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    c, h, w = obs_shape
    tsfm = transforms.Compose(
        [transforms.Resize((h, w)),
         transforms.ToTensor()])

    dataset = flags.dataset

    if dataset == "mnist":
        dataset = MNIST(root="./", train=True, transform=tsfm, download=True)
    elif dataset == "omniglot":
        dataset = Omniglot(root="./",
                           background=True,
                           transform=tsfm,
                           download=True)
    elif dataset == "celeba":
        dataset = CelebA(root="./",
                         split="train",
                         target_type=None,
                         transform=tsfm,
                         download=True)
    elif dataset == "celeba-hq":
        dataset = datasets.CelebAHQ(root="./",
                                    split="train",
                                    transform=tsfm,
                                    download=True)
    else:
        raise NotImplementedError

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"])
        D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                d_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
                image_queue,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = [
        threading.Thread(
            target=learn_D,
            name="d_learner-thread-%i" % i,
            args=(
                flags,
                d_queue,
                D,
                D_eval,
                D_optimizer,
                D_scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    for thread in d_learner:
        thread.daemon = True

    dataloader_thread = threading.Thread(target=data_loader,
                                         args=(
                                             flags,
                                             dataloader,
                                             image_queue,
                                         ))
    dataloader_thread.daemon = True

    actorpool_thread.start()

    threads = learner_threads + inference_threads
    daemons = d_learner + [dataloader_thread]

    for t in threads + daemons:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "D_scheduler_state_dict": D_scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in threads:
        t.join()
Esempio n. 14
0
def create_tasks(dataroot, num_tasks):
    alphabets = Omniglot(dataroot, background=True)._alphabets + Omniglot(
        dataroot, background=False)._alphabets
    order = np.random.RandomState(666).permutation(len(alphabets))
    tasks = np.array(alphabets)[order]
    return tasks[:num_tasks]
Esempio n. 15
0
File: mtc.py Progetto: bkj/mtcnet
    # --
    # IO

    stats = {
        "mean" : (0.07793742418289185,),
        "std"  : (0.2154727578163147,)
    }

    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: 1 - x), # Make sparse
        transforms.Normalize(**stats),
    ])

    back_dataset = Omniglot(root='./data', background=True, transform=transform)
    test_dataset = Omniglot(root='./data', background=False, transform=transform)

    back_wrapper = OmniglotTaskWrapper(back_dataset)
    test_wrapper = OmniglotTaskWrapper(test_dataset)

    back_loader  = precompute_batches(
        wrapper=back_wrapper,
        num_epochs=args.num_epochs,
        num_steps=args.num_steps,
        batch_size=args.batch_size,
        num_classes=args.num_classes,
        num_shots=1
    )

    test_loader  = precompute_batches(