Esempio n. 1
0
    def test_upper_confidence_bound(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([[0.0]], device=device, dtype=dtype)
            variance = torch.tensor([[1.0]], device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = UpperConfidenceBound(model=mm, beta=1.0)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            ucb = module(X)
            ucb_expected = torch.tensor([1.0], device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4))

            module = UpperConfidenceBound(model=mm, beta=1.0, maximize=False)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            ucb = module(X)
            ucb_expected = torch.tensor([-1.0], device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4))

            # check for proper error if multi-output model
            mean2 = torch.rand(1, 2, device=device, dtype=dtype)
            variance2 = torch.rand(1, 2, device=device, dtype=dtype)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            module2 = UpperConfidenceBound(model=mm2, beta=1.0)
            with self.assertRaises(UnsupportedError):
                module2(X)
Esempio n. 2
0
    def test_probability_of_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([0.0], device=device, dtype=dtype).view(1, 1)
            variance = torch.ones(1, 1, device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.0250, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96, maximize=False)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.9750, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            # check for proper error if multi-output model
            mean2 = torch.rand(1, 2, device=device, dtype=dtype)
            variance2 = torch.ones_like(mean2)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            module2 = ProbabilityOfImprovement(model=mm2, best_f=0.0)
            with self.assertRaises(UnsupportedError):
                module2(X)
Esempio n. 3
0
    def test_degenerate_GPyTorchPosterior(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # singular covariance matrix
            degenerate_covar = torch.tensor(
                [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device
            )
            mean = torch.rand(3, dtype=dtype, device=device)
            mvn = MultivariateNormal(mean, lazify(degenerate_covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            variance_exp = degenerate_covar.diag().unsqueeze(-1)
            self.assertTrue(torch.equal(posterior.variance, variance_exp))

            # rsample
            with warnings.catch_warnings(record=True) as w:
                # we check that the p.d. warning is emitted - this only
                # happens once per posterior, so we need to check only once
                samples = posterior.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=device)
            b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
            b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            with warnings.catch_warnings(record=True) as w:
                b_samples = b_posterior.rsample(
                    sample_shape=torch.Size([4]), base_samples=b_base_samples
                )
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Esempio n. 4
0
 def test_GPyTorchPosterior(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, dtype=dtype, device=device)
         covar = variance.diag()
         mvn = MultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
         self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
         self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1)))
         # rsample
         samples = posterior.rsample()
         self.assertEqual(samples.shape, torch.Size([1, 3, 1]))
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
         # incompatible shapes
         with self.assertRaises(RuntimeError):
             posterior.rsample(
                 sample_shape=torch.Size([3]), base_samples=base_samples
             )
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=device)
         b_covar = b_variance.unsqueeze(-1) * torch.eye(3).type_as(b_variance)
         b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 1, device=device, dtype=dtype)
         b_samples = b_posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=b_base_samples
         )
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Esempio n. 5
0
def test_forget_mult():
    this_tests(forget_mult_CPU)
    x,f = torch.randn(5,3,20).chunk(2, dim=2)
    for (bf, bw) in [(True,True), (False,True), (True,False), (False,False)]:
        th_out = manual_forget_mult(x, f, batch_first=bf, backward=bw)
        out = forget_mult_CPU(x, f, batch_first=bf, backward=bw)
        assert torch.allclose(th_out,out)
        h = torch.randn((5 if bf else 3), 10)
        th_out = manual_forget_mult(x, f, h=h, batch_first=bf, backward=bw)
        out = forget_mult_CPU(x, f, hidden_init=h, batch_first=bf, backward=bw)
        assert torch.allclose(th_out,out)
Esempio n. 6
0
def test_qrnn_bidir():
    this_tests(QRNN)
    qrnn = QRNN(10, 20, 2, bidirectional=True, batch_first=True, window=2, output_gate=False)
    x = torch.randn(7,5,10)
    y,h = qrnn(x)
    assert y.size() == torch.Size([7, 5, 40])
    assert h.size() == torch.Size([4, 7, 20])
    #Without an out gate, the last timestamp in the forward output is the second to last hidden
    #and the first timestamp of the backward output is the last hidden
    assert torch.allclose(y[:,-1,:20], h[2])
    assert torch.allclose(y[:,0,20:], h[3])
Esempio n. 7
0
def test_qrnn_layer():
    this_tests(QRNNLayer)
    qrnn_fwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True)
    qrnn_bwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True, backward=True)
    qrnn_bwd.load_state_dict(qrnn_fwd.state_dict())
    x_fwd = torch.randn(7,5,10)
    x_bwd = x_fwd.clone().flip(1)
    y_fwd,h_fwd = qrnn_fwd(x_fwd)
    y_bwd,h_bwd = qrnn_bwd(x_bwd)
    assert torch.allclose(y_fwd, y_bwd.flip(1), rtol=1e-4, atol=1e-5)
    assert torch.allclose(h_fwd, h_bwd, rtol=1e-4, atol=1e-5)
    y_fwd,h_fwd = qrnn_fwd(x_fwd, h_fwd)
    y_bwd,h_bwd = qrnn_bwd(x_bwd, h_bwd)
    assert torch.allclose(y_fwd, y_bwd.flip(1), rtol=1e-4, atol=1e-5)
    assert torch.allclose(h_fwd, h_bwd, rtol=1e-4, atol=1e-5)
Esempio n. 8
0
    def execute(self, *args):
        desc, arguments, expect = self.get_io(*args)

        outputs = type(self).func(**arguments)
        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        msg = 'Got {} outputs while expecting {}'
        self.assertEqual(len(outputs), len(expect),
                         msg.format(len(outputs), len(expect)))

        passed_all = True
        msg = ('\nExpected output #{}:\n{}\n'
               '!#!-Got:\n{}\n'
               '!#!-min_abs_err    = {}\n'
               '!#!-mean_abs_err   = {}\n'
               '!#!-max_abs_err    = {}\n'
               '!#!-normalized_err = {}\n')
        for i, (o, e) in enumerate(zip(outputs, expect)):
            cond = o.size() == e.size() and torch.allclose(o, e, rtol=1e-1)
            if not cond:
                passed_all = False
                err = (e.data - o.data)
                abs_err = err.abs()
                desc += msg.format(
                    i, e.data.cpu().numpy(), o.data.cpu().numpy(),
                    abs_err.min().item(),
                    abs_err.mean().item(),
                    abs_err.max().item(),
                    (err.norm() / e.data.norm()).item(),
                )
            else:
                desc += '\nOutput #{} passed'.format(i)
        self.assertTrue(passed_all, desc)
        return tuple(zip(expect, outputs))
Esempio n. 9
0
 def test_forward(self):
     s = sqrtm(self.x)
     y = s.mm(s)
     self.assertTrue(
         torch.allclose(y, self.x),
         ((self.x - y).norm() / self.x.norm()).item()
     )
Esempio n. 10
0
 def test_constrained_expected_improvement_batch(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor(
             [[-0.5, 0.0, 5.0, 0.0], [0.0, 0.0, 5.0, 0.0], [0.5, 0.0, 5.0, 0.0]],
             device=device,
             dtype=dtype,
         ).unsqueeze(dim=-2)
         variance = torch.ones(3, 4, device=device, dtype=dtype).unsqueeze(dim=-2)
         N = torch.distributions.Normal(loc=0.0, scale=1.0)
         a = N.icdf(torch.tensor(0.75))  # get a so that P(-a <= N <= a) = 0.5
         mm = MockModel(MockPosterior(mean=mean, variance=variance))
         module = ConstrainedExpectedImprovement(
             model=mm,
             best_f=0.0,
             objective_index=0,
             constraints={1: [None, 0], 2: [5.0, None], 3: [-a, a]},
         )
         X = torch.empty(3, 1, 1, device=device, dtype=dtype)  # dummy
         ei = module(X)
         ei_expected_unconstrained = torch.tensor(
             [0.19780, 0.39894, 0.69780], device=device, dtype=dtype
         )
         ei_expected = ei_expected_unconstrained * 0.5 * 0.5 * 0.5
         self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
Esempio n. 11
0
 def test_diagonal_covariance(self):
     single_cov = [covariance(c, self.A, True) for c in self.batch_var]
     batch_cov = covariance(self.batch_var, self.A, True)
     close = [torch.allclose(r, c)
              for r, c in zip(batch_cov, single_cov)]
     self.assertNotIn(False, close)
     return single_cov, batch_cov
Esempio n. 12
0
 def test_variance_batch_mean_and_variance(self):
     single_var = [variance(c, v)
                   for c, v in zip(self.batch_mean, self.batch_var)]
     batch_var = variance(self.batch_mean, self.batch_var)
     close = [torch.allclose(r, c) for r, c in zip(batch_var, single_var)]
     self.assertNotIn(False, close)
     return single_var, batch_var
Esempio n. 13
0
 def test_GPyTorchPosterior_Multitask(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, 2, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, 2, dtype=dtype, device=device)
         covar = variance.view(-1).diag()
         mvn = MultitaskMultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
         self.assertTrue(torch.equal(posterior.mean, mean))
         self.assertTrue(torch.equal(posterior.variance, variance))
         # rsample
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 2, device=device, dtype=dtype)
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 2, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, 2, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=device)
         b_covar = b_variance.view(2, 6, 1) * torch.eye(6).type_as(b_variance)
         b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 2, device=device, dtype=dtype)
         b_samples = b_posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=b_base_samples
         )
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
Esempio n. 14
0
    def test_forward_no_collapse(self, cuda=False):
        for dtype in (torch.float, torch.double):

            # no resample
            sampler = SobolQMCNormalSampler(
                num_samples=4, seed=1234, collapse_batch_dims=False
            )
            self.assertFalse(sampler.resample)
            self.assertEqual(sampler.seed, 1234)
            self.assertFalse(sampler.collapse_batch_dims)
            # check samples non-batched
            posterior = _get_posterior(cuda=cuda, dtype=dtype)
            samples = sampler(posterior)
            self.assertEqual(samples.shape, torch.Size([4, 2, 1]))
            self.assertEqual(sampler.seed, 1235)
            # ensure samples are the same
            samples2 = sampler(posterior)
            self.assertTrue(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, 1235)
            # ensure this works with a differently shaped posterior
            posterior_batched = _get_posterior_batched(cuda=cuda, dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, 1236)

            # resample
            sampler = SobolQMCNormalSampler(
                num_samples=4, resample=True, collapse_batch_dims=False
            )
            self.assertTrue(sampler.resample)
            self.assertFalse(sampler.collapse_batch_dims)
            initial_seed = sampler.seed
            # check samples non-batched
            posterior = _get_posterior(cuda=cuda, dtype=dtype)
            samples = sampler(posterior=posterior)
            self.assertEqual(samples.shape, torch.Size([4, 2, 1]))
            self.assertEqual(sampler.seed, initial_seed + 1)
            # ensure samples are not the same
            samples2 = sampler(posterior)
            self.assertFalse(torch.allclose(samples, samples2))
            self.assertEqual(sampler.seed, initial_seed + 2)
            # ensure this works with a differently shaped posterior
            posterior_batched = _get_posterior_batched(cuda=cuda, dtype=dtype)
            samples_batched = sampler(posterior_batched)
            self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
            self.assertEqual(sampler.seed, initial_seed + 3)
Esempio n. 15
0
    def test_MultivariateNormalQMCEngineSeededInvTransform(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # test even dimension
            with manual_seed(54321):
                a = torch.randn(2, 2)
                cov = a @ a.transpose(-1, -2) + torch.rand(2).diag()

            mean = torch.zeros(2, device=device, dtype=dtype)
            cov = cov.to(device=device, dtype=dtype)
            engine = MultivariateNormalQMCEngine(
                mean=mean, cov=cov, seed=12345, inv_transform=True
            )
            samples = engine.draw(n=2)
            self.assertEqual(samples.dtype, dtype)
            self.assertEqual(samples.device.type, device.type)
            samples_expected = torch.tensor(
                [[-0.560064316, 0.629113674], [-1.292604208, -0.048077226]],
                device=device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(samples, samples_expected))

            # test odd dimension
            with manual_seed(54321):
                a = torch.randn(3, 3)
                cov = a @ a.transpose(-1, -2) + torch.rand(3).diag()

            mean = torch.zeros(3, device=device, dtype=dtype)
            cov = cov.to(device=device, dtype=dtype)
            engine = MultivariateNormalQMCEngine(
                mean=mean, cov=cov, seed=12345, inv_transform=True
            )
            samples = engine.draw(n=2)
            self.assertEqual(samples.dtype, dtype)
            self.assertEqual(samples.device.type, device.type)
            samples_expected = torch.tensor(
                [
                    [-2.388370037, 3.071142435, -0.319439292],
                    [-0.282978594, -4.350236893, -1.085214734],
                ],
                device=device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(samples, samples_expected))
Esempio n. 16
0
def test_forget_mult_cuda():
    this_tests(ForgetMultGPU, BwdForgetMultGPU)
    x,f = torch.randn(5,3,20).cuda().chunk(2, dim=2)
    x,f = x.contiguous().requires_grad_(True),f.contiguous().requires_grad_(True)
    th_x,th_f = detach_and_clone(x),detach_and_clone(f)
    for (bf, bw) in [(True,True), (False,True), (True,False), (False,False)]:
        forget_mult = BwdForgetMultGPU if bw else ForgetMultGPU
        th_out = forget_mult_CPU(th_x, th_f, hidden_init=None, batch_first=bf, backward=bw)
        th_loss = th_out.pow(2).mean()
        th_loss.backward()
        out = forget_mult.apply(x, f, None, bf)
        loss = out.pow(2).mean()
        loss.backward()
        assert torch.allclose(th_out,out, rtol=1e-4, atol=1e-5)
        assert torch.allclose(th_x.grad,x.grad, rtol=1e-4, atol=1e-5)
        assert torch.allclose(th_f.grad,f.grad, rtol=1e-4, atol=1e-5)
        for p in [x,f, th_x, th_f]:
            p = p.detach()
            p.grad = None
        h = torch.randn((5 if bf else 3), 10).cuda().requires_grad_(True)
        th_h = detach_and_clone(h)
        th_out = forget_mult_CPU(th_x, th_f, hidden_init=th_h, batch_first=bf, backward=bw)
        th_loss = th_out.pow(2).mean()
        th_loss.backward()
        out = forget_mult.apply(x.contiguous(), f.contiguous(), h, bf)
        loss = out.pow(2).mean()
        loss.backward()
        assert torch.allclose(th_out,out, rtol=1e-4, atol=1e-5)
        assert torch.allclose(th_x.grad,x.grad, rtol=1e-4, atol=1e-5)
        assert torch.allclose(th_f.grad,f.grad, rtol=1e-4, atol=1e-5)
        assert torch.allclose(th_h.grad,h.grad, rtol=1e-4, atol=1e-5)
        for p in [x,f, th_x, th_f]:
            p = p.detach()
            p.grad = None
Esempio n. 17
0
    def test_expected_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([[-0.5]], device=device, dtype=dtype)
            variance = torch.ones(1, 1, device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = ExpectedImprovement(model=mm, best_f=0.0)
            X = torch.empty(1, 1, device=device, dtype=dtype)  # dummy
            ei = module(X)
            ei_expected = torch.tensor(0.19780, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))

            module = ExpectedImprovement(model=mm, best_f=0.0, maximize=False)
            X = torch.empty(1, 1, device=device, dtype=dtype)  # dummy
            ei = module(X)
            ei_expected = torch.tensor(0.6978, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
Esempio n. 18
0
def test_forget_mult_forward_gpu():
    this_tests(ForgetMultGPU)
    dtype = torch.double
    x,f,h,expected = range(3,8),[0.5]*5,1,range(1,7)
    x,f,expected = [torch.tensor(t, dtype=dtype)[None,:,None].cuda() for t in (x,f,expected)]
    #output = torch.zeros(1,6,1, dtype=dtype).cuda()
    #output[0,0,0] = h
    output = ForgetMultGPU.apply(x, f, h, True)
    assert torch.allclose(output,expected[:,1:])
Esempio n. 19
0
 def test_NormalQMCEngineSeededInvTransform(self):
     # test even dimension
     engine = NormalQMCEngine(d=2, seed=12345, inv_transform=True)
     samples = engine.draw(n=2)
     self.assertEqual(samples.dtype, torch.float)
     samples_expected = torch.tensor(
         [[-0.41622922, 0.46622792], [-0.96063897, -0.75568963]]
     )
     self.assertTrue(torch.allclose(samples, samples_expected))
     # test odd dimension
     engine = NormalQMCEngine(d=3, seed=12345, inv_transform=True)
     samples = engine.draw(n=2)
     samples_expected = torch.tensor(
         [
             [-1.40525266, 1.37652443, -0.8519666],
             [-0.166497, -2.3153681, -0.15975676],
         ]
     )
     self.assertTrue(torch.allclose(samples, samples_expected))
Esempio n. 20
0
 def test_NormalQMCEngineSeeded(self):
     # test even dimension
     engine = NormalQMCEngine(d=2, seed=12345)
     samples = engine.draw(n=2)
     self.assertEqual(samples.dtype, torch.float)
     samples_expected = torch.tensor(
         [[-0.63099602, -1.32950772], [0.29625805, 1.86425618]]
     )
     self.assertTrue(torch.allclose(samples, samples_expected))
     # test odd dimension
     engine = NormalQMCEngine(d=3, seed=12345)
     samples = engine.draw(n=2)
     samples_expected = torch.tensor(
         [
             [1.83169884, -1.40473647, 0.24334828],
             [0.36596099, 1.2987395, -1.47556275],
         ]
     )
     self.assertTrue(torch.allclose(samples, samples_expected))
Esempio n. 21
0
    def test_MultivariateNormalQMCEngineSeeded(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):

            # test even dimension
            with manual_seed(54321):
                a = torch.randn(2, 2)
                cov = a @ a.transpose(-1, -2) + torch.rand(2).diag()

            mean = torch.zeros(2, device=device, dtype=dtype)
            cov = cov.to(device=device, dtype=dtype)
            engine = MultivariateNormalQMCEngine(mean=mean, cov=cov, seed=12345)
            samples = engine.draw(n=2)
            self.assertEqual(samples.dtype, dtype)
            self.assertEqual(samples.device.type, device.type)
            samples_expected = torch.tensor(
                [[-0.849047422, -0.713852942], [0.398635030, 1.350660801]],
                device=device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(samples, samples_expected))

            # test odd dimension
            with manual_seed(54321):
                a = torch.randn(3, 3)
                cov = a @ a.transpose(-1, -2) + torch.rand(3).diag()

            mean = torch.zeros(3, device=device, dtype=dtype)
            cov = cov.to(device=device, dtype=dtype)
            engine = MultivariateNormalQMCEngine(mean, cov, seed=12345)
            samples = engine.draw(n=2)
            self.assertEqual(samples.dtype, dtype)
            self.assertEqual(samples.device.type, device.type)
            samples_expected = torch.tensor(
                [
                    [3.113158941, -3.262257099, -0.819938779],
                    [0.621987879, 2.352285624, -1.992680788],
                ],
                device=device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(samples, samples_expected))
Esempio n. 22
0
 def test_NormalQMCEngineSeededOut(self):
     # test even dimension
     engine = NormalQMCEngine(d=2, seed=12345)
     out = torch.empty(2, 2)
     self.assertIsNone(engine.draw(n=2, out=out))
     samples_expected = torch.tensor(
         [[-0.63099602, -1.32950772], [0.29625805, 1.86425618]]
     )
     self.assertTrue(torch.allclose(out, samples_expected))
     # test odd dimension
     engine = NormalQMCEngine(d=3, seed=12345)
     out = torch.empty(2, 3)
     self.assertIsNone(engine.draw(n=2, out=out))
     samples_expected = torch.tensor(
         [
             [1.83169884, -1.40473647, 0.24334828],
             [0.36596099, 1.2987395, -1.47556275],
         ]
     )
     self.assertTrue(torch.allclose(out, samples_expected))
Esempio n. 23
0
 def test_linear_mc_objective(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         weights = torch.rand(3, device=device, dtype=dtype)
         obj = LinearMCObjective(weights=weights)
         samples = torch.randn(4, 2, 3, device=device, dtype=dtype)
         self.assertTrue(
             torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
         )
         samples = torch.randn(5, 4, 2, 3, device=device, dtype=dtype)
         self.assertTrue(
             torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
         )
         # make sure this errors if sample output dimensions are incompatible
         with self.assertRaises(RuntimeError):
             obj(samples=torch.randn(2, device=device, dtype=dtype))
         with self.assertRaises(RuntimeError):
             obj(samples=torch.randn(1, device=device, dtype=dtype))
         # make sure we can't construct objectives with multi-dim. weights
         with self.assertRaises(ValueError):
             LinearMCObjective(weights=torch.rand(2, 3, device=device, dtype=dtype))
         with self.assertRaises(ValueError):
             LinearMCObjective(weights=torch.tensor(1.0, device=device, dtype=dtype))
Esempio n. 24
0
 def test_expected_improvement_batch(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([-0.5, 0.0, 0.5], device=device, dtype=dtype).view(
             3, 1, 1
         )
         variance = torch.ones(3, 1, 1, device=device, dtype=dtype)
         mm = MockModel(MockPosterior(mean=mean, variance=variance))
         module = ExpectedImprovement(model=mm, best_f=0.0)
         X = torch.empty(3, 1, 1, device=device, dtype=dtype)  # dummy
         ei = module(X)
         ei_expected = torch.tensor(
             [0.19780, 0.39894, 0.69780], device=device, dtype=dtype
         )
         self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
         # check for proper error if multi-output model
         mean2 = torch.rand(3, 1, 2, device=device, dtype=dtype)
         variance2 = torch.rand(3, 1, 2, device=device, dtype=dtype)
         mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
         module2 = ExpectedImprovement(model=mm2, best_f=0.0)
         with self.assertRaises(UnsupportedError):
             module2(X)
Esempio n. 25
0
    def __init__(
        self,
        mean: Tensor,
        cov: Tensor,
        seed: Optional[int] = None,
        inv_transform: bool = False,
    ) -> None:
        r"""Engine for qMC sampling from a multivariate Normal `N(\mu, \Sigma)`.

        Args:
            mean: The mean vector.
            cov: The covariance matrix.
            seed: The seed with which to seed the random number generator of the
                underlying SobolEngine.
            inv_transform: If True, use inverse transform instead of Box-Muller.
        """
        # validate inputs
        if not cov.shape[0] == cov.shape[1]:
            raise ValueError("Covariance matrix is not square.")
        if not mean.shape[0] == cov.shape[0]:
            raise ValueError("Dimension mismatch between mean and covariance.")
        if not torch.allclose(cov, cov.transpose(-1, -2)):
            raise ValueError("Covariance matrix is not symmetric.")
        self._mean = mean
        self._normal_engine = NormalQMCEngine(
            d=mean.shape[0], seed=seed, inv_transform=inv_transform
        )
        # compute Cholesky decomp; if it fails, do the eigendecomposition
        try:
            self._corr_matrix = torch.cholesky(cov).transpose(-1, -2)
        except RuntimeError:
            eigval, eigvec = torch.symeig(cov, eigenvectors=True)
            if not torch.all(eigval >= -1e-8):
                raise ValueError("Covariance matrix not PSD.")
            eigval_root = eigval.clamp_min(0.0).sqrt()
            self._corr_matrix = (eigvec * eigval_root).transpose(-1, -2)
Esempio n. 26
0
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels,
                                             n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets',
                                      'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(
                n_fft=n_fft, hop_length=hop_length, power=2)
            out_librosa, _ = librosa.core.spectrum._spectrogram(
                y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=2)

            out_torch = spect_transform(sound).squeeze().cpu()
            self.assertTrue(
                torch.allclose(out_torch,
                               torch.from_numpy(out_librosa),
                               atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                window_fn=torch.hann_window,
                hop_length=hop_length,
                n_mels=n_mels,
                n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa,
                                                         sr=sample_rate,
                                                         n_fft=n_fft,
                                                         hop_length=hop_length,
                                                         n_mels=n_mels,
                                                         htk=True,
                                                         norm=None)
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
            torch_mel = melspect_transform(sound).squeeze().cpu()

            self.assertTrue(
                torch.allclose(torch_mel.type(librosa_mel_tensor.dtype),
                               librosa_mel_tensor,
                               atol=5e-3))

            # test s2db
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(
                torch.allclose(db_torch,
                               torch.from_numpy(db_librosa),
                               atol=5e-3))

            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
            db_librosa_tensor = torch.from_numpy(db_librosa)

            self.assertTrue(
                torch.allclose(db_torch.type(db_librosa_tensor.dtype),
                               db_librosa_tensor,
                               atol=5e-3))

            # test MFCC
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(
                sample_rate=sample_rate,
                n_mfcc=n_mfcc,
                norm='ortho',
                melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

            #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
            #                                             sr=sample_rate,
            #                                             n_mfcc = n_mfcc,
            #                                             hop_length=hop_length,
            #                                             n_fft=n_fft,
            #                                             htk=True,
            #                                             norm=None,
            #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa,
                                             axis=0,
                                             type=2,
                                             norm='ortho')[:n_mfcc]
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()

            self.assertTrue(
                torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype),
                               librosa_mfcc_tensor,
                               atol=5e-3))
Esempio n. 27
0
def test_to_hetero_with_bases_and_rgcn_equal_output():
    torch.manual_seed(1234)

    # Run `RGCN` with basis decomposition:
    x = torch.randn(10, 16)  # 6 paper nodes, 4 author nodes
    adj = (torch.rand(10, 10) > 0.5)
    adj[6:, 6:] = False
    edge_index = adj.nonzero(as_tuple=False).t().contiguous()
    row, col = edge_index

    # # 0 = paper<->paper, 1 = author->paper, 2 = paper->author
    edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long)
    edge_type[(row < 6) & (col < 6)] = 0
    edge_type[(row < 6) & (col >= 6)] = 1
    edge_type[(row >= 6) & (col < 6)] = 2
    assert edge_type.min() == 0

    num_bases = 4
    conv = RGCNConv(16, 32, num_relations=3, num_bases=num_bases, aggr='add')
    out1 = conv(x, edge_index, edge_type)

    # Run `to_hetero_with_bases`:
    x_dict = {
        'paper': x[:6],
        'author': x[6:],
    }
    edge_index_dict = {
        ('paper', '_', 'paper'):
        edge_index[:, edge_type == 0],
        ('paper', '_', 'author'):
        edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]),
        ('author', '_', 'paper'):
        edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]),
    }

    adj_t_dict = {
        key: SparseTensor.from_edge_index(edge_index).t()
        for key, edge_index in edge_index_dict.items()
    }

    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))
    model = to_hetero_with_bases(RGCN(16, 32),
                                 metadata,
                                 num_bases=num_bases,
                                 debug=False)

    # Set model weights:
    for i in range(num_bases):
        model.conv.convs[i].lin.weight.data = conv.weight[i].data.t()
        model.conv.convs[i].edge_type_weight.data = conv.comp[:, i].data.t()

    model.lin.weight.data = conv.root.data.t()
    model.lin.bias.data = conv.bias.data

    out2 = model(x_dict, edge_index_dict)
    out2 = torch.cat([out2['paper'], out2['author']], dim=0)
    assert torch.allclose(out1, out2, atol=1e-6)

    out3 = model(x_dict, adj_t_dict)
    out3 = torch.cat([out3['paper'], out3['author']], dim=0)
    assert torch.allclose(out1, out3, atol=1e-6)
Esempio n. 28
0
    def test_load_with_different_shard_plan(self) -> None:
        path = self.get_file_path()

        # We hardcode the assumption of how many shards are around
        self.assertEqual(self.world_size, dist.get_world_size())

        specs = [
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                ],
            ),
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:0/cuda:0",
                ],
            ),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[2, 0],
                    shard_sizes=[1, 20],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[3, 0],
                    shard_sizes=[3, 20],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[6, 0],
                    shard_sizes=[3, 20],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[9, 0],
                    shard_sizes=[1, 20],
                    placement="rank:0/cuda:0",
                ),
            ]),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[8, 20],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[8, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0/cuda:0",
                ),
            ]),
        ]

        for s0 in specs:
            for s1 in specs:
                if s0 == s1:
                    continue

                if dist.get_rank() == 0:
                    shutil.rmtree(path, ignore_errors=True)
                    os.makedirs(path)
                dist.barrier()

                model_to_save = MyShardedModel3(s0)
                model_to_save._register_state_dict_hook(state_dict_hook)
                state_dict_to_save = model_to_save.state_dict()

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=state_dict_to_save,
                                storage_writer=fs_writer)

                dist.barrier()

                model_to_load = MyShardedModel3(s1)
                model_to_load._register_state_dict_hook(state_dict_hook)
                state_dict_to_load_to = model_to_load.state_dict()
                dist.barrier()

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=state_dict_to_load_to,
                                storage_reader=fs_reader)

                dist.barrier()
                store_tensor = self.load_tensor(model_to_save.sharded_tensor)
                dist.barrier()
                load_tensor = self.load_tensor(model_to_load.sharded_tensor)

                if dist.get_rank() == 0:
                    self.assertTrue(torch.allclose(store_tensor, load_tensor),
                                    msg=f"{s0} vs {s1}")
def test_precision_recall(pred, target, expected_prec, expected_rec):
    prec = precision(pred, target, reduction='none')
    rec = recall(pred, target, reduction='none')

    assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
    assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
Esempio n. 30
0
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs):
    """
    Usage
    Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    Rs = 0,2,4
    Q = tensor of shape [15, 81]
    """
    dtype = torch.get_default_dtype()
    with torch_default_dtype(torch.float64):
        # reformat `formulas` and make checks
        formulas = [
            (-1 if f.startswith('-') else 1, f.replace('-', ''))
            for f in formula.split('=')
        ]
        s0, f0 = formulas[0]
        assert s0 == 1

        for _s, f in formulas:
            if len(set(f)) != len(f) or set(f) != set(f0):
                raise RuntimeError(f'{f} is not a permutation of {f0}')
            if len(f0) != len(f):
                raise RuntimeError(f'{f0} and {f} don\'t have the same number of indices')

        # `formulas` is a list of (sign, permutation of indices)
        # each formula can be viewed as a permutation of the original formula
        formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas}  # set of generators (permutations)

        # they can be composed, for instance if you have ijk=jik=ikj
        # you also have ijk=jki
        # applying all possible compositions creates an entire group
        while True:
            n = len(formulas)
            formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas])
            formulas = formulas.union([
                (s1 * s2, perm.compose(p1, p2))
                for s1, p1 in formulas
                for s2, p2 in formulas
            ])
            if len(formulas) == n:
                break  # we break when the set is stable => it is now a group \o/

        # lets clean the `kw_Rs` before checking that they are compatible with the formulas
        for i in kw_Rs:
            if not callable(kw_Rs[i]):
                Rs = o3.convention(kw_Rs[i])
                if has_parity is None:
                    has_parity = any(p != 0 for _, _, p in Rs)
                if not has_parity and not all(p == 0 for _, _, p in Rs):
                    raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere')
                if has_parity and any(p == 0 for _, _, p in Rs):
                    raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere')
                kw_Rs[i] = Rs

        if has_parity is None:
            raise RuntimeError(f'please specify the argument `has_parity`')

        group = O3() if has_parity else SO3()

        # here we check that each index has one and only one representation
        for _s, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]:
                    raise RuntimeError(f'Rs of {i} (Rs={o3.format_Rs(kw_Rs[i])}) and {j} (Rs={o3.format_Rs(kw_Rs[j])}) should be the same')
                if i in kw_Rs:
                    kw_Rs[j] = kw_Rs[i]
                if j in kw_Rs:
                    kw_Rs[i] = kw_Rs[j]

        for i in f0:
            if i not in kw_Rs:
                raise RuntimeError(f'index {i} has not Rs associated to it')

        ide = group.identity()
        dims = {i: len(kw_Rs[i](*ide)) if callable(kw_Rs[i]) else o3.dim(kw_Rs[i]) for i in f0}  # dimension of each index
        full_base = list(itertools.product(*(range(dims[i]) for i in f0)))  # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3)
        # len(full_base) degrees of freedom in an unconstrained tensor

        # but there is constraints given by the group `formulas`
        # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on
        base = set()
        for x in full_base:
            # T[x] is a coefficient of the tensor T and is related to other coefficient T[y]
            # if x and y are related by a formula
            xs = {(s, tuple(x[i] for i in p)) for s, p in formulas}
            # s * T[x] are all equal for all (s, x) in xs
            # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom
            if not (-1, x) in xs:
                # the sign is arbitrary, put both possibilities
                base.add(frozenset({
                    frozenset(xs),
                    frozenset({(-s, x) for s, x in xs})
                }))

        # len(base) is the number of degrees of freedom in the tensor.
        # Now we want to decompose these degrees of freedom into irreps

        base = sorted([sorted([sorted(xs) for xs in x]) for x in base])  # requested for python 3.7 but not for 3.8 (probably a bug in 3.7)

        # First we compute the change of basis (projection) between full_base and base
        d_sym = len(base)
        d = len(full_base)
        Q = torch.zeros(d_sym, d)

        for i, x in enumerate(base):
            x = max(x, key=lambda xs: sum(s for s, x in xs))
            for s, e in x:
                j = full_base.index(e)
                Q[i, j] = s / len(x)**0.5

        assert torch.allclose(Q @ Q.T, torch.eye(d_sym))

        if d_sym == 0:
            return [], torch.zeros(d_sym, d).to(dtype=dtype)

        # We project the representation on the basis `base`
        def representation(g):
            def re(r):
                if callable(r):
                    return r(*g)
                return o3.rep(r, *g)

            m = kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T

        # And check that after this projection it is still a representation
        assert is_representation(group, representation, eps)

        # The rest of the code simply extract the irreps present in this representation
        Rs_out = []
        A = Q.clone()
        for r in group.irrep_indices():
            if group.irrep(r)(ide).shape[0] > d_sym - o3.dim(Rs_out):
                break

            mul, B, representation = reduce(group, representation, group.irrep(r), eps)
            A = direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A
            A = _round_sqrt(A, eps)

            if has_parity:
                Rs_out += [(mul,) + r]
            else:
                Rs_out += [(mul, r)]

            if o3.dim(Rs_out) == d_sym:
                break

        if o3.dim(Rs_out) != d_sym:
            raise RuntimeError(f'unable to decompose into irreducible representations')

        return o3.simplify(Rs_out), A.to(dtype=dtype)
Esempio n. 31
0
    def test_uniform(self):
        rets = torch.ones(2, 5)
        weights = SparsemaxAllocator(5, temperature=1)(rets)

        assert torch.allclose(weights, rets / 5)
Esempio n. 32
0
    def test_texture_map(self):
        """
        Test a mesh with a texture map is loaded and rendered correctly.
        The pupils in the eyes of the cow should always be looking to the left.
        """
        device = torch.device("cuda:0")
        obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
        obj_filename = obj_dir / "cow_mesh/cow.obj"

        # Load mesh + texture
        mesh = load_objs_as_meshes([obj_filename], device=device)

        # Init rasterizer settings
        R, T = look_at_view_transform(2.7, 0, 0)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        raster_settings = RasterizationSettings(
            image_size=512, blur_radius=0.0, faces_per_pixel=1
        )

        # Init shader settings
        materials = Materials(device=device)
        lights = PointLights(device=device)

        # Place light behind the cow in world space. The front of
        # the cow is facing the -z direction.
        lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]

        # Init renderer
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
            shader=TexturedSoftPhongShader(
                lights=lights, cameras=cameras, materials=materials
            ),
        )

        # Load reference image
        image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)

        for bin_size in [0, None]:
            # Check both naive and coarse to fine produce the same output.
            renderer.rasterizer.raster_settings.bin_size = bin_size
            images = renderer(mesh)
            rgb = images[0, ..., :3].squeeze().cpu()

            if DEBUG:
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                    DATA_DIR / "DEBUG_texture_map_back.png"
                )

            # NOTE some pixels can be flaky and will not lead to
            # `cond1` being true. Add `cond2` and check `cond1 or cond2`
            cond1 = torch.allclose(rgb, image_ref, atol=0.05)
            cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
            self.assertTrue(cond1 or cond2)

        # Check grad exists
        [verts] = mesh.verts_list()
        verts.requires_grad = True
        mesh2 = Meshes(verts=[verts], faces=mesh.faces_list(), textures=mesh.textures)
        images = renderer(mesh2)
        images[0, ...].sum().backward()
        self.assertIsNotNone(verts.grad)

        ##########################################
        # Check rendering of the front of the cow
        ##########################################

        R, T = look_at_view_transform(2.7, 0, 180)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)

        # Move light to the front of the cow in world space
        lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]

        # Load reference image
        image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)

        for bin_size in [0, None]:
            # Check both naive and coarse to fine produce the same output.
            renderer.rasterizer.raster_settings.bin_size = bin_size

            images = renderer(mesh, cameras=cameras, lights=lights)
            rgb = images[0, ..., :3].squeeze().cpu()

            if DEBUG:
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                    DATA_DIR / "DEBUG_texture_map_front.png"
                )

            # NOTE some pixels can be flaky and will not lead to
            # `cond1` being true. Add `cond2` and check `cond1 or cond2`
            cond1 = torch.allclose(rgb, image_ref, atol=0.05)
            cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
            self.assertTrue(cond1 or cond2)

        #################################
        # Add blurring to rasterization
        #################################
        R, T = look_at_view_transform(2.7, 0, 180)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=512,
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
            faces_per_pixel=100,
        )

        # Load reference image
        image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)

        for bin_size in [0, None]:
            # Check both naive and coarse to fine produce the same output.
            renderer.rasterizer.raster_settings.bin_size = bin_size

            images = renderer(
                mesh.clone(),
                cameras=cameras,
                raster_settings=raster_settings,
                blend_params=blend_params,
            )
            rgb = images[0, ..., :3].squeeze().cpu()

            if DEBUG:
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                    DATA_DIR / "DEBUG_blurry_textured_rendering.png"
                )

            self.assertClose(rgb, image_ref, atol=0.05)
Esempio n. 33
0
def test_complex_norm(complex_tensor, power):
    expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
    norm_tensor = F.complex_norm(complex_tensor, power)

    assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
Esempio n. 34
0
def __wigner_3j(l1, l2, l3, _version=1):  # pragma: no cover
    """
    Computes the 3-j symbol
    https://en.wikipedia.org/wiki/3-j_symbol

    Closely related to the Clebsch–Gordan coefficients

    D(l1)_il D(l2)_jm D(l3)_kn Q_lmn == Q_ijk
    """
    # these three propositions are equivalent
    assert abs(l2 - l3) <= l1 <= l2 + l3
    assert abs(l3 - l1) <= l2 <= l3 + l1
    assert abs(l1 - l2) <= l3 <= l1 + l2

    def _DxDxD(a, b, c):
        D1 = irr_repr(l1, a, b, c)
        D2 = irr_repr(l2, a, b, c)
        D3 = irr_repr(l3, a, b, c)
        return torch.einsum('il,jm,kn->ijklmn', (D1, D2, D3)).reshape(n, n)

    n = (2 * l1 + 1) * (2 * l2 + 1) * (2 * l3 + 1)
    random_angles = [
        [4.41301023, 5.56684102, 4.59384642],
        [4.93325116, 6.12697327, 4.14574096],
        [0.53878964, 4.09050444, 5.36539036],
        [2.16017393, 3.48835314, 5.55174441],
        [2.52385107, 0.29089583, 3.90040975],
    ]

    with torch_default_dtype(torch.float64):
        B = torch.zeros((n, n))
        for abc in random_angles:
            D = _DxDxD(*abc) - torch.eye(n)
            B += D.T @ D
            del D
            gc.collect()

    # ask for one (smallest) eigenvalue/eigenvector pair if there is only one exists, otherwise ask for two
    s, v = scipy.linalg.eigh(B.numpy(),
                             eigvals=(0, min(1, n - 1)),
                             overwrite_a=True)
    del B
    gc.collect()

    kernel = v.T[s < 1e-10]
    null_space = torch.from_numpy(kernel)

    assert null_space.size(
        0) == 1, null_space.size()  # unique subspace solution
    Q = null_space[0]
    Q = Q.reshape(2 * l1 + 1, 2 * l2 + 1, 2 * l3 + 1)

    if next(x for x in Q.flatten() if x.abs() > 1e-10 * Q.abs().max()) < 0:
        Q.neg_()

    Q[Q.abs() < 1e-14] = 0

    with torch_default_dtype(torch.float64):
        abc = rand_angles()
        _Q = torch.einsum(
            "il,jm,kn,lmn",
            (irr_repr(l1, *abc), irr_repr(l2, *abc), irr_repr(l3, *abc), Q))
        assert torch.allclose(Q, _Q)

    assert Q.dtype == torch.float64
    return Q  # [m1, m2, m3]
    def create_and_check_reformer_model_with_attn_mask(self,
                                                       config,
                                                       input_ids,
                                                       input_mask,
                                                       choice_labels,
                                                       is_decoder=False):
        # no special position embeddings
        config.axial_pos_embds = False
        config.is_decoder = is_decoder

        if self.lsh_attn_chunk_length is not None:
            # need to set chunk length equal sequence length to be certain that chunking works
            config.lsh_attn_chunk_length = self.seq_length

        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        # set all position encodings to zero so that postions don't matter
        with torch.no_grad():
            embedding = model.embeddings.position_embeddings.embedding
            embedding.weight = torch.nn.Parameter(
                torch.zeros(embedding.weight.shape).to(torch_device))
            embedding.weight.requires_grad = False

        half_seq_len = self.seq_length // 2
        roll = self.chunk_length

        half_input_ids = input_ids[:, :half_seq_len]

        # normal padded
        attn_mask = torch.cat(
            [
                torch.ones_like(half_input_ids),
                torch.zeros_like(half_input_ids)
            ],
            dim=-1,
        )
        input_ids_padded = torch.cat(
            [
                half_input_ids,
                ids_tensor((self.batch_size, half_seq_len), self.vocab_size)
            ],
            dim=-1,
        )

        # shifted padded
        input_ids_roll = torch.cat(
            [
                half_input_ids,
                ids_tensor((self.batch_size, half_seq_len), self.vocab_size)
            ],
            dim=-1,
        )
        input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
        attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)

        output_padded = model(input_ids_padded,
                              attention_mask=attn_mask)[0][:, :half_seq_len]
        output_padded_rolled = model(
            input_ids_roll,
            attention_mask=attn_mask_roll)[0][:, roll:half_seq_len + roll]

        self.parent.assertTrue(
            torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
Esempio n. 36
0
    def testALEBO(self):
        B = torch.tensor(
            [[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 3.0, 4.0, 5.0, 6.0]],
            dtype=torch.double)
        train_X = torch.tensor(
            [
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [1.0, 1.0, 1.0, 1.0, 1.0],
                [2.0, 2.0, 2.0, 2.0, 2.0],
            ],
            dtype=torch.double,
        )
        train_Y = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.double)
        train_Yvar = 0.1 * torch.ones(3, 1, dtype=torch.double)

        m = ALEBO(B=B, laplace_nsamp=5, fit_restarts=1)
        self.assertTrue(torch.equal(B, m.B))
        self.assertEqual(m.laplace_nsamp, 5)
        self.assertEqual(m.fit_restarts, 1)
        self.assertEqual(m.refit_on_update, True)
        self.assertEqual(m.refit_on_cv, False)
        self.assertEqual(m.warm_start_refitting, False)

        # Test fit
        m.fit(
            Xs=[train_X, train_X],
            Ys=[train_Y, train_Y],
            Yvars=[train_Yvar, train_Yvar],
            search_space_digest=SearchSpaceDigest(
                feature_names=[],
                bounds=[(-1, 1)] * 5,
            ),
            metric_names=[],
        )
        self.assertIsInstance(m.model, ModelListGP)
        self.assertTrue(torch.allclose(m.Xs[0], (B @ train_X.t()).t()))

        # Test predict
        f, cov = m.predict(X=B)
        self.assertEqual(f.shape, torch.Size([2, 2]))
        self.assertEqual(cov.shape, torch.Size([2, 2, 2]))

        # Test best point
        objective_weights = torch.tensor([1.0, 0.0], dtype=torch.double)
        with self.assertRaises(NotImplementedError):
            m.best_point(bounds=[(-1, 1)] * 5,
                         objective_weights=objective_weights)

        # Test gen
        # With clipping
        with mock.patch(
                "ax.models.torch.alebo.optimize_acqf",
                autospec=True,
                return_value=(m.Xs[0], torch.tensor([])),
        ):
            Xopt, _, _, _ = m.gen(
                n=1,
                bounds=[(-1, 1)] * 5,
                objective_weights=torch.tensor([1.0, 0.0], dtype=torch.double),
            )

        self.assertFalse(torch.allclose(Xopt, train_X))
        self.assertTrue(Xopt.min() >= -1)
        self.assertTrue(Xopt.max() <= 1)
        # Without
        with mock.patch(
                "ax.models.torch.alebo.optimize_acqf",
                autospec=True,
                return_value=(torch.ones(1, 2, dtype=torch.double),
                              torch.tensor([])),
        ):
            Xopt, _, _, _ = m.gen(
                n=1,
                bounds=[(-1, 1)] * 5,
                objective_weights=torch.tensor([1.0, 0.0], dtype=torch.double),
            )

        self.assertTrue(
            torch.allclose(
                Xopt,
                torch.tensor([[-0.2, -0.1, 0.0, 0.1, 0.2]],
                             dtype=torch.double)))

        # Test update
        train_X2 = torch.tensor(
            [
                [3.0, 3.0, 3.0, 3.0, 3.0],
                [1.0, 1.0, 1.0, 1.0, 1.0],
                [2.0, 2.0, 2.0, 2.0, 2.0],
            ],
            dtype=torch.double,
        )
        m.update(
            Xs=[train_X, train_X2],
            Ys=[train_Y, train_Y],
            Yvars=[train_Yvar, train_Yvar],
        )
        self.assertTrue(torch.allclose(m.Xs[0], (B @ train_X.t()).t()))
        self.assertTrue(torch.allclose(m.Xs[1], (B @ train_X2.t()).t()))
        m.refit_on_update = False
        m.update(
            Xs=[train_X, train_X2],
            Ys=[train_Y, train_Y],
            Yvars=[train_Yvar, train_Yvar],
        )

        # Test get_and_fit with single meric
        gp = m.get_and_fit_model(Xs=[(B @ train_X.t()).t()],
                                 Ys=[train_Y],
                                 Yvars=[train_Yvar])
        self.assertIsInstance(gp, ALEBOGP)

        # Test cross_validate
        f, cov = m.cross_validate(
            Xs_train=[train_X],
            Ys_train=[train_Y],
            Yvars_train=[train_Yvar],
            X_test=train_X2,
        )
        self.assertEqual(f.shape, torch.Size([3, 1]))
        self.assertEqual(cov.shape, torch.Size([3, 1, 1]))
        m.refit_on_cv = True
        f, cov = m.cross_validate(
            Xs_train=[train_X],
            Ys_train=[train_Y],
            Yvars_train=[train_Yvar],
            X_test=train_X2,
        )
        self.assertEqual(f.shape, torch.Size([3, 1]))
        self.assertEqual(cov.shape, torch.Size([3, 1, 1]))
Esempio n. 37
0
    def testAcq(self):
        B = torch.tensor(
            [[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 3.0, 4.0, 5.0, 6.0]],
            dtype=torch.double)
        train_X = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                               dtype=torch.double)
        train_Y = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.double)
        train_Yvar = 0.1 * torch.ones(3, 1, dtype=torch.double)
        m = ALEBOGP(B=B,
                    train_X=train_X,
                    train_Y=train_Y,
                    train_Yvar=train_Yvar)
        m.eval()

        objective_weights = torch.tensor([1.0], dtype=torch.double)
        acq = ei_or_nei(
            model=m,
            objective_weights=objective_weights,
            outcome_constraints=None,
            X_observed=train_X,
            X_pending=None,
            q=1,
            noiseless=True,
        )
        self.assertIsInstance(acq, ExpectedImprovement)
        self.assertEqual(acq.best_f.item(), 3.0)

        objective_weights = torch.tensor([-1.0], dtype=torch.double)
        acq = ei_or_nei(
            model=m,
            objective_weights=objective_weights,
            outcome_constraints=None,
            X_observed=train_X,
            X_pending=None,
            q=1,
            noiseless=True,
        )
        self.assertEqual(acq.best_f.item(), 1.0)
        with mock.patch(
                "ax.models.torch.alebo.optimize_acqf",
                autospec=True,
                return_value=(train_X, train_Y),
        ) as optim_mock:
            alebo_acqf_optimizer(
                acq_function=acq,
                bounds=None,
                n=1,
                inequality_constraints=5.0,
                fixed_features=None,
                rounding_func=None,
                raw_samples=100,
                num_restarts=5,
                B=B,
            )
        self.assertEqual(optim_mock.call_count, 1)
        self.assertIsInstance(optim_mock.mock_calls[0][2]["acq_function"],
                              ExpectedImprovement)

        acq = ei_or_nei(
            model=m,
            objective_weights=objective_weights,
            outcome_constraints=None,
            X_observed=train_X,
            X_pending=None,
            q=1,
            noiseless=False,
        )
        self.assertIsInstance(acq, qNoisyExpectedImprovement)

        with mock.patch(
                "ax.models.torch.alebo.optimize_acqf",
                autospec=True,
                return_value=(train_X, train_Y),
        ) as optim_mock:
            alebo_acqf_optimizer(
                acq_function=acq,
                bounds=None,
                n=2,
                inequality_constraints=5.0,
                fixed_features=None,
                rounding_func=None,
                raw_samples=100,
                num_restarts=5,
                B=B,
            )

        self.assertEqual(optim_mock.call_count, 2)
        self.assertIsInstance(optim_mock.mock_calls[0][2]["acq_function"],
                              qNoisyExpectedImprovement)
        self.assertEqual(optim_mock.mock_calls[0][2]["num_restarts"], 5)
        self.assertEqual(optim_mock.mock_calls[0][2]["inequality_constraints"],
                         5.0)
        X = optim_mock.mock_calls[0][2]["batch_initial_conditions"]
        self.assertEqual(X.shape, torch.Size([5, 1, 2]))
        # Make sure initialization is inside subspace
        Z = (B @ torch.pinverse(B) @ X[:, 0, :].t()).t()
        self.assertTrue(torch.allclose(Z, X[:, 0, :]))
Esempio n. 38
0
    def forward(self, feed, moc_init_done=False, debug=False):
        summ_writer = utils_improc.Summ_writer(
            writer = feed['writer'],
            global_step = feed['global_step'],
            set_name= feed['set_name'],
            fps=8)

        writer = feed['writer']
        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        ### ... All things sensor ... ###
        sensor_rgbs = feed['sensor_imgs']
        sensor_depths = feed['sensor_depths']
        center_sensor_H, center_sensor_W = sensor_depths[0][0].shape[-1] // 2, sensor_depths[0][0].shape[-2] // 2
        ### ... All things sensor end ... ###

        # 1. Form the memory tensor using the feat net and visual images.
        # check what all do you need for this and create only those things

        ##  .... Input images ....  ##
        rgb_camRs = feed['rgb_camRs']
        rgb_camXs = feed['rgb_camXs']
        ##  .... Input images end ....  ##

        ## ... Hyperparams ... ##
        B, H, W, V, S = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S
        __p = lambda x: pack_seqdim(x, B)
        __u = lambda x: unpack_seqdim(x, B)
        PH, PW = hyp.PH, hyp.PW
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z/2), int(Y/2), int(X/2)
        ## ... Hyperparams end ... ##

        ## .... VISUAL TRANSFORMS BEGIN .... ##
        pix_T_cams = feed['pix_T_cams']
        pix_T_cams_ = __p(pix_T_cams)
        origin_T_camRs = feed['origin_T_camRs']
        origin_T_camRs_ = __p(origin_T_camRs)
        origin_T_camXs = feed['origin_T_camXs']
        origin_T_camXs_ = __p(origin_T_camXs)
        camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            origin_T_camRs_), origin_T_camXs_)
        camXs_T_camRs_ = utils_geom.safe_inverse(camRs_T_camXs_)
        camRs_T_camXs = __u(camRs_T_camXs_)
        camXs_T_camRs = __u(camXs_T_camRs_)
        pix_T_cams_ = utils_geom.pack_intrinsics(pix_T_cams_[:, 0, 0], pix_T_cams_[:, 1, 1], pix_T_cams_[:, 0, 2],
            pix_T_cams_[:, 1, 2])
        pix_T_camRs_ = torch.matmul(pix_T_cams_, camXs_T_camRs_)
        pix_T_camRs = __u(pix_T_camRs_)
        ## ... VISUAL TRANSFORMS END ... ##

        ## ... SENSOR TRANSFORMS BEGIN ... ##
        sensor_origin_T_camXs = feed['sensor_extrinsics']
        sensor_origin_T_camXs_ = __p(sensor_origin_T_camXs)
        sensor_origin_T_camRs = feed['sensor_origin_T_camRs']
        sensor_origin_T_camRs_ = __p(sensor_origin_T_camRs)
        sensor_camRs_T_origin_ = utils_geom.safe_inverse(sensor_origin_T_camRs_)

        sensor_camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            sensor_origin_T_camRs_), sensor_origin_T_camXs_)
        sensor_camXs_T_camRs_ = utils_geom.safe_inverse(sensor_camRs_T_camXs_)

        sensor_camRs_T_camXs = __u(sensor_camRs_T_camXs_)
        sensor_camXs_T_camRs = __u(sensor_camXs_T_camRs_)

        sensor_pix_T_cams = feed['sensor_intrinsics']
        sensor_pix_T_cams_ = __p(sensor_pix_T_cams)
        sensor_pix_T_cams_ = utils_geom.pack_intrinsics(sensor_pix_T_cams_[:, 0, 0], sensor_pix_T_cams_[:, 1, 1],
            sensor_pix_T_cams_[:, 0, 2], sensor_pix_T_cams_[:, 1, 2])
        sensor_pix_T_camRs_ = torch.matmul(sensor_pix_T_cams_, sensor_camXs_T_camRs_)
        sensor_pix_T_camRs = __u(sensor_pix_T_camRs_)
        ## .... SENSOR TRANSFORMS END .... ##

        ## .... Visual Input point clouds .... ##
        xyz_camXs = feed['xyz_camXs']
        xyz_camXs_ = __p(xyz_camXs)
        xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_)  # (40, 4, 4) (B*S, N, 3)
        xyz_camRs = __u(xyz_camRs_)
        assert all([torch.allclose(xyz_camR, inp_xyz_camR) for xyz_camR, inp_xyz_camR in zip(
            xyz_camRs, feed['xyz_camRs']
        )]), "computation of xyz_camR here and those computed in input do not match"
        ## .... Visual Input point clouds end .... ##

        ## ... Sensor input point clouds ... ##
        sensor_xyz_camXs = feed['sensor_xyz_camXs']
        sensor_xyz_camXs_ = __p(sensor_xyz_camXs)
        sensor_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_xyz_camXs_)
        sensor_xyz_camRs = __u(sensor_xyz_camRs_)
        assert all([torch.allclose(sensor_xyz, inp_sensor_xyz) for sensor_xyz, inp_sensor_xyz in zip(
            sensor_xyz_camRs, feed['sensor_xyz_camRs']
        )]), "the sensor_xyz_camRs computed in forward do not match those computed in input"

        ## ... visual occupancy computation voxelize the pointcloud from above ... ##
        occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
        occXs_ = utils_vox.voxelize_xyz(xyz_camXs_, Z, Y, X)
        occRs_half_ = utils_vox.voxelize_xyz(xyz_camRs_, Z2, Y2, X2)
        occXs_half_ = utils_vox.voxelize_xyz(xyz_camXs_, Z2, Y2, X2)
        ## ... visual occupancy computation end ... NOTE: no unpacking ##

        ## .. visual occupancy computation for sensor inputs .. ##
        sensor_occRs_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z, Y, X)
        sensor_occXs_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z, Y, X)
        sensor_occRs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z2, Y2, X2)
        sensor_occXs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z2, Y2, X2)

        ## ... unproject rgb images ... ##
        unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_camRs_)
        unpXs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_cams_)
        ## ... unproject rgb finish ... NOTE: no unpacking ##

        ## ... Make depth images ... ##
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(pix_T_cams_, xyz_camXs_, H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, pix_T_cams_)
        dense_xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B*S, 1, H, W])
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)
        ## ... Make depth images ... ##

        ## ... Make sensor depth images ... ##
        sensor_depth_camXs_, sensor_valid_camXs_ = utils_geom.create_depth_image(sensor_pix_T_cams_,
            sensor_xyz_camXs_, H, W)
        sensor_dense_xyz_camXs_ = utils_geom.depth2pointcloud(sensor_depth_camXs_, sensor_pix_T_cams_)
        sensor_dense_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_dense_xyz_camXs_)
        sensor_inbound_camXs_ = utils_vox.get_inbounds(sensor_dense_xyz_camRs_, Z, Y, X).float()
        sensor_inbound_camXs_ = torch.reshape(sensor_inbound_camXs_, [B*hyp.sensor_S, 1, H, W])
        sensor_valid_camXs = __u(sensor_valid_camXs_) * __u(sensor_inbound_camXs_)
        ### .. Done making sensor depth images .. ##

        ### ... Sanity check ... Write to tensorboard ... ###
        summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(__u(depth_camXs_), dim=1))
        summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camRs', torch.unbind(rgb_camRs, dim=1))
        summ_writer.summ_occs('3d_inputs/occXs', torch.unbind(__u(occXs_), dim=1), reduce_axes=[2])
        summ_writer.summ_unps('3d_inputs/unpXs', torch.unbind(__u(unpXs_), dim=1),\
            torch.unbind(__u(occXs_), dim=1))

        # A different approach for viewing occRs of sensors
        sensor_occRs = __u(sensor_occRs_)
        vis_sensor_occRs = torch.max(sensor_occRs, dim=1, keepdim=True)[0]
        # summ_writer.summ_occs('3d_inputs/sensor_occXs', torch.unbind(__u(sensor_occXs_), dim=1),
        #     reduce_axes=[2])
        summ_writer.summ_occs('3d_inputs/sensor_occRs', torch.unbind(vis_sensor_occRs, dim=1), reduce_axes=[2])

        ### ... code for visualizing sensor depths and sensor rgbs ... ###
        # summ_writer.summ_oneds('2D_inputs/depths_sensor', torch.unbind(sensor_depths, dim=1))
        # summ_writer.summ_rgbs('2D_inputs/rgbs_sensor', torch.unbind(sensor_rgbs, dim=1))
        # summ_writer.summ_oneds('2D_inputs/validXs_sensor', torch.unbind(sensor_valid_camXs, dim=1))

        if summ_writer.save_this:
            unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, matmul2(pix_T_cams_, camXs_T_camRs_))
            unpRs = __u(unpRs_)
            occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
            summ_writer.summ_occs('3d_inputs/occRs', torch.unbind(__u(occRs_), dim=1), reduce_axes=[2])
            summ_writer.summ_unps('3d_inputs/unpRs', torch.unbind(unpRs, dim=1),\
                torch.unbind(__u(occRs_), dim=1))
        ### ... Sanity check ... Writing to tensoboard complete ... ###
        results = list()

        mask_ = None
        ### ... Visual featnet part .... ###
        if hyp.do_feat:
            featXs_input = torch.cat([__u(occXs_), __u(occXs_)*__u(unpXs_)], dim=2)  # B, S, 4, H, W, D
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), occXs_half_)
            freeXs = __u(freeXs_)
            visXs = torch.clamp(__u(occXs_half_) + freeXs, 0.0, 1.0)

            if type(mask_) != type(None):
                assert(list(mask_.shape)[2:5] == list(featXs_input.shape)[2:5])
            featXs_, validXs_, _ = self.featnet(featXs_input_, summ_writer, mask=occXs_)
            # total_loss += feat_loss  # Note no need of loss

            validXs, featXs = __u(validXs_), __u(featXs_) # unpacked into B, S, C, D, H, W
            # bring everything to ref_frame
            validRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, visXs)
            featRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, featXs)  # This is now in memory coordinates

            emb3D_e = torch.mean(featRs[:, 1:], dim=1)  # context, or the features of the scene
            emb3D_g = featRs[:, 0]  # this is to predict, basically I will pass emb3D_e as input and hope to predict emb3D_g
            vis3D_e = torch.max(validRs[:, 1:], dim=1)[0] * torch.max(visRs[:, 1:], dim=1)[0]
            vis3D_g = validRs[:, 0] * visRs[:, 0]

            #### ... I do not think I need this ... ####
            results = {}
        #     # if hyp.do_eval_recall:
        #     #     results['emb3D_e'] = emb3D_e
        #     #     results['emb3D_g'] = emb3D_g
        #     #### ... Check if you need the above

            summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featRs_output', torch.unbind(featRs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e', vis3D_e, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g', vis3D_g, pca=False)

            # I need to aggregate the features and detach to prevent the backward pass on featnet
            featRs = torch.mean(featRs, dim=1)
            featRs = featRs.detach()
            #  ... HERE I HAVE THE VISUAL FEATURE TENSOR ... WHICH IS MADE USING 5 EVENLY SPACED VIEWS #

        # FOR THE TOUCH PART, I HAVE THE OCC and THE AIM IS TO PREDICT FEATURES FROM THEM #
        if hyp.do_touch_feat:
            # 1. Pass all the sensor depth images through the backbone network
            input_sensor_depths = __p(sensor_depths)
            sensor_features_ = self.backbone_2D(input_sensor_depths)

            # should normalize these feature tensors
            sensor_features_ = l2_normalize(sensor_features_, dim=1)

            sensor_features = __u(sensor_features_)
            assert torch.allclose(torch.norm(sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                "normalization has no effect on you huh."

            if hyp.do_eval_recall:
              results['sensor_features'] = sensor_features_
              results['sensor_depths'] = input_sensor_depths
              results['object_img'] = rgb_camRs
              results['sensor_imgs'] = __p(sensor_rgbs)

            # if moco is used do the same procedure as above but with a different network #
            if hyp.do_moc or hyp.do_eval_recall:
                # 1. Pass all the sensor depth images through the key network
                key_input_sensor_depths = copy.deepcopy(__p(sensor_depths)) # bx1024x1x16x16->(2048x1x16x16)
                self.key_touch_featnet.eval()
                with torch.no_grad():
                    key_sensor_features_ = self.key_touch_featnet(key_input_sensor_depths)

                key_sensor_features_ = l2_normalize(key_sensor_features_, dim=1)
                key_sensor_features = __u(key_sensor_features_)
                assert torch.allclose(torch.norm(key_sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                    "normalization has no effect on you huh."

        # doing the same procedure for moco but with a different network end #

        # do you want to do metric learning voxel point based using visual features and sensor features
        if hyp.do_touch_embML and not hyp.do_touch_forward:
            # trial 1: I do not pass the above obtained features through some encoder decoder in 3d
            # So compute the location is ref_frame which the center of these depth images will occupy
            # at all of these locations I will sample the from the visual tensor. It forms the positive pairs
            # negatives are simply everything except the positive
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)  # BxHxW as required by Pixels2Camera
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)

            # finally use apply4x4 to get the locations in ref_cam
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)

            # NOTE: convert them to memory coordinates, the name is xyz so I presume it returns xyz but talk to ADAM
            sensor_depths_centers_in_mem_ = utils_vox.Ref2Mem(sensor_depths_centers_in_ref_cam_, Z2, Y2, X2)
            sensor_depths_centers_in_mem = sensor_depths_centers_in_mem_.reshape(hyp.B, hyp.sensor_S, -1)

            if debug:
                print('assert that you are not entering here')
                from IPython import embed; embed()
                # form a (0, 1) volume here at these locations and see if it resembles a cup
                dim1 = X2 * Y2 * Z2
                dim2 = X2 * Y2
                dim3 = X2
                binary_voxel_grid = torch.zeros((hyp.B, X2, Y2, Z2))
                # NOTE: Z is the leading dimension
                rounded_idxs = torch.round(sensor_depths_centers_in_mem)
                flat_idxs = dim2 * rounded_idxs[0, :, 0] + dim3 * rounded_idxs[0, :, 1] + rounded_idxs[0, :, 2]
                flat_idxs1 = dim2 * rounded_idxs[1, :, 0] + dim3 * rounded_idxs[1, :, 1] + rounded_idxs[1, :, 2]
                flat_idxs1 = flat_idxs1 + dim1
                flat_idxs1 = flat_idxs1.long()
                flat_idxs = flat_idxs.long()

                flattened_grid = binary_voxel_grid.flatten()
                flattened_grid[flat_idxs] = 1.
                flattened_grid[flat_idxs1] = 1.

                binary_voxel_grid = flattened_grid.view(B, X2, Y2, Z2)

                assert binary_voxel_grid[0].sum() == len(torch.unique(flat_idxs)), "some indexes are missed here"
                assert binary_voxel_grid[1].sum() == len(torch.unique(flat_idxs1)), "some indexes are missed here"

                # o3d.io.write_voxel_grid("forward_pass_save/grid0.ply", binary_voxel_grid[0])
                # o3d.io.write_voxel_grid("forward_pass_save/grid1.ply", binary_voxel_grid[0])
                # need to save these voxels
                save_voxel(binary_voxel_grid[0].cpu().numpy(), "forward_pass_save/grid0.binvox")
                save_voxel(binary_voxel_grid[1].cpu().numpy(), "forward_pass_save/grid1.binvox")
                from IPython import embed; embed()

            # use grid sample to get the visual touch tensor at these locations, NOTE: visual tensor features shape is (B, C, N)
            visual_tensor_features = utils_samp.bilinear_sample3D(featRs, sensor_depths_centers_in_mem[:, :, 0],
                sensor_depths_centers_in_mem[:, :, 1], sensor_depths_centers_in_mem[:, :, 2])
            visual_feature_tensor = visual_tensor_features.permute(0, 2, 1)
            # pack it
            visual_feature_tensor_ = __p(visual_feature_tensor)
            C = list(visual_feature_tensor.shape)[-1]
            print('C=', C)

            # do the metric learning this is the same as before.
            # the code is basically copied from embnet3d.py but some changes are being made very minor
            emb_vec = torch.stack((sensor_features_, visual_feature_tensor_), dim=1).view(B*self.num_samples*self.batch_k, C)
            y = torch.stack([torch.range(0,self.num_samples*B-1), torch.range(0,self.num_samples*B-1)], dim=1).view(self.num_samples*B*self.batch_k)
            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # I need to write my own version of margin loss since the negatives and anchors may not be same dim
            d_ap = torch.sqrt(torch.sum((positives - anchors)**2, dim=1) + 1e-8)
            pos_loss = torch.clamp(d_ap - beta + self._margin, min=0.0)

            # TODO: expand the dims of anchors and tile them and compute the negative loss

            # do the pair count where you average by contributors only

            # this is your total loss


            # Further idea is to check what volumetric locations do each of the depth images corresponds to
            # unproject the entire depth image and convert to ref. and then sample.

        if hyp.do_touch_forward:
            ## ... Begin code for getting crops from visual memory ... ##
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)
            sensor_depths_centers_in_world_ = utils_geom.apply_4x4(sensor_origin_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm
            ## this will be later used for visualization hence saving it here for now
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm

            sensor_depths_centers_in_camXs = __u(sensor_depths_centers_in_camXs_).squeeze(2)

            # There has to be a better way to do this, for each of the cameras in the batch I want a box of size (ch, cw, cd)
            # TODO: rotation is the deviation of the box from the axis aligned do I want this
            tB, tN, _ = list(sensor_depths_centers_in_camXs.shape)  # 2, 512, _
            boxlist = torch.zeros(tB, tN, 9)  # 2, 512, 9
            boxlist[:, :, :3] = sensor_depths_centers_in_camXs  # this lies on the object
            boxlist[:, :, 3:6] = torch.FloatTensor([hyp.contextW, hyp.contextH, hyp.contextD])

            # convert the boxlist to lrtlist and to cuda
            # the rt here transforms the from box coordinates to camera coordinates
            box_lrtlist = utils_geom.convert_boxlist_to_lrtlist(boxlist)

            # Now I will use crop_zoom_from_mem functionality to get the features in each of the boxes
            # I will do it for each of the box separately as required by the api
            context_grid_list = list()
            for m in range(box_lrtlist.shape[1]):
                curr_box = box_lrtlist[:, m, :]
                context_grid = utils_vox.crop_zoom_from_mem(featRs, curr_box, 8, 8, 8,
                    sensor_camRs_T_camXs[:, m, :, :])
                context_grid_list.append(context_grid)
            context_grid_list = torch.stack(context_grid_list, dim=1)
            context_grid_list_ = __p(context_grid_list)
            ## ... till here I believe I have not introduced any randomness, so the points are still in
            ## ... End code for getting crops around this center of certain height, width and depth ... ##

            ## ... Begin code for passing the context grid through 3D CNN to obtain a vector ... ##
            sensor_cam_locs = feed['sensor_locs']  # these are in origin coordinates
            sensor_cam_quats = feed['sensor_quats'] # this too in in world_coordinates
            sensor_cam_locs_ = __p(sensor_cam_locs)
            sensor_cam_quats_ = __p(sensor_cam_quats)
            sensor_cam_locs_in_R_ = utils_geom.apply_4x4(sensor_camRs_T_origin_, sensor_cam_locs_.unsqueeze(1)).squeeze(1)
            # TODO TODO TODO confirm that this is right? TODO TODO TODO
            get_r_mat = lambda cam_quat: transformations.quaternion_matrix_py(cam_quat)
            rot_mat_Xs_ = torch.from_numpy(np.stack(list(map(get_r_mat, sensor_cam_quats_.cpu().numpy())))).to(sensor_cam_locs_.device).float()
            rot_mat_Rs_ = torch.bmm(sensor_camRs_T_origin_, rot_mat_Xs_)
            get_quat = lambda r_mat: transformations.quaternion_from_matrix_py(r_mat)
            sensor_quats_in_R_ = torch.from_numpy(np.stack(list(map(get_quat, rot_mat_Rs_.cpu().numpy())))).to(sensor_cam_locs_.device).float()

            pred_features_ = self.context_net(context_grid_list_,\
                sensor_cam_locs_in_R_, sensor_quats_in_R_)

            # normalize
            pred_features_ = l2_normalize(pred_features_, dim=1)
            pred_features = __u(pred_features_)

            # if doing moco I have to pass the inputs through the key(slow) network as well #
            if hyp.do_moc or hyp.do_eval_recall:
                key_context_grid_list_ = copy.deepcopy(context_grid_list_)
                key_sensor_cam_locs_in_R_ = copy.deepcopy(sensor_cam_locs_in_R_)
                key_sensor_quats_in_R_ = copy.deepcopy(sensor_quats_in_R_)
                self.key_context_net.eval()
                with torch.no_grad():
                    key_pred_features_ = self.key_context_net(key_context_grid_list_,\
                        key_sensor_cam_locs_in_R_, key_sensor_quats_in_R_)

                # normalize, normalization is very important why though
                key_pred_features_ = l2_normalize(key_pred_features_, dim=1)
                key_pred_features = __u(key_pred_features_)
            # end passing of the input through the slow network this is necessary for moco #
            ## ... End code for passing the context grid through 3D CNN to obtain a vector ... ##

        ## ... Begin code for doing metric learning between pred_features and sensor features ... ##
        # 1. Subsample both based on the number of positive samples
        if hyp.do_touch_embML:
            assert(hyp.do_touch_forward)
            assert(hyp.do_touch_feat)
            perm = torch.randperm(len(pred_features_))  ## 1024
            chosen_sensor_feats_ = sensor_features_[perm[:self.num_pos_samples*hyp.B]]
            chosen_pred_feats_ = pred_features_[perm[:self.num_pos_samples*B]]

            # 2. form the emb_vec and get pos and negative samples for the batch
            emb_vec = torch.stack((chosen_sensor_feats_, chosen_pred_feats_), dim=1).view(hyp.B*self.num_pos_samples*self.batch_k, -1)
            y = torch.stack([torch.range(0, self.num_pos_samples*B-1), torch.range(0, self.num_pos_samples*B-1)],\
                dim=1).view(B*self.num_pos_samples*self.batch_k) # (0, 0, 1, 1, ..., 255, 255)

            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # 3. Compute the loss, ML loss and the l2 distance betwee the embeddings
            margin_loss, _ = self.criterion(anchors, positives, negatives, self.beta, y[a_indices])
            total_loss = utils_misc.add_loss('embtouch/emb_touch_ml_loss', total_loss, margin_loss,
                hyp.emb_3D_ml_coeff, summ_writer)

            # the l2 loss between the embeddings
            l2_loss = torch.nn.functional.mse_loss(chosen_sensor_feats_, chosen_pred_feats_)
            total_loss = utils_misc.add_loss('embtouch/emb_l2_loss', total_loss, l2_loss,
                hyp.emb_3D_l2_coeff, summ_writer)
        ## ... End code for doing metric learning between pred_features and sensor_features ... ##

        ## ... Begin code for doing moc inspired ML between pred_features and sensor_features ... ##
        if hyp.do_moc and moc_init_done:
            moc_loss = self.moc_ml_net(sensor_features_, key_sensor_features_,\
                pred_features_, key_pred_features_, summ_writer)
            total_loss += moc_loss
        ## ... End code for doing moc inspired ML between pred_features and sensor_feature ... ##

        ## ... add code for filling up results needed for eval recall ... ##
        if hyp.do_eval_recall and moc_init_done:
            results['context_features'] = pred_features_
            results['sensor_depth_centers_in_world'] = sensor_depths_centers_in_world_
            results['sensor_depths_centers_in_ref_cam'] = sensor_depths_centers_in_ref_cam_
            results['object_name'] = feed['object_name']

            # I will do precision recall here at different recall values and summarize it using tensorboard
            recalls = [1, 5, 10, 50, 100, 200]
            # also should not include any gradients because of this
            # fast_sensor_emb_e = sensor_features_
            # fast_context_emb_e = pred_features_
            # slow_sensor_emb_g = key_sensor_features_
            # slow_context_emb_g = key_context_features_
            fast_sensor_emb_e = sensor_features_.clone().detach()
            fast_context_emb_e = pred_features_.clone().detach()

            # I will do multiple eval recalls here
            slow_sensor_emb_g = key_sensor_features_.clone().detach()
            slow_context_emb_g = key_pred_features_.clone().detach()

            # assuming the above thing goes well
            fast_sensor_emb_e = fast_sensor_emb_e.cpu().numpy()
            fast_context_emb_e = fast_context_emb_e.cpu().numpy()
            slow_sensor_emb_g = slow_sensor_emb_g.cpu().numpy()
            slow_context_emb_g = slow_context_emb_g.cpu().numpy()

            # now also move the vis to numpy and plot it using matplotlib
            vis_e = __p(sensor_rgbs)
            vis_g = __p(sensor_rgbs)
            np_vis_e = vis_e.cpu().detach().numpy()
            np_vis_e = np.transpose(np_vis_e, [0, 2, 3, 1])
            np_vis_g = vis_g.cpu().detach().numpy()
            np_vis_g = np.transpose(np_vis_g, [0, 2, 3, 1])

            # bring it back to original color
            np_vis_g = ((np_vis_g+0.5) * 255).astype(np.uint8)
            np_vis_e = ((np_vis_e+0.5) * 255).astype(np.uint8)

            # now compare fast_sensor_emb_e with slow_context_emb_g
            # since I am doing positive against this
            fast_sensor_emb_e_list = [fast_sensor_emb_e, np_vis_e]
            slow_context_emb_g_list = [slow_context_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_emb_e_list, slow_context_emb_g_list, recalls=recalls
            )

            # finally plot the nearest neighbour retrieval and move ahead
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_slow_context')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_slow_context/recall@{re}',\
                    prec[pr])

            # now compare fast_context_emb_e with slow_sensor_emb_g
            fast_context_emb_e_list = [fast_context_emb_e, np_vis_e]
            slow_sensor_emb_g_list = [slow_sensor_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_context_emb_e_list, slow_sensor_emb_g_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_context_slow_sensor')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_context_slow_sensor/recall@{re}',\
                    prec[pr])


            # now finally compare both the fast, I presume we want them to go closer too
            fast_sensor_list = [fast_sensor_emb_e, np_vis_e]
            fast_context_list = [fast_context_emb_e, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_list, fast_context_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_fast_context')

            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_fast_context/recall@{re}',\
                    prec[pr])

        ## ... done code for filling up results needed for eval recall ... ##
        summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, [key_sensor_features_, key_pred_features_]
Esempio n. 39
0
def test_fl_mnist_example_training_can_be_translated(hook, workers):
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(784, 392)
            self.fc2 = nn.Linear(392, 10)

        def forward(self, x):
            x = self.fc1(x)
            x = F.relu(x)
            x = self.fc2(x)
            return x

    model = Net()

    def set_model_params(module, params_list, start_param_idx=0):
        """ Set params list into model recursively
        """
        param_idx = start_param_idx

        for name, param in module._parameters.items():
            module._parameters[name] = params_list[param_idx]
            param_idx += 1

        for name, child in module._modules.items():
            if child is not None:
                param_idx = set_model_params(child, params_list, param_idx)

        return param_idx

    def softmax_cross_entropy_with_logits(logits, targets, batch_size):
        """ Calculates softmax entropy
            Args:
                * logits: (NxC) outputs of dense layer
                * targets: (NxC) one-hot encoded labels
                * batch_size: value of N, temporarily required because Plan cannot trace .shape
        """
        # numstable logsoftmax
        norm_logits = logits - logits.max()
        log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
        # reduction = mean
        return -(targets * log_probs).sum() / batch_size

    def naive_sgd(param, **kwargs):
        return param - kwargs["lr"] * param.grad

    @sy.func2plan()
    def train(data, targets, lr, batch_size, model_parameters):
        # load model params
        set_model_params(model, model_parameters)

        # forward
        logits = model(data)

        # loss
        loss = softmax_cross_entropy_with_logits(logits, targets, batch_size)

        # backward
        loss.backward()

        # step
        updated_params = [naive_sgd(param, lr=lr) for param in model_parameters]

        # accuracy
        pred = th.argmax(logits, dim=1)
        targets_idx = th.argmax(targets, dim=1)
        acc = pred.eq(targets_idx).sum().float() / batch_size

        return (loss, acc, *updated_params)

    # Dummy inputs
    data = th.randn(3, 28 * 28)
    target = F.one_hot(th.tensor([1, 2, 3]), 10)
    lr = th.tensor([0.01])
    batch_size = th.tensor([3.0])
    model_state = list(model.parameters())

    # Build Plan
    train.build(data, target, lr, batch_size, model_state, trace_autograd=True)

    # Execute with original forward function (native torch autograd)
    res_torch = train(data, target, lr, batch_size, model_state)

    # Execute traced operations (traced autograd)
    train.forward = None
    res_syft_traced = train(data, target, lr, batch_size, model_state)

    # Translate syft Plan to torchscript and execute it
    train.add_translation(PlanTranslatorTorchscript)
    res_torchscript = train.torchscript(data, target, lr, batch_size, model_state)

    # (debug out)
    print(train.torchscript.code)

    # All variants should be equal
    for i, out in enumerate(res_torch):
        assert th.allclose(out, res_syft_traced[i])
        assert th.allclose(out, res_torchscript[i])
def allclose(x, y, tol=1e-5):
    return torch.allclose(x, y, atol=tol, rtol=tol)
 def check_trained_model(self, model, alternate_seed=False):
     # Checks a training seeded with learning_rate = 0.1
     (a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
     self.assertTrue(torch.allclose(model.a, a))
     self.assertTrue(torch.allclose(model.b, b))
Esempio n. 42
0
    t_encoder = torch.nn.TransformerEncoderLayer(N,
                                                 H,
                                                 dim_feedforward=4 * N,
                                                 dropout=0,
                                                 activation="relu")  #.cuda()
    t_encoder.self_attn.in_proj_weight.data = torch.cat(
        (op.WQ.transpose(0, 1).reshape(N, N), op.WK.transpose(0, 1).reshape(
            N, N), op.WV.transpose(0, 1).reshape(N, N)),
        dim=0)
    t_encoder.self_attn.in_proj_bias.data = torch.cat(
        (op.BQ.transpose(0, 1).reshape(N), op.BK.transpose(0, 1).reshape(N),
         op.BV.transpose(0, 1).reshape(N)))
    t_encoder.self_attn.out_proj.weight.data = op.WO.transpose(0, 2).reshape(
        N, N)
    t_encoder.self_attn.out_proj.bias.data = op.BO
    t_encoder.linear1.weight.data = op.linear1_w
    t_encoder.linear1.bias.data = op.linear1_b
    t_encoder.linear2.weight.data = op.linear2_w
    t_encoder.linear2.bias.data = op.linear2_b
    t_encoder.norm1.weight.data = op.norm1_scale
    t_encoder.norm1.bias.data = op.norm1_bias
    t_encoder.norm2.weight.data = op.norm2_scale
    t_encoder.norm2.bias.data = op.norm2_bias
    t_encoder.train()
    res_torch = t_encoder.forward(X)

    result = torch.allclose(res_dace, res_torch, rtol=1e-03, atol=1e-05)
    print('Result:', result)
    if not result:
        exit(1)
Esempio n. 43
0
        # flatten from N x 1024 x 1 x 1 to N x 1024, where N is the batch size

        x = torch.flatten(input=x, start_dim=1)

        return x


# check that the two state dicts are equal

if __name__ == '__main__':

    param_path = 'mx-h64-1024_0d3-1.17.pkl'

    net = Net(param_path)

    for key, value in net.old_state_dict.items():

        # skip layer 19

        if '19' in key:
            continue

        param1 = value
        param2 = net.state_dict()[net.map_param_name(key)]

        # torch.allclose() because parameters are floats

        is_equal = torch.allclose(param1, param2)

        print(is_equal)
Esempio n. 44
0
                "sliced_tensor_names": ["x", "target", "output"],
                # Define pipeline stage partition by specifying cut points.
                # 2-stage cut. It's a cut on tensor "12".
                "pipeline_cut_info_string": "12",
            },
            "allreduce_post_accumulation": True,
        },
    }
)

trainer = ORTTrainer(model, schema, adam_config, apply_loss, trainer_config)

loss_history = []
for i in range(5):
    l, p = trainer.train_step(x.to(cuda_device), y.to(cuda_device))
    loss_history.append(l)

print("loss history: ", loss_history)

# Valid ranks are [0, 1, 2, 3].
# [0, 2] forms the 2-stage pipeline in the 1st data parallel group.
# [1, 3] forms the 2-stage pipeline in the 2nd data parallel group.
last_pipeline_stage_ranks = [2, 3]

# The loss values computed at the last pipeline stages. Note that intermediate
# stages may not have valid loss values, so we don't check them.
expected_loss_history = [0.9420, 0.6608, 0.8944, 1.2279, 1.1173]
if rank in last_pipeline_stage_ranks:
    for result, expected in zip(loss_history, expected_loss_history):
        assert torch.allclose(result.cpu(), torch.Tensor([expected], device="cpu"), 1e-03)
Esempio n. 45
0
def test_vote_fusion():
    img_meta = {
        'ori_shape': (530, 730, 3),
        'img_shape': (600, 826, 3),
        'pad_shape': (608, 832, 3),
        'scale_factor':
        torch.tensor([1.1315, 1.1321, 1.1315, 1.1321]),
        'flip':
        False,
        'pcd_horizontal_flip':
        False,
        'pcd_vertical_flip':
        False,
        'pcd_trans':
        torch.tensor([0., 0., 0.]),
        'pcd_scale_factor':
        1.0308290128214932,
        'pcd_rotation':
        torch.tensor([[0.9747, 0.2234, 0.0000], [-0.2234, 0.9747, 0.0000],
                      [0.0000, 0.0000, 1.0000]]),
        'transformation_3d_flow': ['HF', 'R', 'S', 'T']
    }

    rt_mat = torch.tensor([[0.979570, 0.047954, -0.195330],
                           [0.047954, 0.887470, 0.458370],
                           [0.195330, -0.458370, 0.867030]])
    k_mat = torch.tensor([[529.5000, 0.0000, 365.0000],
                          [0.0000, 529.5000, 265.0000],
                          [0.0000, 0.0000, 1.0000]])
    rt_mat = rt_mat.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]
                                ]) @ rt_mat.transpose(1, 0)
    depth2img = k_mat @ rt_mat
    img_meta['depth2img'] = depth2img

    bboxes = torch.tensor([[[
        5.4286e+02, 9.8283e+01, 6.1700e+02, 1.6742e+02, 9.7922e-01, 3.0000e+00
    ], [
        4.2613e+02, 8.4646e+01, 4.9091e+02, 1.6237e+02, 9.7848e-01, 3.0000e+00
    ], [
        2.5606e+02, 7.3244e+01, 3.7883e+02, 1.8471e+02, 9.7317e-01, 3.0000e+00
    ], [
        6.0104e+02, 1.0648e+02, 6.6757e+02, 1.9216e+02, 8.4607e-01, 3.0000e+00
    ], [
        2.2923e+02, 1.4984e+02, 7.0163e+02, 4.6537e+02, 3.5719e-01, 0.0000e+00
    ], [
        2.5614e+02, 7.4965e+01, 3.3275e+02, 1.5908e+02, 2.8688e-01, 3.0000e+00
    ], [
        9.8718e+00, 1.4142e+02, 2.0213e+02, 3.3878e+02, 1.0935e-01, 3.0000e+00
    ], [
        6.1930e+02, 1.1768e+02, 6.8505e+02, 2.0318e+02, 1.0720e-01, 3.0000e+00
    ]]])

    seeds_3d = torch.tensor([[[0.044544, 1.675476, -1.531831],
                              [2.500625, 7.238662, -0.737675],
                              [-0.600003, 4.827733, -0.084022],
                              [1.396212, 3.994484, -1.551180],
                              [-2.054746, 2.012759, -0.357472],
                              [-0.582477, 6.580470, -1.466052],
                              [1.313331, 5.722039, 0.123904],
                              [-1.107057, 3.450359, -1.043422],
                              [1.759746, 5.655951, -1.519564],
                              [-0.203003, 6.453243, 0.137703],
                              [-0.910429, 0.904407, -0.512307],
                              [0.434049, 3.032374, -0.763842],
                              [1.438146, 2.289263, -1.546332],
                              [0.575622, 5.041906, -0.891143],
                              [-1.675931, 1.417597, -1.588347]]])

    imgs = torch.linspace(-1, 1, steps=608 * 832).reshape(1, 608,
                                                          832).repeat(3, 1,
                                                                      1)[None]

    expected_tensor1 = torch.tensor(
        [[[
            0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00,
            0.000000e+00, 1.193706e-01, -0.000000e+00, -2.879214e-01,
            -0.000000e+00, 0.000000e+00, 1.422463e-01, -6.474612e-01,
            -0.000000e+00, 1.490057e-02, 0.000000e+00
        ],
          [
              0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              0.000000e+00, -1.873745e+00, -0.000000e+00, 1.576240e-01,
              0.000000e+00, -0.000000e+00, -3.646177e-02, -7.751858e-01,
              0.000000e+00, 9.593642e-02, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, -6.263277e-02, 0.000000e+00, -3.646387e-01,
              0.000000e+00, 0.000000e+00, -5.875812e-01, -6.263450e-02,
              0.000000e+00, 1.149264e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 8.899736e-01, 0.000000e+00, 9.019017e-01,
              0.000000e+00, 0.000000e+00, 6.917775e-01, 8.899733e-01,
              0.000000e+00, 9.812444e-01, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -4.516903e-01, -0.000000e+00, -2.315422e-01,
              -0.000000e+00, -0.000000e+00, -4.197519e-01, -4.516906e-01,
              -0.000000e+00, -1.547615e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 3.571937e-01, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 3.571937e-01,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 9.731653e-01,
              0.000000e+00, 0.000000e+00, 1.093455e-01, 0.000000e+00,
              0.000000e+00, 8.460656e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ]]])

    expected_tensor2 = torch.tensor([[
        False, False, False, False, False, True, False, True, False, False,
        True, True, False, True, False, False, False, False, False, False,
        False, False, True, False, False, False, False, False, True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, True, False
    ]])

    expected_tensor3 = torch.tensor(
        [[[
            -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
            0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
            -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00,
            -0.000000e+00, 1.720988e-01, 0.000000e+00
        ],
          [
              0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00,
              0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 4.824460e-02, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, 1.447314e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 9.759269e-01, 0.000000e+00
          ],
          [
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00,
              -0.000000e+00, -1.631542e-01, -0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 1.072001e-01, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
              0.000000e+00, 0.000000e+00, 0.000000e+00
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ],
          [
              2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04,
              -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03,
              -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03,
              2.540967e-03, -1.834944e-03, 1.032048e-03
          ]]])

    fusion = VoteFusion()
    out1, out2 = fusion(imgs, bboxes, seeds_3d, [img_meta])
    assert torch.allclose(expected_tensor1, out1[:, :, :15], 1e-3)
    assert torch.allclose(expected_tensor2.float(), out2.float(), 1e-3)
    assert torch.allclose(expected_tensor3, out1[:, :, 30:45], 1e-3)

    out1, out2 = fusion(imgs, bboxes[:, :2], seeds_3d, [img_meta])
    out1 = out1[:, :15, 30:45]
    out2 = out2[:, 30:45].float()
    assert torch.allclose(torch.zeros_like(out1), out1, 1e-3)
    assert torch.allclose(torch.zeros_like(out2), out2, 1e-3)
Esempio n. 46
0
    def _test_state_dict(self, weight, bias, input, constructor):
        weight = Variable(weight, requires_grad=True)
        bias = Variable(bias, requires_grad=True)
        input = Variable(input)

        def fn_base(optimizer, weight, bias):
            optimizer.zero_grad()
            i = input_cuda if weight.is_cuda else input
            loss = (weight.mv(i) + bias).pow(2).sum()
            loss.backward()
            return loss

        optimizer = constructor(weight, bias)
        fn = functools.partial(fn_base, optimizer, weight, bias)

        # Prime the optimizer
        for _i in range(20):
            optimizer.step(fn)
        # Clone the weights and construct new optimizer for them
        weight_c = Variable(weight.data.clone(), requires_grad=True)
        bias_c = Variable(bias.data.clone(), requires_grad=True)
        optimizer_c = constructor(weight_c, bias_c)
        fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
        # Load state dict
        state_dict = deepcopy(optimizer.state_dict())
        state_dict_c = deepcopy(optimizer.state_dict())
        optimizer_c.load_state_dict(state_dict_c)

        precision = 0.0001
        # Run both optimizations in parallel
        for _i in range(20):
            optimizer.step(fn)
            optimizer_c.step(fn_c)
            assert torch.allclose(weight, weight_c, atol=precision)
            assert torch.allclose(bias, bias_c, atol=precision)

        # Make sure state dict wasn't modified
        assert assert_dict_equal(state_dict, state_dict_c)

        # Check that state dict can be loaded even when we cast parameters
        # to a different type and move to a different device.
        if not torch.cuda.is_available():
            return

        input_cuda = Variable(input.data.float().cuda())
        weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
        bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
        optimizer_cuda = constructor(weight_cuda, bias_cuda)
        fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
                                    bias_cuda)

        state_dict = deepcopy(optimizer.state_dict())
        state_dict_c = deepcopy(optimizer.state_dict())
        optimizer_cuda.load_state_dict(state_dict_c)

        # Make sure state dict wasn't modified
        assert assert_dict_equal(state_dict, state_dict_c)

        for _i in range(20):
            optimizer.step(fn)
            optimizer_cuda.step(fn_cuda)
            assert weight == weight_cuda
            assert bias == bias_cuda

        # validate deepcopy() copies all public attributes
        def getPublicAttr(obj):
            return set(k for k in obj.__dict__ if not k.startswith('_'))

        assert getPublicAttr(optimizer) == getPublicAttr(deepcopy(optimizer))
Esempio n. 47
0
 def test_mean(self):
     mu = mean(self.mu, self.A, self.b)
     self.assertTrue(torch.allclose(mu, self.mc_mu, rtol=1e-1))
     return self.mc_mu, mu
Esempio n. 48
0
 def test_covariance(self):
     cov = covariance(self.cov, self.A)
     self.assertTrue(torch.allclose(cov, self.mc_cov, rtol=1e-1))
     return self.mc_cov, cov
def test_neighbor_sampler_on_cora(get_dataset):
    dataset = get_dataset(name='Cora')
    data = dataset[0]

    batch = torch.arange(10)
    loader = NeighborSampler(data.edge_index,
                             sizes=[-1, -1, -1],
                             node_idx=batch,
                             batch_size=10)

    class SAGE(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()

            self.convs = torch.nn.ModuleList()
            self.convs.append(SAGEConv(in_channels, 16))
            self.convs.append(SAGEConv(16, 16))
            self.convs.append(SAGEConv(16, out_channels))

        def batch(self, x, adjs):
            for i, (edge_index, _, size) in enumerate(adjs):
                x_target = x[:size[1]]  # Target nodes are always placed first.
                x = self.convs[i]((x, x_target), edge_index)
            return x

        def full(self, x, edge_index):
            for conv in self.convs:
                x = conv(x, edge_index)
            return x

    model = SAGE(dataset.num_features, dataset.num_classes)

    _, n_id, adjs = next(iter(loader))
    out1 = model.batch(data.x[n_id], adjs)
    out2 = model.full(data.x, data.edge_index)[batch]
    assert torch.allclose(out1, out2, atol=1e-7)

    class GAT(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()

            self.convs = torch.nn.ModuleList()
            self.convs.append(GATConv(in_channels, 16, heads=2))
            self.convs.append(GATConv(32, 16, heads=2))
            self.convs.append(GATConv(32, out_channels, heads=2, concat=False))

        def batch(self, x, adjs):
            for i, (edge_index, _, size) in enumerate(adjs):
                x_target = x[:size[1]]  # Target nodes are always placed first.
                x = self.convs[i]((x, x_target), edge_index)
            return x

        def full(self, x, edge_index):
            for conv in self.convs:
                x = conv(x, edge_index)
            return x

    _, n_id, adjs = next(iter(loader))
    out1 = model.batch(data.x[n_id], adjs)
    out2 = model.full(data.x, data.edge_index)[batch]
    assert torch.allclose(out1, out2, atol=1e-7)
Esempio n. 50
0
    def test_condition_on_observations(self):
        for (train_iteration_fidelity, train_data_fidelity) in [
            (False, True),
            (True, False),
            (True, True),
        ]:
            for batch_shape in (torch.Size(), torch.Size([2])):
                for num_outputs in (1, 2):
                    for double in (False, True):
                        num_dim = 1 + train_iteration_fidelity + train_data_fidelity
                        tkwargs = {
                            "device": self.device,
                            "dtype": torch.double if double else torch.float,
                        }
                        model, model_kwargs = self._get_model_and_data(
                            batch_shape=batch_shape,
                            num_outputs=num_outputs,
                            train_iteration_fidelity=train_iteration_fidelity,
                            train_data_fidelity=train_data_fidelity,
                            **tkwargs,
                        )
                        # evaluate model
                        model.posterior(
                            torch.rand(torch.Size([4, num_dim]), **tkwargs))
                        # test condition_on_observations
                        fant_shape = torch.Size([2])
                        # fantasize at different input points
                        X_fant, Y_fant = _get_random_data_with_fidelity(
                            fant_shape + batch_shape,
                            num_outputs,
                            n=3,
                            train_iteration_fidelity=train_iteration_fidelity,
                            train_data_fidelity=train_data_fidelity,
                            **tkwargs,
                        )
                        c_kwargs = ({
                            "noise": torch.full_like(Y_fant, 0.01)
                        } if isinstance(model, FixedNoiseGP) else {})
                        cm = model.condition_on_observations(
                            X_fant, Y_fant, **c_kwargs)
                        # fantasize at different same input points
                        c_kwargs_same_inputs = ({
                            "noise":
                            torch.full_like(Y_fant[0], 0.01)
                        } if isinstance(model, FixedNoiseGP) else {})
                        cm_same_inputs = model.condition_on_observations(
                            X_fant[0], Y_fant, **c_kwargs_same_inputs)

                        test_Xs = [
                            # test broadcasting single input across fantasy and
                            # model batches
                            torch.rand(4, num_dim, **tkwargs),
                            # separate input for each model batch and broadcast across
                            # fantasy batches
                            torch.rand(batch_shape + torch.Size([4, num_dim]),
                                       **tkwargs),
                            # separate input for each model and fantasy batch
                            torch.rand(
                                fant_shape + batch_shape +
                                torch.Size([4, num_dim]),
                                **tkwargs,
                            ),
                        ]
                        for test_X in test_Xs:
                            posterior = cm.posterior(test_X)
                            self.assertEqual(
                                posterior.mean.shape,
                                fant_shape + batch_shape +
                                torch.Size([4, num_outputs]),
                            )
                            posterior_same_inputs = cm_same_inputs.posterior(
                                test_X)
                            self.assertEqual(
                                posterior_same_inputs.mean.shape,
                                fant_shape + batch_shape +
                                torch.Size([4, num_outputs]),
                            )

                            # check that fantasies of batched model are correct
                            if len(batch_shape) > 0 and test_X.dim() == 2:
                                state_dict_non_batch = {
                                    key: (val[0] if val.numel() > 1 else val)
                                    for key, val in model.state_dict().items()
                                }
                                model_kwargs_non_batch = {
                                    "train_X":
                                    model_kwargs["train_X"][0],
                                    "train_Y":
                                    model_kwargs["train_Y"][0],
                                    "train_iteration_fidelity":
                                    model_kwargs["train_iteration_fidelity"],
                                    "train_data_fidelity":
                                    model_kwargs["train_data_fidelity"],
                                }
                                if "train_Yvar" in model_kwargs:
                                    model_kwargs_non_batch[
                                        "train_Yvar"] = model_kwargs[
                                            "train_Yvar"][0]
                                model_non_batch = type(model)(
                                    **model_kwargs_non_batch)
                                model_non_batch.load_state_dict(
                                    state_dict_non_batch)
                                model_non_batch.eval()
                                model_non_batch.likelihood.eval()
                                model_non_batch.posterior(
                                    torch.rand(torch.Size([4, num_dim]),
                                               **tkwargs))
                                c_kwargs = ({
                                    "noise":
                                    torch.full_like(Y_fant[0, 0, :], 0.01)
                                } if isinstance(model, FixedNoiseGP) else {})
                                mnb = model_non_batch
                                cm_non_batch = mnb.condition_on_observations(
                                    X_fant[0][0], Y_fant[:, 0, :], **c_kwargs)
                                non_batch_posterior = cm_non_batch.posterior(
                                    test_X)
                                self.assertTrue(
                                    torch.allclose(
                                        posterior_same_inputs.mean[:, 0, ...],
                                        non_batch_posterior.mean,
                                        atol=1e-3,
                                    ))
                                self.assertTrue(
                                    torch.allclose(
                                        posterior_same_inputs.mvn.
                                        covariance_matrix[:, 0, :, :],
                                        non_batch_posterior.mvn.
                                        covariance_matrix,
                                        atol=1e-3,
                                    ))
Esempio n. 51
0
def test_label_is_same(transform):
    a = torch.arange(20).reshape(1, 1, 4, 5).float()
    for p in transform.params:
        aug = transform.apply_aug_image(a, **{transform.pname: p})
        deaug = transform.apply_deaug_label(aug, **{transform.pname: p})
        assert torch.allclose(aug, deaug)
Esempio n. 52
0
def hindsight_relabel_fn(buffer,
                         result,
                         info,
                         her_proportion,
                         achieved_goal_field="observation.achieved_goal",
                         desired_goal_field="observation.desired_goal",
                         reward_fn=l2_dist_close_reward_fn):
    """Randomly get `batch_size` hindsight relabeled trajectories.

    Note: The environments where the sampels are from are ordered in the
        returned batch.

    Args:
        buffer (ReplayBuffer): for access to future achieved goals.
        result (nest): of tensors of the sampled exp
        info (BatchInfo): of the sampled result
        her_proportion (float): proportion of hindsight relabeled experience.
        achieved_goal_field (str): path to the achieved_goal field in
            exp nest.
        desired_goal_field (str): path to the desired_goal field in the
            exp nest.
        reward_fn (Callable): function to recompute reward based on
            achieve_goal and desired_goal.  Default gives reward 0 when
            L2 distance less than 0.05 and -1 otherwise, same as is done in
            suite_robotics environments.
    Returns:
        tuple:
            - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...]
            - BatchInfo: Information about the batch. Its shapes are [batch_size].
                - env_ids: environment id for each sequence
                - positions: starting position in the replay buffer for each sequence.
                - importance_weights: priority divided by the average of all
                    non-zero priorities in the buffer.
    """
    if her_proportion == 0:
        return result

    env_ids = info.env_ids
    start_pos = info.positions
    shape = alf.nest.get_nest_shape(result)
    batch_size, batch_length = shape[:2]
    # TODO: add support for batch_length > 2.
    assert batch_length == 2, shape

    # relabel only these sampled indices
    her_cond = torch.rand(batch_size) < her_proportion
    (her_indices, ) = torch.where(her_cond)
    (non_her_indices, ) = torch.where(torch.logical_not(her_cond))

    last_step_pos = start_pos[her_indices] + batch_length - 1
    last_env_ids = env_ids[her_indices]
    # Get x, y indices of LAST steps
    dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids)
    if alf.summary.should_record_summaries():
        alf.summary.scalar(
            "replayer/" + buffer._name + ".mean_steps_to_episode_end",
            torch.mean(dist.type(torch.float32)))

    # get random future state
    future_idx = buffer.circular(last_step_pos + (torch.rand(*dist.shape) *
                                                  (dist + 1)).to(torch.int64))
    achieved_goals = alf.nest.get_field(buffer._buffer, achieved_goal_field)
    future_ag = achieved_goals[(last_env_ids, future_idx)].unsqueeze(1)

    # relabel desired goal
    result_desired_goal = alf.nest.get_field(result, desired_goal_field)
    relabed_goal = result_desired_goal.clone()
    her_batch_index_tuple = (her_indices.unsqueeze(1),
                             torch.arange(batch_length).unsqueeze(0))
    relabed_goal[her_batch_index_tuple] = future_ag

    # recompute rewards
    result_ag = alf.nest.get_field(result, achieved_goal_field)
    relabeled_rewards = reward_fn(result_ag,
                                  relabed_goal,
                                  device=buffer._device)
    if alf.summary.should_record_summaries():
        alf.summary.scalar(
            "replayer/" + buffer._name + ".reward_mean_before_relabel",
            torch.mean(result.reward[her_indices][:-1]))
        alf.summary.scalar(
            "replayer/" + buffer._name + ".reward_mean_after_relabel",
            torch.mean(relabeled_rewards[her_indices][:-1]))
    # assert reward function is the same as used by the environment.
    if not torch.allclose(relabeled_rewards[non_her_indices],
                          result.reward[non_her_indices]):
        msg = ("hindsight_relabel_fn:\nrelabeled_reward\n{}\n!=\n" +
               "env_reward\n{}\nag:\n{}\ndg:\n{}\nenv_ids:\n{}\nstart_pos:\n{}"
               ).format(relabeled_rewards[non_her_indices],
                        result.reward[non_her_indices],
                        result_ag[non_her_indices],
                        result_desired_goal[non_her_indices],
                        env_ids[non_her_indices], start_pos[non_her_indices])
        logging.warning(msg)
        # assert False, msg
        relabeled_rewards[non_her_indices] = result.reward[non_her_indices]

    result = alf.nest.transform_nest(result, desired_goal_field,
                                     lambda _: relabed_goal)
    result = result._replace(reward=relabeled_rewards)
    return result, info
Esempio n. 53
0
 def test_mean(self):
     single_mean = [mean(c, self.A, self.b) for c in self.batch_mean]
     batch_mean = mean(self.batch_mean, self.A, self.b)
     close = [torch.allclose(r, c) for r, c in zip(batch_mean, single_mean)]
     self.assertNotIn(False, close)
     return single_mean, batch_mean
Esempio n. 54
0
def test_add_transform():
    transform = tta.Add(values=[-1, 0, 1])
    a = torch.arange(20).reshape(1, 1, 4, 5).float()
    for p in transform.params:
        aug = transform.apply_aug_image(a, **{transform.pname: p})
        assert torch.allclose(aug, a + p)
Esempio n. 55
0
 def test_variance(self):
     single_cov = [variance(c, self.A) for c in self.batch_cov]
     batch_cov = variance(self.batch_cov, self.A)
     close = [torch.allclose(r, c) for r, c in zip(batch_cov, single_cov)]
     self.assertNotIn(False, close)
     return single_cov, batch_cov
Esempio n. 56
0
def test_multiply_transform():
    transform = tta.Multiply(factors=[-1, 0, 1])
    a = torch.arange(20).reshape(1, 1, 4, 5).float()
    for p in transform.params:
        aug = transform.apply_aug_image(a, **{transform.pname: p})
        assert torch.allclose(aug, a * p)
Esempio n. 57
0
 def test_variance(self):
     var = variance(self.cov, self.A)
     self.assertTrue(torch.allclose(var, self.mc_var, rtol=1e-1))
     return self.mc_var, var
Esempio n. 58
0
    def test_condition_on_observations(self):
        for batch_shape, dtype in itertools.product(
            (torch.Size(), torch.Size([2])), (torch.float, torch.double)):
            tkwargs = {"device": self.device, "dtype": dtype}
            X_dim = 2

            model, model_kwargs = self._get_model_and_data(
                batch_shape=batch_shape, X_dim=X_dim, **tkwargs)
            train_X = model_kwargs["datapoints"]
            train_comp = model_kwargs["comparisons"]

            # evaluate model
            model.posterior(torch.rand(torch.Size([4, X_dim]), **tkwargs))
            # test condition_on_observations

            # test condition_on_observations with prior mode
            prior_m = PairwiseGP(None, None)
            cond_m = prior_m.condition_on_observations(train_X, train_comp)
            self.assertTrue(cond_m.datapoints is train_X)
            self.assertTrue(cond_m.comparisons is train_comp)

            # fantasize at different input points
            fant_shape = torch.Size([2])
            X_fant, Y_fant, comp_fant = self._make_rand_mini_data(
                batch_shape=fant_shape + batch_shape, X_dim=X_dim, **tkwargs)

            # cannot condition on non-pairwise Ys
            with self.assertRaises(RuntimeError):
                model.condition_on_observations(X_fant, comp_fant[..., 0])
            cm = model.condition_on_observations(X_fant, comp_fant)
            # make sure it's a deep copy
            self.assertTrue(model is not cm)

            # fantasize at same input points (check proper broadcasting)
            cm_same_inputs = model.condition_on_observations(
                X_fant[0], comp_fant)

            test_Xs = [
                # test broadcasting single input across fantasy and model batches
                torch.rand(4, X_dim, **tkwargs),
                # separate input for each model batch and broadcast across
                # fantasy batches
                torch.rand(batch_shape + torch.Size([4, X_dim]), **tkwargs),
                # separate input for each model and fantasy batch
                torch.rand(fant_shape + batch_shape + torch.Size([4, X_dim]),
                           **tkwargs),
            ]
            for test_X in test_Xs:
                posterior = cm.posterior(test_X)
                self.assertEqual(posterior.mean.shape,
                                 fant_shape + batch_shape + torch.Size([4, 1]))
                posterior_same_inputs = cm_same_inputs.posterior(test_X)
                self.assertEqual(
                    posterior_same_inputs.mean.shape,
                    fant_shape + batch_shape + torch.Size([4, 1]),
                )

                # check that fantasies of batched model are correct
                if len(batch_shape) > 0 and test_X.dim() == 2:
                    state_dict_non_batch = {
                        key: (val[0] if val.numel() > 1 else val)
                        for key, val in model.state_dict().items()
                    }
                    model_kwargs_non_batch = {
                        "datapoints": model_kwargs["datapoints"][0],
                        "comparisons": model_kwargs["comparisons"][0],
                    }
                    model_non_batch = type(model)(**model_kwargs_non_batch)
                    model_non_batch.load_state_dict(state_dict_non_batch)
                    model_non_batch.eval()
                    model_non_batch.posterior(
                        torch.rand(torch.Size([4, X_dim]), **tkwargs))
                    cm_non_batch = model_non_batch.condition_on_observations(
                        X_fant[0][0], comp_fant[:, 0, :])
                    non_batch_posterior = cm_non_batch.posterior(test_X)
                    self.assertTrue(
                        torch.allclose(
                            posterior_same_inputs.mean[:, 0, ...],
                            non_batch_posterior.mean,
                            atol=1e-3,
                        ))
                    self.assertTrue(
                        torch.allclose(
                            posterior_same_inputs.mvn.
                            covariance_matrix[:, 0, :, :],
                            non_batch_posterior.mvn.covariance_matrix,
                            atol=1e-3,
                        ))
Esempio n. 59
0
    def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim):
        model_1, optimizer_1, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_1)
        data_loader = random_dataloader(model=model_1,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model_1.device)
        for n, batch in enumerate(data_loader):
            loss = model_1(batch[0], batch[1])
            model_1.backward(loss)
            model_1.step()
        # Test whether momentum mask still exist after saving checkpoint
        assert optimizer_1.optimizer.lamb_freeze_key is True
        mask1 = mask1.to(
            device=optimizer_1.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'],
                              mask1,
                              atol=1e-07), f"Incorrect momentum mask"
        scaling_coeff_1 = []
        for v in optimizer_1.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_1.append(v['scaling_coeff'])
        save_folder = os.path.join(tmpdir, 'saved_checkpoint')
        model_1.save_checkpoint(save_folder, tag=None)
        assert torch.allclose(
            optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07
        ), f"Momentum mask should not change after saving checkpoint"

        model_2, optimizer_2, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_2)
        # Test whether momentum mask stays the same after loading checkpoint
        mask2 = mask2.to(
            device=optimizer_2.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'],
                              mask2,
                              atol=1e-07), f"Incorrect momentum mask"
        model_2.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert torch.allclose(
            optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07
        ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_2.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_2.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs is loaded correctly
        scaling_coeff_2 = []
        for v in optimizer_2.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_2.append(v['scaling_coeff'])
        assert list(sorted(scaling_coeff_2)) == list(
            sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs"
        assert optimizer_2.optimizer.lamb_freeze_key is True

        model_3, optimizer_3, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_3)
        optimizer_3.optimizer.freeze_step = 20
        data_loader = random_dataloader(model=model_3,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model_3.device)
        for n, batch in enumerate(data_loader):
            loss = model_3(batch[0], batch[1])
            model_3.backward(loss)
            model_3.step()
        assert optimizer_3.optimizer.lamb_freeze_key is True
        # Test whether momentum mask stays the same after loading checkpoint
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Incorrect momentum mask"
        model_3.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_3.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_3.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted
        for v in optimizer_3.state.values():
            assert v[
                'lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze"
            assert v['last_factor'] == 1.0, f"Incorrect last_factor"
            assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff"
        assert optimizer_3.optimizer.lamb_freeze_key is False
Esempio n. 60
0
    def test_pairwise_gp(self):
        for batch_shape, dtype in itertools.product(
            (torch.Size(), torch.Size([2])), (torch.float, torch.double)):
            tkwargs = {"device": self.device, "dtype": dtype}
            X_dim = 2

            model, model_kwargs = self._get_model_and_data(
                batch_shape=batch_shape, X_dim=X_dim, **tkwargs)
            train_X = model_kwargs["datapoints"]
            train_comp = model_kwargs["comparisons"]

            # test training
            # regular training
            mll = PairwiseLaplaceMarginalLogLikelihood(model).to(**tkwargs)
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=OptimizationWarning)
                fit_gpytorch_model(mll, options={"maxiter": 2}, max_retries=1)
            # prior training
            prior_m = PairwiseGP(None, None)
            with self.assertRaises(RuntimeError):
                prior_m(train_X)
            # forward in training mode with non-training data
            custom_m = PairwiseGP(**model_kwargs)
            other_X = torch.rand(batch_shape + torch.Size([3, X_dim]),
                                 **tkwargs)
            other_comp = train_comp.clone()
            with self.assertRaises(RuntimeError):
                custom_m(other_X)
            custom_mll = PairwiseLaplaceMarginalLogLikelihood(custom_m).to(
                **tkwargs)
            post = custom_m(train_X)
            with self.assertRaises(RuntimeError):
                custom_mll(post, other_comp)

            # setting jitter = 0 with a singular covar will raise error
            sing_train_X = torch.ones(batch_shape + torch.Size([10, X_dim]),
                                      **tkwargs)
            with self.assertRaises(RuntimeError):
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=RuntimeWarning)
                    custom_m = PairwiseGP(sing_train_X, train_comp, jitter=0)
                    custom_m.posterior(sing_train_X)

            # test init
            self.assertIsInstance(model.mean_module, ConstantMean)
            self.assertIsInstance(model.covar_module, RBFKernel)
            self.assertIsInstance(model.covar_module.lengthscale_prior,
                                  GammaPrior)
            self.assertEqual(model.num_outputs, 1)

            # test custom models
            custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel())
            self.assertIsInstance(custom_m.covar_module, LinearKernel)
            # std_noise setter
            custom_m.std_noise = 123
            self.assertTrue(torch.all(custom_m.std_noise == 123))
            # prior prediction
            prior_m = PairwiseGP(None, None)
            prior_m.eval()
            post = prior_m.posterior(train_X)
            self.assertIsInstance(post, GPyTorchPosterior)

            # test methods that are not commonly or explicitly used
            # _calc_covar with observation noise
            no_noise_cov = model._calc_covar(train_X,
                                             train_X,
                                             observation_noise=False)
            noise_cov = model._calc_covar(train_X,
                                          train_X,
                                          observation_noise=True)
            diag_diff = (noise_cov - no_noise_cov).diagonal(dim1=-2, dim2=-1)
            self.assertTrue(
                torch.allclose(
                    diag_diff,
                    model.std_noise.expand(diag_diff.shape),
                    rtol=1e-4,
                    atol=1e-5,
                ))
            # test trying adding jitter
            pd_mat = torch.eye(2, 2)
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=RuntimeWarning)
                jittered_pd_mat = model._add_jitter(pd_mat)
            diag_diff = (jittered_pd_mat - pd_mat).diagonal(dim1=-2, dim2=-1)
            self.assertTrue(
                torch.allclose(
                    diag_diff,
                    torch.full_like(diag_diff, model._jitter),
                    atol=model._jitter / 10,
                ))

            # test initial utility val
            util_comp = torch.topk(model.utility, k=2,
                                   dim=-1).indices.unsqueeze(-2)
            self.assertTrue(torch.all(util_comp == train_comp))

            # test posterior
            # test non batch evaluation
            X = torch.rand(batch_shape + torch.Size([3, X_dim]), **tkwargs)
            expected_shape = batch_shape + torch.Size([3, 1])
            posterior = model.posterior(X)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, expected_shape)
            self.assertEqual(posterior.variance.shape, expected_shape)

            # expect to raise error when output_indices is not None
            with self.assertRaises(RuntimeError):
                model.posterior(X, output_indices=[0])

            # test re-evaluating utility when it's None
            model.utility = None
            posterior = model.posterior(X)
            self.assertIsInstance(posterior, GPyTorchPosterior)

            # test adding observation noise
            posterior_pred = model.posterior(X, observation_noise=True)
            self.assertIsInstance(posterior_pred, GPyTorchPosterior)
            self.assertEqual(posterior_pred.mean.shape, expected_shape)
            self.assertEqual(posterior_pred.variance.shape, expected_shape)
            pvar = posterior_pred.variance
            reshaped_noise = model.std_noise.unsqueeze(-2).expand(
                posterior.variance.shape)
            pvar_exp = posterior.variance + reshaped_noise
            self.assertTrue(
                torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5))

            # test batch evaluation
            X = torch.rand(2, *batch_shape, 3, X_dim, **tkwargs)
            expected_shape = torch.Size([2]) + batch_shape + torch.Size([3, 1])

            posterior = model.posterior(X)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, expected_shape)
            # test adding observation noise in batch mode
            posterior_pred = model.posterior(X, observation_noise=True)
            self.assertIsInstance(posterior_pred, GPyTorchPosterior)
            self.assertEqual(posterior_pred.mean.shape, expected_shape)
            pvar = posterior_pred.variance
            reshaped_noise = model.std_noise.unsqueeze(-2).expand(
                posterior.variance.shape)
            pvar_exp = posterior.variance + reshaped_noise
            self.assertTrue(
                torch.allclose(pvar, pvar_exp, rtol=1e-4, atol=1e-5))