Exemplo n.º 1
0
    def setUp(self):
        # Define 5-way 5-shot few shot task.
        # TODO for miniImageNet, tieredImageNet
        self.num_classes = 5
        self.num_samples = 5
        self.num_query = 3
        data_dir = 'data/Omniglot/'
        data_dir_imagenet = 'data/miniImageNet/'
        SEED = 0

        train_chars, test_chars = split_omniglot_characters(data_dir, SEED)
        self.train_chars = train_chars
        self.test_chars = test_chars
        self.task = OmniglotTask(self.train_chars, self.num_classes,
                                 self.num_samples, self.num_query)

        train_images, test_images = load_imagenet_images(data_dir_imagenet)
        self.train_images = train_images
        self.test_images = test_images
        self.task_mini_image = ImageNetTask(self.train_images,
                                            self.num_classes, self.num_samples,
                                            self.num_query)
Exemplo n.º 2
0
    torch.manual_seed(SEED)
    if params.cuda: torch.cuda.manual_seed(SEED)

    # Split meta-training and meta-testing characters
    if 'Omniglot' in args.data_dir and params.dataset == 'Omniglot':
        params.in_channels = 1
        params.in_features_fc = 1
        (meta_train_classes, meta_val_classes,
         meta_test_classes) = split_omniglot_characters(args.data_dir, SEED)
        task_type = OmniglotTask
    elif ('miniImageNet' in args.data_dir or 'tieredImageNet'
          in args.data_dir) and params.dataset == 'ImageNet':
        params.in_channels = 3
        params.in_features_fc = 4
        (meta_train_classes, meta_val_classes,
         meta_test_classes) = load_imagenet_images(args.data_dir)
        task_type = ImageNetTask
    else:
        raise ValueError("I don't know your dataset")

    # Define the model and optimizer
    if params.cuda:
        model = TPN(params).cuda()
    else:
        model = TPN(params)

    # fetch loss function and metrics
    loss_fn = nn.NLLLoss()
    model_metrics = metrics

    # Reload weights from the saved file
Exemplo n.º 3
0
    if params.cuda: torch.cuda.manual_seed(SEED)

    # Set the logger
    utils.set_logger(os.path.join(args.model_dir, 'train.log'))

    # NOTE These params are only applicable to pre-specified model architecture.
    # Split meta-training and meta-testing characters
    if 'Omniglot' in args.data_dir and params.dataset == 'Omniglot':
        params.in_channels = 1
        meta_train_classes, meta_test_classes = split_omniglot_characters(
            args.data_dir, SEED)
        task_type = OmniglotTask
    elif ('miniImageNet' in args.data_dir or 'tieredImageNet'
          in args.data_dir) and params.dataset == 'ImageNet':
        params.in_channels = 3
        meta_train_classes, meta_test_classes = load_imagenet_images(
            args.data_dir)
        task_type = ImageNetTask
    else:
        raise ValueError("I don't know your dataset")

    # Define the model and optimizer
    if params.cuda:
        model = MetaLearner(params).cuda()
    else:
        model = MetaLearner(params)
    # NOTE we need to define task_lr after defining model
    model.define_task_lr_params()
    model_params = list(model.parameters()) + list(model.task_lr.values())
    meta_optimizer = torch.optim.Adam(model_params, lr=meta_lr)

    # fetch loss function and metrics