def calc_loss(self, centroids, embeddings_query, y_query):
        """
        Calculate loss as specified in "Prototypical Networks for Few-shot Learning" page 3 algorithm 1
        using tensor broadcast operations.
        :param centroids: Centroids calculated from support set. Shape: [batch_size, num_classes, num_embedding_features]
        :param embeddings_query: Embeddings from images of query set. Shape: [batch_size, query_size * num_classes, num_embedding_features]
        :param y_query: Labels of query samples. Shape: [batch_size, query_size * num_classes]
        :return: the loss scalar value
        """
        num_classes = centroids.shape[1]
        centroids = centroids.unsqueeze(1)
        embeddings_query = embeddings_query.unsqueeze(2)
        loss_matrix = self.distance_fun(centroids, embeddings_query)  # [batch_size, query_size, num_classes]
        tg.guard(loss_matrix, "*, QUERY_SIZE, NUM_CLASSES")
        index_correct_class = torch.arange(num_classes, device=y_query.device).view(1, 1, num_classes)
        index_correct_class = y_query.unsqueeze(-1) == index_correct_class
        index_correct_class = index_correct_class\
            .expand(index_correct_class.shape[0], index_correct_class.shape[1], num_classes)
        index_correct_class = index_correct_class.int()
#         tg.guard(index_correct_class, "*, QUERY_SIZE, NUM_CLASSES")
#         tg.guard(loss_matrix, "*, QUERY_SIZE, NUM_CLASSES")

        loss_value = (loss_matrix * -1).logsumexp(dim=-1).sum()
        loss_value += (loss_matrix * index_correct_class).sum()
        num_classes_queries = y_query.shape[1]
        loss_value /= num_classes_queries
        return loss_value
Example #2
0
    def forward(self, X_supp, y_supp, X_query):
        num_classes = y_supp.max() + 1
        bs, supp_size, c, h, w = X_supp.shape
        tg.guard(X_supp, "*, SUPP_SIZE, C, H, W")
        tg.guard(y_supp, "*, SUPP_SIZE")
        tg.guard(X_query, "*, QUERY_SIZE, C, H, W")
        query_size = X_query.shape[1]
        X_supp = X_supp.flatten(0, 1).contiguous()
        X_query = X_query.flatten(0, 1).contiguous()
        len_supp = len(X_supp)
        X_supp_query = torch.cat([X_supp, X_query], dim=0)
        embeddings_supp_query = self.embedding_nn(X_supp_query)
        embeddings_supp = embeddings_supp_query[:len_supp].view(bs, supp_size, -1)
        embeddings_query = embeddings_supp_query[len_supp:].view(bs, query_size, -1)
        tg.guard(embeddings_supp, "*, SUPP_SIZE, NUM_FEATURES")
        tg.guard(embeddings_query, "*, QUERY_SIZE, NUM_FEATURES")
        y_supp_broadcast = y_supp.unsqueeze(-1).expand(bs, supp_size, embeddings_supp.shape[-1])
        centroids = torch.zeros(bs, num_classes, embeddings_supp.shape[-1], device=X_supp.device, dtype=embeddings_supp.dtype)\
                    .scatter_add(1, y_supp_broadcast, embeddings_supp) / supp_size * num_classes
#         tg.guard(centroids, "*, NUM_CLASSES, NUM_FEATURES")
        result = dict(centroids=centroids,
                      embeddings_support=embeddings_supp,
                      embeddings_query=embeddings_query)
        if self.get_probabilities:
            result['prob_query'] = self._get_probabilities(result, self.distance_function)
        return result
Example #3
0
def test_reset():
    import tensorguard as tg
    tg.reset()
    tg.guard([3, 5, 4], "A, B, C")
    assert tg.get_dims() == {"A": 3, "B": 5, "C": 4}
    tg.reset()
    assert tg.get_dims() == {}
Example #4
0
def test_get_dims():
    import tensorguard as tg
    tg.reset()
    assert tg.get_dims() == {}
    x = np.zeros([15, 4, 32])
    tg.guard(x, "B, C, W")
    assert tg.get_dims("B * 2, W/4") == [30, 8]
Example #5
0
def test_guard_raises_inferred_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3])
    b = np.ones([3, 2, 5])
    tg.guard(a, "A, B, C")
    with pytest.raises(ShapeError):
        tg.guard(b, "C, B, A")
Example #6
0
def test_get_dim():
    import tensorguard as tg
    tg.reset()
    x = np.zeros([32, 2, 5])
    tg.guard(x, "B, C, W")
    assert tg.get_dim("W") == 5
    with pytest.raises(KeyError):
        tg.get_dim("W_FAKE")
    assert tg.safe_get_dim("W_FAKE") is None
Example #7
0
def test_guard_infers_dimensions_complex_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3])
    tg.guard(a, "A, B*2, A+C")
    assert tg.get_dims() == {
        "A": 1,
        "B": 1,
        "C": 2
    }, f'{tg.get_dims()}' + ' != {"A": 1, "B": 1, "C": 2}'
Example #8
0
def test_set_dim():
    import tensorguard as tg
    tg.reset()
    x = np.zeros([32, 2, 5])
    tg.guard(x, "B, C, W")
    assert tg.get_dim("W") == 5
    tg.set_dim("W", 10)
    assert tg.get_dim("W") == 10
    tg.set_dim("WF", 40)
    assert tg.get_dim("WF") == 40
    assert tg.safe_get_dim("W_FAKE") is None
    tg.set_dims(WW=32, HH=55)
    assert tg.get_dim("WW") == 32
    assert tg.safe_get_dim("HH") == 55
Example #9
0
def test_guard_dynamic_shape_global():
    import tensorguard as tg
    tg.reset()
    with pytest.raises(ShapeError):
        tg.guard([None, 2, 3], "C, B, A")

    tg.guard([None, 2, 3], "?, B, A")
    tg.guard([1, 2, 3], "C?, B, A")
    tg.guard([None, 2, 3], "C?, B, A")
Example #10
0
def test_guard_ellipsis_infer_dims_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3, 4, 5])
    tg.guard(a, "A, B, ..., C")
    assert tg.get_dims() == {"A": 1, "B": 2, "C": 5}
Example #11
0
def test_guard_ellipsis_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3, 4, 5])
    tg.guard(a, "...")
    tg.guard(a, "..., 5")
    tg.guard(a, "..., 4, 5")
    tg.guard(a, "1, ...")
    tg.guard(a, "1, 2, ...")
    tg.guard(a, "1, 2, ..., 4, 5")
    tg.guard(a, "1, 2, 3, ..., 4, 5")

    with pytest.raises(ShapeError):
        tg.guard(a, "1, 2, 3, 4, 5, 6,...")

    with pytest.raises(ShapeError):
        tg.guard(a, "..., 1, 2, 3, 4, 5, 6")
Example #12
0
def test_guard_ignores_wildcard_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.get_dims() == {}
Example #13
0
def test_guard_raises_complex_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "A, B, B")
Example #14
0
def test_guard_infers_dimensions_operator_priority_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 8])
    tg.guard(a, "A, B, A+C*2+1")
    assert tg.get_dims() == {"A": 1, "B": 2, "C": 3}
Example #15
0
def test_guard_infers_dimensions_global():
    import tensorguard as tg
    tg.reset()
    a = np.ones([1, 2, 3])
    tg.guard(a, "A, B, C")
    assert tg.get_dims() == {"A": 1, "B": 2, "C": 3}