def test_add_new_element(self): om = ag.ObjectsMemory() t1 = torch.arange(10, 20)[None, ...] om.add_new_element(t1, 1) npt.assert_equal(utils.t2a(om.M), np.arange(10, 20).reshape((1, 10))) npt.assert_equal(om.seq_ids, np.arange(1, 2)) g_set = set([1]) self.assertEqual(set(om.G.nodes), g_set) om = ag.ObjectsMemory() t1 = torch.arange(10, 20)[None, ...] t2 = torch.arange(20, 30)[None, ...] t3 = torch.arange(30, 40)[None, ...] om.add_new_element(t1, 1) om.add_new_element(t2, 2) om.add_new_element(t3, 3) npt.assert_equal(utils.t2a(om.M), np.arange(10, 40).reshape((3, 10))) npt.assert_equal(om.seq_ids, np.arange(1, 4)) g_set = set([1, 2, 3]) self.assertEqual(set(om.G.nodes), g_set)
def test_predict(self): data = simpledataset() sup = simplesupervisor(data[1]) agent = ag.ActiveAgent(0.5, 1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) output = [ agent.process_next([data[0][itx]], data[2][itx]) for itx in range(len(data[0])) ] predictions = [agent.predict([d])[1] for d in data[0]] is_known = [agent.predict([d])[0] for d in data[0]] npt.assert_equal(np.concatenate(predictions), data[2]) all_pred = agent.predict(list(data[0])) npt.assert_equal(np.concatenate(is_known), all_pred[0]) npt.assert_equal(np.concatenate(predictions), all_pred[1]) npt.assert_equal(agent.predict(list(data[0])), all_pred)
def test_process_next_active(self): data = simpledataset() sup = simplesupervisor(data[1]) agent = ag.ActiveAgent(0.5, 1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) output = [ agent.process_next([data[0][itx]], data[2][itx]) for itx in range(len(data[0])) ] asked_sup = np.array([o[2] for o in output]) for itx in range(len(output)): with self.subTest(n=itx): if itx > 0: self.assertTrue(output[itx][1] < itx) self.assertTrue(len(agent.sup_mem), asked_sup.sum())
def test_refine(self): data = simpledataset() sup = simplesupervisor(data[1]) def refine(agent): optim = torch.optim.sgd() l = loss.ContrastiveLoss() e = ag.create_siamese_trainer(agent, optim, l) sampler = samp.SeadableRandomSampler(agent.sup_mem, 1) data_loader = torch.utils.data.DataLoader(agent.sup_mem, sampler=sampler) e.run(data_loader, max_epochs=2) agent = ag.Agent(1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate, refine=refine) for itx in range(len(data[0])): with self.subTest(n=itx): out = agent.process_next([data[0][itx]], data[2][itx]) if itx < 2 and itx != 0: self.assertTrue(out[2]) self.assertEqual(len(agent.obj_mem), itx + 1) self.assertEqual(len(agent.sup_mem), itx)
def test_process_next_internals(self): data = simpledataset() sup = simplesupervisor(data[1]) agent = ag.Agent(1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) for itx in range(len(data[0])): with self.subTest(n=itx): out = agent.process_next([data[0][itx]], data[2][itx]) if itx < 2 and itx != 0: self.assertTrue(out[2]) self.assertEqual(len(agent.obj_mem), itx + 1) self.assertEqual(len(agent.sup_mem), itx) data = simpledataset() sup = simplesupervisor(data[1])
def af(seed, sup): torch.manual_seed(seed) agent = ag.Agent(seed, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) return agent
def tets_len(self): om = ag.ObjectsMemory() t1 = torch.arange(10, 20) t2 = torch.arange(20, 30) t3 = torch.arange(30, 40) om.add_new_element(t1, 1) om.add_new_element(t2, 2) om.add_new_element(t3, 3) self.assertEqual(len(om), 3) self.assertEqual(om.sequences, 3)
def test_get_something(self): om = ag.ObjectsMemory() t1 = torch.arange(10, 20) t2 = torch.arange(20, 30) t3 = torch.arange(30, 40) om.add_new_element(t1, 1) om.add_new_element(t2, 2) om.add_new_element(t3, 3) npt.assert_equal(om.get_sid(0), 1) npt.assert_equal(utils.t2a(om.get_embed(0)), utils.t2a(t1))
def test_supervision(self): data = simpledataset() sup = simplesupervisor(data[1]) agent = ag.ActiveAgent(1.0, 1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) output = np.array([ agent.process_next([data[0][itx]], data[2][itx]) for itx in range(len(data[0])) ]) self.assertTrue(output[1:, 2].any()) agent = ag.ActiveAgent(0.01, 1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) output = np.array([ agent.process_next([data[0][itx]], data[2][itx]) for itx in range(len(data[0])) ]) self.assertFalse(output[1:, 2].all())
def test_get_knn(self): om = ag.ObjectsMemory() t1 = torch.arange(10, 20).float() t2 = torch.arange(20, 30).float() t3 = torch.arange(40, 50).float() om.add_new_element(t1, 1) om.add_new_element(t2, 2) om.add_new_element(t3, 3) npt.assert_equal(utils.t2a(om.get_knn(t1, k=1)[1]), 0) npt.assert_equal(utils.t2a(om.get_knn(t1, k=2)[1]), np.array([[0, 1]])) t12 = torch.stack([t1, t2]) npt.assert_equal(utils.t2a(om.get_knn(t12, k=1)[1]), np.array([[0], [1]])) npt.assert_equal(utils.t2a(om.get_knn(t12, k=2)[1]), np.array([[0, 1], [1, 0]]))
def test_add_neighbors(self): om = ag.ObjectsMemory() t1 = torch.arange(10, 20)[None, ...] t2 = torch.arange(20, 30)[None, ...] t3 = torch.arange(30, 40)[None, ...] om.add_new_element(t1, 1) om.add_new_element(t2, 2) om.add_new_element(t3, 3) om.add_neighbors(3, [0]) npt.assert_equal(utils.t2a(om.M), np.arange(10, 40).reshape((3, 10))) npt.assert_equal(om.seq_ids, np.arange(1, 4)) g_set = set([1, 2, 3]) self.assertEqual(set(om.G.nodes), g_set) e_set = set([(1, 3)]) self.assertEqual(set([tuple(sorted(e)) for e in om.G.edges]), e_set)
def test_process_next_out(self): data = simpledataset() sup = simplesupervisor(data[1]) agent = ag.Agent(1, ag.ObjectsMemory(), ag.SupervisionMemory(), simplemodel(), sup, bootstrap=2, max_neigh_check=1, add_seen_element=ag.add_seen_separate) output = [ agent.process_next([data[0][itx]], data[2][itx]) for itx in range(len(data[0])) ] for itx in range(1, len(output)): with self.subTest(n=itx): self.assertTrue(output[itx][1] < itx) self.assertTrue(output[itx][2])