Example #1
0
    def meta_learn(self, batch, batch_idx, ways, shots, queries):
        self.features.train()
        learner = self.classifier.clone()
        learner.train()
        data, labels = batch
        data = self.features(data)

        # Separate data into adaptation and evaluation sets
        support_indices = np.zeros(data.size(0), dtype=bool)
        selection = np.arange(ways) * (shots + queries)
        for offset in range(shots):
            support_indices[selection + offset] = True
        query_indices = torch.from_numpy(~support_indices)
        support_indices = torch.from_numpy(support_indices)
        support = data[support_indices]
        support_labels = labels[support_indices]
        query = data[query_indices]
        query_labels = labels[query_indices]

        # Adapt the classifier
        for step in range(self.adaptation_steps):
            preds = learner(support)
            train_error = self.loss(preds, support_labels)
            learner.adapt(train_error)

        # Evaluating the adapted model
        predictions = learner(query)
        valid_error = self.loss(predictions, query_labels)
        valid_accuracy = accuracy(predictions, query_labels)
        return valid_error, valid_accuracy
Example #2
0
    def meta_learn(self, batch, batch_idx, ways, shots, queries):
        self.features.train()
        data, labels = batch

        # Sort data samples by labels
        sort = torch.sort(labels)
        data = data.squeeze(0)[sort.indices].squeeze(0)
        labels = labels.squeeze(0)[sort.indices].squeeze(0)

        # Compute support and query embeddings
        embeddings = self.features(data)
        support_indices = np.zeros(data.size(0), dtype=bool)
        selection = np.arange(ways) * (shots + queries)
        for offset in range(shots):
            support_indices[selection + offset] = True
        query_indices = torch.from_numpy(~support_indices)
        support_indices = torch.from_numpy(support_indices)
        support = embeddings[support_indices]
        support_labels = labels[support_indices]
        query = embeddings[query_indices]
        query_labels = labels[query_indices]

        self.classifier.fit_(support, support_labels)
        logits = self.classifier(query)
        eval_loss = self.loss(logits, query_labels)
        eval_accuracy = accuracy(logits, query_labels)
        return eval_loss, eval_accuracy
Example #3
0
    def test_simple(self):
        for normalize in [True, False]:
            X = []
            y = []
            for i in range(NUM_CLASSES):
                images = torch.randn(1, *IMAGE_SHAPES).expand(NUM_SHOTS, *IMAGE_SHAPES)
                labels = torch.ones(NUM_SHOTS).long()
                X.append(images)
                y.append(i * labels)
            X = torch.cat(X, dim=0)
            y = torch.cat(y)
            X.requires_grad = True
            X_support = X + torch.randn_like(X) * NOISE
            X_query = X + torch.randn_like(X) * NOISE

            # Compute embeddings
            X_support = X_support.view(NUM_CLASSES * NUM_SHOTS, -1)
            X_query = X_query.view(NUM_CLASSES * NUM_SHOTS, -1)

            classifier = l2l.nn.SVClassifier(
                support=X_support,
                labels=y,
                normalize=normalize,
            )
            predictions = classifier(X_query)
            acc = accuracy(predictions, y)
            self.assertTrue(acc >= 0.95)
Example #4
0
    def test_simple(self):
        for distance in ["euclidean", "cosine"]:
            for normalize in [True, False]:
                # Create some fake data
                X = []
                y = []
                for i in range(NUM_CLASSES):
                    images = torch.randn(1, *IMAGE_SHAPES).expand(
                        NUM_SHOTS, *IMAGE_SHAPES)
                    labels = torch.ones(NUM_SHOTS).long()
                    X.append(images)
                    y.append(i * labels)
                X = torch.cat(X, dim=0)
                y = torch.cat(y)
                X_support = X + torch.randn_like(X) * NOISE
                X_query = X + torch.randn_like(X) * NOISE

                # Compute embeddings
                X_support = X_support.view(NUM_CLASSES * NUM_SHOTS, -1)
                X_query = X_query.view(NUM_CLASSES * NUM_SHOTS, -1)

                classifier = l2l.nn.PrototypicalClassifier(
                    support=X_support,
                    labels=y,
                    distance=distance,
                    normalize=normalize,
                )
                predictions = classifier(X_query)
                acc = accuracy(predictions, y)
                self.assertTrue(acc >= 0.95)
Example #5
0
    for normalize in [True, False]:
        X = []
        y = []
        for i in range(NUM_CLASSES):
            images = torch.randn(1,
                                 *IMAGE_SHAPES).expand(NUM_SHOTS,
                                                       *IMAGE_SHAPES)
            labels = torch.ones(NUM_SHOTS).long()
            X.append(images)
            y.append(i * labels)
        X = torch.cat(X, dim=0)
        y = torch.cat(y)
        X.requires_grad = True
        X = X.cuda()
        y = y.cuda()
        X_support = X + torch.randn_like(X) * NOISE
        X_query = X + torch.randn_like(X) * NOISE

        # Compute embeddings
        X_support = X_support.view(NUM_CLASSES * NUM_SHOTS, -1)
        X_query = X_query.view(NUM_CLASSES * NUM_SHOTS, -1)

        classifier = SVClassifier(
            support=X_support,
            labels=y,
            normalize=normalize,
        )
        predictions = classifier(X_query)
        acc = accuracy(predictions, y)
        assert acc >= 0.95