def test_mmd2_various_ranked_inputs(self, device, xx, yy, xy): inmap = {'mat': _mat, 'vec': _vec, 'sca': _sca} derivs = {'mat': [[0, 0.5], [0.5, 0]], 'vec': [0.5, 0.5], 'sca': 1} derivs_xy = { 'mat': [[-0.5, -0.5], [-0.5, -0.5]], 'vec': [-1, -1], 'sca': -2 } mmd2_in = inmap[xx](device), inmap[yy](device), inmap[xy](device) print(mmd2_in) res = mmd2(*mmd2_in) if xy == 'mat': assert (res == -2.5) else: assert (res == 0) res.backward() T.testing.assert_allclose(mmd2_in[0].grad, T.tensor(derivs[xx], device=device, dtype=T.float64), rtol=1e-06, atol=1e-12) T.testing.assert_allclose(mmd2_in[1].grad, T.tensor(derivs[yy], device=device, dtype=T.float64), rtol=1e-06, atol=1e-12) T.testing.assert_allclose(mmd2_in[2].grad, T.tensor(derivs_xy[xy], device=device, dtype=T.float64), rtol=1e-06, atol=1e-12)
def test_mmd2_gener(self, device): res = mmd2_gener(_mat(device), _mat(device)) res_exp = mmd2(None, _mat(device), _mat(device)) assert (res == res_exp) res.backward()
def test_PQk_is_None(self, device): with pytest.raises(AssertionError): mmd2(_mat(device), _mat(device), None)
def test_PPk_is_None(self, device): res = mmd2(None, _mat(device), _vec(device)) assert (res == -2.5) res.backward()