def _setUp(self, batch_size, device=None): t0, t1 = torch.tensor([0., 1.]).to(device) w0 = torch.zeros(batch_size, D).to(device=device) w1 = torch.randn(batch_size, D).to(device=device) t = torch.rand([]).to(device) self.t = t self.bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, entropy=0)
def run_torch(ks=(0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12)): w0 = torch.zeros(b, d) t_cons = [] t_queries = [] t_alls = [] for k in tqdm.tqdm(ks): now = time.time() bm_vanilla = BrownianTree(t0=t0, t1=t1, w0=w0, cache_depth=k) t_con = time.time() - now t_cons.append(t_con) now = time.time() for t in ts: bm_vanilla(t).to(device) t_query = time.time() - now t_queries.append(t_query) t_all = t_con + t_query t_alls.append(t_all) logging.warning( f'k={k}, t_con={t_con:.4f}, t_query={t_query:.4f}, t_all={t_all:.4f}' ) img_path = os.path.join('.', 'diagnostics', 'plots', 'profile_btree.png') plt.figure() plt.plot(ks, t_cons, label='cons') plt.plot(ks, t_queries, label='queries') plt.plot(ks, t_alls, label='all') plt.title(f'b={b}, d={d}, repetitions={reps}, device={w0.device}') plt.xlabel('Cache level') plt.ylabel('Time (secs)') plt.legend() plt.savefig(img_path) plt.close()
def test_normality(self): """Kolmogorov-Smirnov test.""" t0_, t1_ = 0.0, 1.0 t0, t1 = torch.tensor([t0_, t1_]) eps = 1e-5 for _ in range(REPS): w0_, w1_ = 0.0, npr.randn() # Use the same endpoint for the batch, so samples from same dist. w0 = torch.tensor(w0_).repeat(LARGE_BATCH_SIZE) w1 = torch.tensor(w1_).repeat(LARGE_BATCH_SIZE) bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14) for _ in range(REPS): t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) samples = bm(t_) samples_ = samples.detach().numpy() mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_) std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_)) ref_dist = norm(loc=mean_, scale=std_) _, pval = kstest(samples_, ref_dist.cdf) self.assertGreaterEqual(pval, ALPHA)
class TestBrownianTree(TorchTestCase): def _setUp(self, batch_size, device=None): t0, t1 = torch.tensor([0., 1.]).to(device) w0 = torch.zeros(batch_size, D).to(device=device) w1 = torch.randn(batch_size, D).to(device=device) t = torch.rand([]).to(device) self.t = t self.bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, entropy=0) def test_basic_cpu(self): self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cpu')) sample = self.bm(self.t) self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D)) def test_basic_gpu(self): if not torch.cuda.is_available(): self.skipTest(reason='CUDA not available.') self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cuda')) sample = self.bm(self.t) self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D)) def test_determinism(self): self._setUp(batch_size=SMALL_BATCH_SIZE) vals = [self.bm(self.t) for _ in range(REPS)] for val in vals[1:]: self.tensorAssertAllClose(val, vals[0]) def test_normality(self): """Kolmogorov-Smirnov test.""" t0_, t1_ = 0.0, 1.0 t0, t1 = torch.tensor([t0_, t1_]) eps = 1e-5 for _ in range(REPS): w0_, w1_ = 0.0, npr.randn() # Use the same endpoint for the batch, so samples from same dist. w0 = torch.tensor(w0_).repeat(LARGE_BATCH_SIZE) w1 = torch.tensor(w1_).repeat(LARGE_BATCH_SIZE) bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14) for _ in range(REPS): t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) samples = bm(t_) samples_ = samples.detach().numpy() mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_) std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_)) ref_dist = norm(loc=mean_, scale=std_) _, pval = kstest(samples_, ref_dist.cdf) self.assertGreaterEqual(pval, ALPHA) def test_to(self): if not torch.cuda.is_available(): self.skipTest(reason='CUDA not available.') self._setUp(batch_size=SMALL_BATCH_SIZE) cache = self.bm.get_cache() old = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0) gpu = torch.device('cuda') self.bm.to(gpu) cache = self.bm.get_cache() new = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0) self.assertTrue(str(new.device).startswith('cuda')) self.tensorAssertAllClose(old, new.cpu())