コード例 #1
0
def leapfrog(q, p, potential_fn, inverse_mass_matrix, step_size):
    r"""
    Second order symplectic integrator that uses the velocity leapfrog algorithm.

    :param dict q: dictionary of sample site names and their current values
        (type :class:`~torch.Tensor`).
    :param dict p: dictionary of sample site names and corresponding momenta
        (type :class:`~torch.Tensor`).
    :param callable potential_fn: function that returns potential energy given q
        for each sample site. The negative gradient of the function with respect
        to ``q`` determines the rate of change of the corresponding sites'
        momenta ``r``.
    :param torch.Tensor inverse_mass_matrix: a tensor :math:`M^{-1}` which is used
        to calculate kinetic energy: :math:`E_{kinetic} = \frac{1}{2}z^T M^{-1} q`.
        Here :math:`M` can be a 1D tensor (diagonal matrix) or a 2D tensor (dense matrix).
    :param float step_size: step size for each time step iteration.
    :param int num_steps: number of discrete time steps over which to integrate.
    :param torch.Tensor q_grads: optional gradients of potential energy at current ``q``.
    :return tuple (q_next, p_next, q_grads, potential_energy): next position and momenta,
        together with the potential energy and its gradient w.r.t. ``q_next``.
    """
    q_grads = grad(potential_fn)(q)

    p = p + 0.5 * step_size * (-q_grads)

    p_grads = _kinetic_grad(inverse_mass_matrix, p)
    q = q + step_size * p_grads  # q(n+1)

    q_grads = grad(potential_fn)(q)
    # potential_energy = potential_fn(q)
    # q_grads, potential_energy = grad_and_value(potential_fn)(q)
    p = p + 0.5 * step_size * (-q_grads)

    return q, p, q_grads, q_grads
コード例 #2
0
    def _test_attributes(self, get_attr_lambda, device):
        x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
        expected = get_attr_lambda(x)

        def foo(x):
            self.assertEqual(get_attr_lambda(x), expected)
            return x.sum()

        grad(foo)(x)
コード例 #3
0
    def test_grad_vjp(self, device):
        x = torch.randn(3, device=device)

        def foo(x):
            _, vjp_fn = vjp(torch.sin, x)
            return vjp_fn(x)[0].sum()

        y = grad(foo)(x)
        expected = grad(lambda x: (x * x.cos()).sum())(x)
        self.assertEqual(y, expected)
コード例 #4
0
    def test_per_sample_grads_simple(self, device):
        def compute_loss(weight, x, t):
            y = x @ weight
            return ((y - t) ** 2).sum()

        weight = torch.randn(16, 2, device=device)
        x = torch.randn(64, 16, device=device)
        t = torch.randn(64, 2, device=device)
        result = vmap(partial(grad(compute_loss), weight))(x, t)
        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
        expected = torch.stack(expected)
        # TODO: Check if the rtol is a problem
        self.assertEqual(result, expected, atol=0, rtol=5e-4)
コード例 #5
0
    def test_argnums(self, device):
        x = torch.randn([])
        y = torch.randn([])
        gx = grad(torch.mul, argnums=0)(x, y)
        self.assertEqual(gx, y)

        gy = grad(torch.mul, argnums=1)(x, y)
        self.assertEqual(gy, x)

        gx, = grad(torch.mul, argnums=(0,))(x, y)
        self.assertEqual(gx, y)

        gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
        self.assertEqual(gx, y)
        self.assertEqual(gy, x)
コード例 #6
0
    def test_make_fx_no_decompose(self, device):
        # FIXME
        return self.skipTest("error: maximum recursion reached")

        def f(x):
            return torch.tanh(x).sum()

        fx_f = make_fx(grad(f))(torch.randn(5))
        ops = set([i.target for i in fx_f.graph.nodes])

        self.assertEqual(torch.ops.aten.tanh_backward in ops, True)

        fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
        ops = set([i.target for i in fx_f.graph.nodes])
        self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
コード例 #7
0
    def test_vjp_grad(self, device):
        x = torch.randn([], device=device)
        y, vjp_fn = vjp(grad(torch.sin), x)
        self.assertEqual(y, x.cos())

        v = torch.randn([])
        self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
コード例 #8
0
 def test_zero_grad(self, device):
     def f(x):
         return (x['a']**2.0).sum()
     inps = ({'a':torch.randn(10, device=device) + 3, 'b':torch.randn(10, device=device)})
     grads = grad(f)(inps)
     self.assertNotEqual(grads['a'].sum(), 0.0)
     self.assertEqual(grads['b'].sum(), 0.0)
コード例 #9
0
    def test_per_sample_grads_embeddingnet(self, device):
        class SampleNet(nn.Module):
            def __init__(self, vocab_size: int):
                super().__init__()
                self.emb = nn.Embedding(vocab_size, 16)
                self.fc1 = nn.Linear(16, 16)
                self.fc2 = nn.Linear(16, 2)

            def forward(self, x):
                x = self.emb(x)
                x = torch.transpose(x, -1, -2)
                x = torch.mean(x, -1)
                x = self.fc1(x)
                x = F.relu(x)
                x = self.fc2(x)
                return x

            def name(self):
                return "SampleNet"

        # Create our inputs...
        vocab_size = 1000
        batch_shape = [64]
        words_per_sentence = 5
        data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device)
        targets = torch.randint(0, 1, (*batch_shape,), device=device)

        # Construct our module
        net = SampleNet(vocab_size).to(device=device)
        criterion = nn.CrossEntropyLoss()

        params = dict(net.named_parameters())
        weights, net_func, _ = make_functional(net)

        def compute_loss(weights, data, target):
            output = net_func(weights, (data,))
            result = criterion(output, target)
            return result

        expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
        expected = zip(*expected)
        expected = tuple(torch.stack(shards) for shards in expected)

        result = vmap(partial(grad(compute_loss), weights))(data, targets)
        for r, e in zip(result, expected):
            # TODO: Check if the rtol is a problem
            self.assertEqual(r, e, atol=0, rtol=1e-4)
コード例 #10
0
    def test_inplace(self, device):
        x = torch.randn([], device=device)

        def foo(x):
            return x.clone().sin_()

        result = grad(foo)(x)
        self.assertEqual(result, x.cos())
コード例 #11
0
    def test_grad_vmap(self, device):
        def foo(x):
            y = vmap(torch.sin)(x)
            return y.sum()

        x = torch.randn(3)
        y = grad(foo)(x)
        self.assertEqual(y, x.cos())
コード例 #12
0
    def test_unrelated_grad(self, device):
        x = torch.tensor(1., device=device)
        y = torch.tensor(2., device=device)

        def unrelated(x):
            return y

        result = grad(unrelated)(x)
        self.assertEqual(result, torch.zeros_like(x))
コード例 #13
0
    def test_view_inplace_simple(self, device):
        def foo(x):
            x = x.clone()
            x.view([]).sin_()
            return x

        x = torch.randn([], requires_grad=True, device=device)
        result = grad(foo)(x)
        self.assertEqual(result, x.cos())
コード例 #14
0
    def test_make_fx_grad(self, device):
        def f(x):
            return torch.sin(x).sum()
        inp = torch.randn(3)
        f = grad(f)
        fx_f = make_fx(f)(inp)

        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))
コード例 #15
0
 def test_invalid_argnums(self, device):
     x = torch.randn([])
     y = torch.randn([])
     with self.assertRaisesRegex(RuntimeError, 'but only'):
         grad(torch.mul, argnums=-1)(x, y)
     with self.assertRaisesRegex(RuntimeError, 'but only'):
         grad(torch.mul, argnums=2)(x, y)
     with self.assertRaisesRegex(RuntimeError, 'int or Tuple'):
         grad(torch.mul, argnums=[0])(x, y)
     with self.assertRaisesRegex(RuntimeError, 'must be int'):
         grad(torch.mul, argnums=('0',))(x, y)
コード例 #16
0
ファイル: per_sample_grads.py プロジェクト: mikekgfb/pytorch
def functorch_per_sample_grad():
    compute_grad = grad(compute_loss)
    compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))

    start = time.time()
    result = compute_per_sample_grad(weights, images, targets)
    torch.cuda.synchronize()
    end = time.time()

    return result, end - start  # end - start in seconds
コード例 #17
0
    def test_escaped_wrappers_are_marked_as_dead(self, device):
        x = torch.randn([], device=device)
        escaped = []
        def foo(x):
            y = x.sin()
            escaped.append(y)
            return y

        result = grad(foo)(x)
        self.assertEqual(functorch._C.dlevel(escaped[0]), -1)
コード例 #18
0
    def get_loss_for_task(x1, y1, x2, y2):
        def inner_loss(params, x1, y1):
            f = net(params, x1)
            loss = mse_loss(f, y1)
            return loss

        grads = grad(inner_loss)(tuple(params), x1, y1)
        new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))]

        v_f = net(new_params, x2)
        return mse_loss(v_f, y2)
コード例 #19
0
    def test_escaped_wrappers_are_ignored(self, device):
        x = torch.randn([], device=device)
        escaped = []
        def foo(x):
            y = x.sin()
            escaped.append(y)
            return y

        result = grad(foo)(x)

        something = escaped[0].sum()
        self.assertEqual(functorch._C.dlevel(something), 0)
        self.assertEqual(something, x.sin().sum())
コード例 #20
0
    def test_new_empty_materializes_tensor(self, device):
        N = 3
        C = 5

        def foo(y, x):
            result = x.new_empty((C,))
            result.copy_(y)
            return result.sum()

        x = torch.randn(N, device=device)
        y = torch.randn(N, C, device=device)
        result = vmap(grad(foo))(y, x)
        self.assertEqual(result, torch.ones_like(y))
コード例 #21
0
    def test_composite_two_ops(self, device):
        N, C = 2, 5
        y = torch.randn(N, C, device=device)
        targets = torch.randint(0, C, (N,), device=device)

        def foo(y, targets):
            return F.cross_entropy(y, targets)

        result = grad(foo)(y, targets)

        y.requires_grad_()
        expected, = torch.autograd.grad(foo(y, targets), y)

        self.assertEqual(result, expected)
コード例 #22
0
        def get_loss_for_task(use_transform, x1, y1, x2, y2):
            def inner_loss(params, x1, y1):
                f = net(params, (x1,))
                loss = mse_loss(f, y1)
                return loss

            if use_transform:
                grads = grad(inner_loss)(params, x1, y1)
            else:
                loss = inner_loss(params, x1, y1)
                grads = torch.autograd.grad(loss, params, create_graph=True)
            new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]

            v_f = net(new_params, (x2,))
            return mse_loss(v_f, y2)
コード例 #23
0
    def test_composite_complicated(self, device):
        x = torch.randn(3, device=device)
        y = torch.randn(3, 5, device=device)

        def foo(x, y):
            result = x @ y
            return result.sum()

        result = grad(foo)(x, y)

        x.requires_grad_()
        out = foo(x, y)
        expected, = torch.autograd.grad(out, x)

        self.assertEqual(result, expected)
コード例 #24
0
    def test_inplace_on_view_base(self, device):
        x = torch.randn(3, device=device)

        def foo(x):
            y = x.clone()
            y0 = y[0]
            y.sin_()
            return y0

        result = grad(foo)(x)

        x.requires_grad_()
        out = foo(x)
        expected, = torch.autograd.grad(out, x)

        self.assertEqual(result, expected)
コード例 #25
0
 def update_fn(step_size, inverse_mass_matrix, state):
     """
     :param float step_size: Size of a single step.
     :param inverse_mass_matrix: Inverse of mass matrix, which is used to
         calculate kinetic energy.
     :param state: Current state of the integrator.
     :return: new state for the integrator.
     """
     q, p, _, q_grad = state
     # maps a function over a pytree, returning a new pytree
     p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p,
                       q_grad)  # p(n+1/2)
     p_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, p)
     q = tree_multimap(lambda q, p_grad: q + step_size * p_grad, q,
                       p_grad)  # q(n+1)
     potential_energy, q_grad = value_and_grad(potential_fn)(q)
     p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p,
                       q_grad)  # p(n+1)
     return IntegratorState(q, p, potential_energy, q_grad)
コード例 #26
0
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
    params, buffers, fnet = net
    querysz = x_qry.size(0)

    def compute_loss(new_params, buffers, x, y):
        logits = fnet(new_params, buffers, (x, ))
        loss = F.cross_entropy(logits, y)
        return loss

    new_params = params
    for _ in range(n_inner_iter):
        grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
        new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]

    # The final set of adapted parameters will induce some
    # final loss and accuracy on the query dataset.
    # These will be used to update the model's meta-parameters.
    qry_logits = fnet(new_params, buffers, (x_qry, ))
    qry_loss = F.cross_entropy(qry_logits, y_qry)
    qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz

    return qry_loss, qry_acc
コード例 #27
0
        def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
            params, buffers, fnet = net
            querysz = x_qry.size(0)

            def compute_loss(new_params, buffers, x, y):
                logits = fnet(new_params, buffers, (x,))
                loss = F.cross_entropy(logits, y)
                return loss

            new_params = params
            for _ in range(n_inner_iter):
                if use_transform:
                    grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
                else:
                    res = compute_loss(new_params, buffers, x_spt, y_spt)
                    grads = torch.autograd.grad(res, new_params, create_graph=True)
                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]

            qry_logits = fnet(new_params, buffers, (x_qry,))
            qry_loss = F.cross_entropy(qry_logits, y_qry)
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry).sum() / querysz

            return qry_loss, qry_acc
コード例 #28
0
# Next, let's define a function to compute the loss of the model given a single
# input rather than a batch of inputs. It is important that this function accepts the
# parameters, the input, and the target, because we will be transforming over them.
# Because the model was originally written to handle batches, we'll use
# ``torch.unsqueeze`` to add a batch dimension.
def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)
    predictions = fmodel(params, buffers, batch)
    loss = loss_fn(predictions, targets)
    return loss


# Now, let's use ``grad`` to create a new function that computes the gradient
# with respect to the first argument of compute_loss (i.e. the params).
ft_compute_grad = grad(compute_loss)

# ``ft_compute_grad`` computes the gradient for a single (sample, target) pair.
# We can use ``vmap`` to get it to compute the gradient over an entire batch
# of samples and targets. Note that in_dims=(None, None, 0, 0) because we wish
# to map ``ft_compute_grad`` over the 0th dimension of the data and targets
# and use the same params and buffers for each.
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

# Finally, let's used our transformed function to compute per-sample-gradients:
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads,
                                               ft_per_sample_grads):
    assert torch.allclose(per_sample_grad,
                          ft_per_sample_grad,
                          atol=1e-6,
コード例 #29
0
ファイル: simple_function.py プロジェクト: zeta1999/functorch
from functorch import grad, nnc_jit, make_fx, make_nnc
import torch
import time


def f(x):
    return torch.sin(x).sum()


inp = torch.randn(100)
grad_pt = grad(f)
grad_fx = make_fx(grad_pt)(inp)
grad_nnc = nnc_jit(grad_pt)
loopnest = make_nnc(grad_pt)(inp)
print(loopnest)


def bench(name, f, iters=10000, warmup=3):
    for _ in range(warmup):
        f()
    begin = time.time()
    for _ in range(iters):
        f()
    print(f"{name}: ", time.time() - begin)


bench("Pytorch: ", lambda: grad_pt(inp))
bench("FX: ", lambda: grad_fx(inp))
bench("NNC: ", lambda: grad_nnc(inp))
コード例 #30
0
    def test_resnet18_per_sample_grads(self, device):
        # Straight out of opacus
        def _replace_child(
            root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
        ) -> None:
            # find the immediate parent
            parent = root
            nameList = child_name.split(".")
            for name in nameList[:-1]:
                parent = parent._modules[name]
            # set to identity
            parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]])

        def replace_all_modules(
            root: nn.Module,
            target_class: Type[nn.Module],
            converter: Callable[[nn.Module], nn.Module],
        ) -> nn.Module:
            # base case
            if isinstance(root, target_class):
                return converter(root)

            for name, obj in root.named_modules():
                if isinstance(obj, target_class):
                    _replace_child(root, name, converter)
            return root

        def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module:
            return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True)

        def convert_batchnorm_modules(
            model: nn.Module,
            converter: Callable[
                [nn.modules.batchnorm._BatchNorm], nn.Module
            ] = _batchnorm_to_groupnorm,
        ) -> nn.Module:
            return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter)

        import torchvision.models as models
        model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
        criterion = nn.CrossEntropyLoss()

        weights, func_model, descriptors = make_functional(model)

        def compute_loss(weights, image, target):
            images = image.unsqueeze(0)
            targets = target.unsqueeze(0)
            output = func_model(weights, (images,))
            loss = criterion(output, targets)
            return loss

        batch_size = 3
        images = torch.randn(batch_size, 3, 32, 32, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)

        result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets)

        expected_grads = [
            torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights)
            for i in range(batch_size)
        ]
        expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]

        self.assertEqual(result_grads, expected_grads)