Esempio n. 1
0
    def setUp(self):
        # Define 5-way 5-shot few shot task.
        # TODO for miniImageNet, tieredImageNet
        self.num_classes = 5
        self.num_samples = 5
        self.num_query = 3
        data_dir = 'data/Omniglot/'
        data_dir_imagenet = 'data/miniImageNet/'
        SEED = 0

        train_chars, test_chars = split_omniglot_characters(data_dir, SEED)
        self.train_chars = train_chars
        self.test_chars = test_chars
        self.task = OmniglotTask(self.train_chars, self.num_classes,
                                 self.num_samples, self.num_query)

        train_images, test_images = load_imagenet_images(data_dir_imagenet)
        self.train_images = train_images
        self.test_images = test_images
        self.task_mini_image = ImageNetTask(self.train_images,
                                            self.num_classes, self.num_samples,
                                            self.num_query)
Esempio n. 2
0
    SEED = params.SEED

    # use GPU if available
    params.cuda = torch.cuda.is_available()  # use GPU is available

    # Set the random seed for reproducible experiments
    torch.manual_seed(SEED)
    if params.cuda: torch.cuda.manual_seed(SEED)

    # Split meta-training and meta-testing characters
    if 'Omniglot' in args.data_dir and params.dataset == 'Omniglot':
        params.in_channels = 1
        params.in_features_fc = 1
        (meta_train_classes, meta_val_classes,
         meta_test_classes) = split_omniglot_characters(args.data_dir, SEED)
        task_type = OmniglotTask
    elif ('miniImageNet' in args.data_dir or 'tieredImageNet'
          in args.data_dir) and params.dataset == 'ImageNet':
        params.in_channels = 3
        params.in_features_fc = 4
        (meta_train_classes, meta_val_classes,
         meta_test_classes) = load_imagenet_images(args.data_dir)
        task_type = ImageNetTask
    else:
        raise ValueError("I don't know your dataset")

    # Define the model and optimizer
    if params.cuda:
        model = TPN(params).cuda()
    else: