Exemple #1
0
def test_newton_step_trust(trust_radius, dims):
    batch_size = 100
    batch_shape = torch.Size((batch_size, ))
    mode = random_inside_unit_circle(batch_shape + (dims, ),
                                     requires_grad=True) + 1
    x = random_inside_unit_circle(batch_shape +
                                  (dims, ), requires_grad=True) - 1

    # create a quadratic loss function
    noise = torch.randn(batch_size, dims, dims)
    hessian = noise + noise.transpose(-1, -2)
    diff = (x - mode).unsqueeze(-2)
    loss = 0.5 * diff.bmm(hessian).bmm(diff.transpose(-1, -2)).sum()

    # run method under test
    x_updated, cov = newton_step(loss, x, trust_radius=trust_radius)

    # check shapes
    assert x_updated.shape == x.shape
    assert cov.shape == hessian.shape

    # check values
    if trust_radius is None:
        assert ((x - x_updated).pow(2).sum(-1) > 1.0).any(), 'test is too weak'
    else:
        assert ((x - x_updated).pow(2).sum(-1) <=
                1e-8 + trust_radius**2).all(), 'trust region violated'
Exemple #2
0
def test_newton_step_converges(trust_radius, dims):
    batch_size = 100
    batch_shape = torch.Size((batch_size, ))
    mode = random_inside_unit_circle(batch_shape + (dims, ),
                                     requires_grad=True) - 1
    x = random_inside_unit_circle(batch_shape +
                                  (dims, ), requires_grad=True) + 1

    # create a quadratic loss function
    noise = torch.randn(batch_size, dims, 1)
    hessian = noise.matmul(noise.transpose(-1, -2)) + 0.01 * torch.eye(dims)

    def loss_fn(x):
        diff = (x - mode).unsqueeze(-2)
        return 0.5 * diff.bmm(hessian).bmm(diff.transpose(-1, -2)).sum()

    # check convergence
    for i in range(100):
        x = x.detach()
        x.requires_grad = True
        loss = loss_fn(x)
        x, cov = newton_step(loss, x, trust_radius=trust_radius)
        if ((x - mode).pow(2).sum(-1) < 1e-4).all():
            logger.debug('Newton iteration converged after {} steps'.format(2 +
                                                                            i))
            return
    pytest.fail('Newton iteration did not converge')
Exemple #3
0
 def get_step(self, loss, params):
     updated_values = {}
     for name, value in params.items():
         trust_radius = self.trust_radii.get(name)
         updated_value, cov = newton_step(loss, value, trust_radius)
         updated_values[name] = updated_value
     return updated_values
Exemple #4
0
def test_newton_step(batch_shape, trust_radius, dims):
    batch_shape = torch.Size(batch_shape)
    mode = 0.5 * random_inside_unit_circle(batch_shape + (dims, ),
                                           requires_grad=True)
    x = 0.5 * random_inside_unit_circle(batch_shape + (dims, ),
                                        requires_grad=True)
    if trust_radius is not None:
        assert trust_radius >= 2, '(x, mode) may be farther apart than trust_radius'

    # create a quadratic loss function
    flat_x = x.reshape(-1, dims)
    flat_mode = mode.reshape(-1, dims)
    noise = torch.randn(flat_x.shape[0], dims, 1)
    flat_hessian = noise.matmul(noise.transpose(-1, -2)) + torch.eye(dims)
    hessian = flat_hessian.reshape(batch_shape + (dims, dims))
    diff = (flat_x - flat_mode).unsqueeze(-2)
    loss = 0.5 * diff.bmm(flat_hessian).bmm(diff.transpose(-1, -2)).sum()

    # run method under test
    x_updated, cov = newton_step(loss, x, trust_radius=trust_radius)

    # check shapes
    assert x_updated.shape == x.shape
    assert cov.shape == hessian.shape

    # check values
    assert_equal(x_updated,
                 mode,
                 prec=1e-6,
                 msg='{} vs {}'.format(x_updated, mode))
    flat_cov = cov.reshape(flat_hessian.shape)
    assert_equal(flat_cov,
                 flat_cov.transpose(-1, -2),
                 msg='covariance is not symmetric: {}'.format(flat_cov))
    actual_eye = torch.bmm(flat_cov, flat_hessian)
    expected_eye = torch.eye(dims).expand(actual_eye.shape)
    assert_equal(actual_eye,
                 expected_eye,
                 prec=1e-4,
                 msg='bad covariance {}'.format(actual_eye))

    # check gradients
    for i in itertools.product(*map(range, mode.shape)):
        expected_grad = torch.zeros(mode.shape)
        expected_grad[i] = 1
        actual_grad = grad(x_updated[i], [mode], create_graph=True)[0]
        assert_equal(actual_grad,
                     expected_grad,
                     prec=1e-5,
                     msg='\n'.join([
                         'bad gradient at index {}'.format(i),
                         'expected {}'.format(expected_grad),
                         'actual   {}'.format(actual_grad),
                     ]))