def calc_loss(self, centroids, embeddings_query, y_query): """ Calculate loss as specified in "Prototypical Networks for Few-shot Learning" page 3 algorithm 1 using tensor broadcast operations. :param centroids: Centroids calculated from support set. Shape: [batch_size, num_classes, num_embedding_features] :param embeddings_query: Embeddings from images of query set. Shape: [batch_size, query_size * num_classes, num_embedding_features] :param y_query: Labels of query samples. Shape: [batch_size, query_size * num_classes] :return: the loss scalar value """ num_classes = centroids.shape[1] centroids = centroids.unsqueeze(1) embeddings_query = embeddings_query.unsqueeze(2) loss_matrix = self.distance_fun(centroids, embeddings_query) # [batch_size, query_size, num_classes] tg.guard(loss_matrix, "*, QUERY_SIZE, NUM_CLASSES") index_correct_class = torch.arange(num_classes, device=y_query.device).view(1, 1, num_classes) index_correct_class = y_query.unsqueeze(-1) == index_correct_class index_correct_class = index_correct_class\ .expand(index_correct_class.shape[0], index_correct_class.shape[1], num_classes) index_correct_class = index_correct_class.int() # tg.guard(index_correct_class, "*, QUERY_SIZE, NUM_CLASSES") # tg.guard(loss_matrix, "*, QUERY_SIZE, NUM_CLASSES") loss_value = (loss_matrix * -1).logsumexp(dim=-1).sum() loss_value += (loss_matrix * index_correct_class).sum() num_classes_queries = y_query.shape[1] loss_value /= num_classes_queries return loss_value
def forward(self, X_supp, y_supp, X_query): num_classes = y_supp.max() + 1 bs, supp_size, c, h, w = X_supp.shape tg.guard(X_supp, "*, SUPP_SIZE, C, H, W") tg.guard(y_supp, "*, SUPP_SIZE") tg.guard(X_query, "*, QUERY_SIZE, C, H, W") query_size = X_query.shape[1] X_supp = X_supp.flatten(0, 1).contiguous() X_query = X_query.flatten(0, 1).contiguous() len_supp = len(X_supp) X_supp_query = torch.cat([X_supp, X_query], dim=0) embeddings_supp_query = self.embedding_nn(X_supp_query) embeddings_supp = embeddings_supp_query[:len_supp].view(bs, supp_size, -1) embeddings_query = embeddings_supp_query[len_supp:].view(bs, query_size, -1) tg.guard(embeddings_supp, "*, SUPP_SIZE, NUM_FEATURES") tg.guard(embeddings_query, "*, QUERY_SIZE, NUM_FEATURES") y_supp_broadcast = y_supp.unsqueeze(-1).expand(bs, supp_size, embeddings_supp.shape[-1]) centroids = torch.zeros(bs, num_classes, embeddings_supp.shape[-1], device=X_supp.device, dtype=embeddings_supp.dtype)\ .scatter_add(1, y_supp_broadcast, embeddings_supp) / supp_size * num_classes # tg.guard(centroids, "*, NUM_CLASSES, NUM_FEATURES") result = dict(centroids=centroids, embeddings_support=embeddings_supp, embeddings_query=embeddings_query) if self.get_probabilities: result['prob_query'] = self._get_probabilities(result, self.distance_function) return result
def test_reset(): import tensorguard as tg tg.reset() tg.guard([3, 5, 4], "A, B, C") assert tg.get_dims() == {"A": 3, "B": 5, "C": 4} tg.reset() assert tg.get_dims() == {}
def test_get_dims(): import tensorguard as tg tg.reset() assert tg.get_dims() == {} x = np.zeros([15, 4, 32]) tg.guard(x, "B, C, W") assert tg.get_dims("B * 2, W/4") == [30, 8]
def test_guard_raises_inferred_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3]) b = np.ones([3, 2, 5]) tg.guard(a, "A, B, C") with pytest.raises(ShapeError): tg.guard(b, "C, B, A")
def test_get_dim(): import tensorguard as tg tg.reset() x = np.zeros([32, 2, 5]) tg.guard(x, "B, C, W") assert tg.get_dim("W") == 5 with pytest.raises(KeyError): tg.get_dim("W_FAKE") assert tg.safe_get_dim("W_FAKE") is None
def test_guard_infers_dimensions_complex_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3]) tg.guard(a, "A, B*2, A+C") assert tg.get_dims() == { "A": 1, "B": 1, "C": 2 }, f'{tg.get_dims()}' + ' != {"A": 1, "B": 1, "C": 2}'
def test_set_dim(): import tensorguard as tg tg.reset() x = np.zeros([32, 2, 5]) tg.guard(x, "B, C, W") assert tg.get_dim("W") == 5 tg.set_dim("W", 10) assert tg.get_dim("W") == 10 tg.set_dim("WF", 40) assert tg.get_dim("WF") == 40 assert tg.safe_get_dim("W_FAKE") is None tg.set_dims(WW=32, HH=55) assert tg.get_dim("WW") == 32 assert tg.safe_get_dim("HH") == 55
def test_guard_dynamic_shape_global(): import tensorguard as tg tg.reset() with pytest.raises(ShapeError): tg.guard([None, 2, 3], "C, B, A") tg.guard([None, 2, 3], "?, B, A") tg.guard([1, 2, 3], "C?, B, A") tg.guard([None, 2, 3], "C?, B, A")
def test_guard_ellipsis_infer_dims_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3, 4, 5]) tg.guard(a, "A, B, ..., C") assert tg.get_dims() == {"A": 1, "B": 2, "C": 5}
def test_guard_ellipsis_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3, 4, 5]) tg.guard(a, "...") tg.guard(a, "..., 5") tg.guard(a, "..., 4, 5") tg.guard(a, "1, ...") tg.guard(a, "1, 2, ...") tg.guard(a, "1, 2, ..., 4, 5") tg.guard(a, "1, 2, 3, ..., 4, 5") with pytest.raises(ShapeError): tg.guard(a, "1, 2, 3, 4, 5, 6,...") with pytest.raises(ShapeError): tg.guard(a, "..., 1, 2, 3, 4, 5, 6")
def test_guard_ignores_wildcard_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3]) tg.guard(a, "*, *, 3") assert tg.get_dims() == {}
def test_guard_raises_complex_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "A, B, B")
def test_guard_infers_dimensions_operator_priority_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 8]) tg.guard(a, "A, B, A+C*2+1") assert tg.get_dims() == {"A": 1, "B": 2, "C": 3}
def test_guard_infers_dimensions_global(): import tensorguard as tg tg.reset() a = np.ones([1, 2, 3]) tg.guard(a, "A, B, C") assert tg.get_dims() == {"A": 1, "B": 2, "C": 3}