Esempio n. 1
0
    def test_prioritized_pop(self):
        """ Insert two clusters and use three cluster elements for a lookup
        therefore increasing their priority. Then pop the unused keys.
        """
        keys = torch.tensor(
            [
                [0, 0, 2.1, 0, 0],
                [0, 0, 2.2, 0, 0],
                [0, 0, 1.9, 0, 0],
                [1, 0, 0.0, 0, 0],
                [0, 1, 0.0, 0, 0],
                [0, 0, 0.0, 1, 0],
                [0, 0, 0.0, 0, 1],
            ]
        )
        values = torch.tensor([2, 2, 2, 0, 0, 0, 0]).unsqueeze(1)
        N, key_size = keys.shape
        assert len(keys) == len(values) == N

        # filll DND
        dnd = DND(key_size, torch.device("cpu"), max_size=N, knn_no=3)
        for k, v in zip(keys, values):
            dnd.write(k, v, update_rule=lambda old, new: new)

        self.assertEqual(len(dnd), N)  # the DND should be of size N

        dnd.rebuild_tree()
        dnd.lookup(torch.tensor([[0, 0, 2.0, 0, 0]])).squeeze().item()

        print(dnd._priority)
        dnd.write(torch.ones(1, 5), 2)
        print(dnd._dict.keys())
Esempio n. 2
0
    def test_priority_increment(self):
        """ Insert two clusters and use three cluster elements for a lookup
        therefore increasing their priority.
        """
        keys = torch.tensor(
            [
                [0, 0, 2.1, 0, 0],
                [0, 0, 2.2, 0, 0],
                [0, 0, 1.9, 0, 0],
                [1, 0, 0.0, 0, 0],
                [0, 1, 0.0, 0, 0],
                [0, 0, 0.0, 1, 0],
                [0, 0, 0.0, 0, 1],
            ]
        )
        values = torch.tensor([2, 2, 2, 0, 0, 0, 0]).unsqueeze(1)
        N, key_size = keys.shape
        assert len(keys) == len(values) == N

        # filll DND
        dnd = DND(key_size, torch.device("cpu"), max_size=N, knn_no=3)
        for k, v in zip(keys, values):
            dnd.write(k, v, update_rule=lambda old, new: new)

        self.assertEqual(len(dnd), N)  # the DND should be of size N

        dnd.rebuild_tree()
        dnd.lookup(torch.tensor([[0, 0, 2.0, 0, 0]])).squeeze().item()
        should_be = [0, 1, 2]
        used_keys = [k for k,v in dnd._priority.items() if v == 1]
        for idx in used_keys:
            self.assertIn(idx, should_be)
Esempio n. 3
0
    def test_binary_cluster(self):
        """ Insert one-hot vectors and retrieve the proper value associated
        with each.
        """
        keys = torch.tensor(
            [
                [0, 0, 2.1, 0, 0],
                [0, 0, 2.2, 0, 0],
                [0, 0, 1.9, 0, 0],
                [1, 0, 0.0, 0, 0],
                [0, 1, 0.0, 0, 0],
                [0, 0, 0.0, 1, 0],
                [0, 0, 0.0, 0, 1],
            ]
        )
        values = torch.tensor([2, 2, 2, 0, 0, 0, 0]).unsqueeze(1)
        N, key_size = keys.shape
        assert len(keys) == len(values) == N

        # filll DND
        dnd = DND(key_size, torch.device("cpu"), max_size=N, knn_no=3)
        for k, v in zip(keys, values):
            dnd.write(k, v, update_rule=lambda old, new: new)

        self.assertEqual(len(dnd), N)  # the DND should be of size N

        dnd.rebuild_tree()
        value = dnd.lookup(torch.tensor([[0, 0, 2.0, 0, 0]])).squeeze().item()
        self.assertEqual(value, 2)
Esempio n. 4
0
    def test_clusters(self):
        """ Insert gaussian clusters and retrieve the proper value associated
        with each.
        """
        # generate clusters
        means, std = [-5.0, 5.0, 10.0], 0.5
        N_ = 100  # N per cluster
        N, key_size = N_ * len(means), 5
        keys = torch.cat([torch.randn(N_, key_size) * std + mu for mu in means])
        values = torch.cat([torch.zeros(N_, 1) + mu for mu in means])
        rnd_idxs = torch.randperm(len(keys))  # shuffle the rows
        keys = keys[rnd_idxs]
        values = values[rnd_idxs]

        # filll DND
        dnd = DND(key_size, torch.device("cpu"), max_size=N, knn_no=3)
        for k, v in zip(keys, values):
            dnd.write(k, v, update_rule=lambda old, new: new)

        self.assertEqual(len(dnd), N)  # the DND should be of size N

        # querry DND
        dnd.rebuild_tree()
        values = [dnd.lookup(torch.ones(1, key_size).fill_(mu)) for mu in means]
        values = [v.squeeze().item() for v in values]

        print(f"clusters: {means}\nvalues:   {values}")

        for value, expected in zip(values, means):
            self.assertAlmostEqual(value, expected, places=self.precision)
Esempio n. 5
0
    def test_backprop(self):
        """ Backprop through the DND. """
        keys = torch.tensor(
            [
                [0, 0, 2.1, 0, 0],
                [0, 0, 2.2, 0, 0],
                [0, 0, 1.9, 0, 0],
                # [1, 0, 0.0, 0, 0],
                # [0, 1, 0.0, 0, 0],
                # [0, 0, 0.0, 1, 0],
                # [0, 0, 0.0, 0, 1],
            ]
        )
        values = torch.tensor([2.1, 2.3, 2.]).unsqueeze(1)
        N, key_size = keys.shape
        assert len(keys) == len(values) == N

        # filll DND
        dnd = DND(key_size, torch.device("cpu"), max_size=N, knn_no=3)
        for k, v in zip(keys, values):
            dnd.write(k, v, update_rule=lambda old, new: new)

        self.assertEqual(len(dnd), N)  # the DND should be of size N

        # high dimensional new state
        obs = torch.tensor([[0, 0, 1.0, 2.0, 1.0, 0, 0]])
        params = torch.randn(key_size, obs.shape[1])
        params.requires_grad_(True)

        h = obs @ params.t()
        dnd.rebuild_tree()
        val = dnd.lookup(h)

        h.register_hook(lambda g: print("h hook:   ", g))
        val.register_hook(lambda g: print("v hook:   ", g))

        # print("h:      ", h)
        # print("h.grad? ", h.requires_grad)
        # print("h.grad: ", h.grad)
        # print("val:    ", val)
        # loss = (val - torch.tensor([0.0])).pow(2).squeeze()
        # print("loss:   ", loss)
        # loss.backward()
        # print("h.grad: ", h.grad)
        # print("w.grad: ", params.grad)
        raise NotImplementedError("Implement DND backprop test!")