Beispiel #1
0
def test_jac_method_grad():
    na = 3
    params = getfnparams(na)
    nnparams = getnnparams(na)
    num_nnparams = len(nnparams)
    jacs = jac(func2(*nnparams), params)
    nout = jacs[0].shape[-2]

    def fcnr(i, v, *allparams):
        nnparams = allparams[:num_nnparams]
        params = allparams[num_nnparams:]
        jacs = jac(func2(*nnparams), params)
        return jacs[i].rmv(v.view(-1))

    def fcnl(i, v, *allparams):
        nnparams = allparams[:num_nnparams]
        params = allparams[num_nnparams:]
        jacs = jac(func2(*nnparams), params)
        return jacs[i].mv(v.view(-1))

    v = torch.rand((na, ), dtype=dtype, requires_grad=True)
    w = [torch.rand_like(p).requires_grad_() for p in params]
    for i in range(len(jacs)):
        gradcheck(fcnr, (i, v, *nnparams, *params))
        gradgradcheck(fcnr, (i, v, *nnparams, *params))
        gradcheck(fcnl, (i, w[i], *nnparams, *params))
        gradgradcheck(fcnl, (i, w[i], *nnparams, *params))
Beispiel #2
0
def test_jac_func():
    na = 3
    params = getfnparams(na)
    nnparams = getnnparams(na)
    nparams = len(params)
    all_idxs = [None, (0, ), (1, ), (0, 1), (0, 1, 2)]
    funcs = [func1, func2(*nnparams)]

    for func in funcs:
        for idxs in all_idxs:
            if idxs is None:
                gradparams = params
            else:
                gradparams = [params[i] for i in idxs]

            jacs = jac(func, params, idxs=idxs)
            assert len(jacs) == len(gradparams)

            y = func(*params)
            nout = torch.numel(y)
            nins = [torch.numel(p) for p in gradparams]
            v = torch.rand_like(y).requires_grad_()
            for i in range(len(jacs)):
                assert list(jacs[i].shape) == [nout, nins[i]]

            # get rmv
            jacs_rmv = torch.autograd.grad(y,
                                           gradparams,
                                           grad_outputs=v,
                                           create_graph=True)
            # the jac LinearOperator has shape of (nout, nin), so we need to flatten v
            jacs_rmv0 = [jc.rmv(v.view(-1)) for jc in jacs]

            # calculate the mv
            w = [torch.rand_like(p) for p in gradparams]
            jacs_lmv = [
                torch.autograd.grad(jacs_rmv[i], (v, ),
                                    grad_outputs=w[i],
                                    retain_graph=True)[0]
                for i in range(len(jacs))
            ]
            jacs_lmv0 = [jacs[i].mv(w[i].view(-1)) for i in range(len(jacs))]

            for i in range(len(jacs)):
                assert torch.allclose(jacs_rmv[i].view(-1),
                                      jacs_rmv0[i].view(-1))
                assert torch.allclose(jacs_lmv[i].view(-1),
                                      jacs_lmv0[i].view(-1))
Beispiel #3
0
    def backward(ctx, grad_yout):
        param_sep = ctx.param_sep
        yout = ctx.saved_tensors[0]
        nparams = ctx.nparams
        fcn = ctx.fcn

        # merge the tensor and nontensor parameters
        tensor_params = ctx.saved_tensors[1:]
        allparams = param_sep.reconstruct_params(tensor_params)
        params = allparams[:nparams]
        objparams = allparams[nparams:]

        # dL/df
        with ctx.fcn.useobjparams(objparams):

            jac_dfdy = jac(fcn, params=(yout, *params), idxs=[0])[0]
            gyfcn = solve(A=jac_dfdy.H,
                          B=-grad_yout.reshape(-1, 1),
                          bck_options=ctx.bck_options,
                          **ctx.bck_options)
            gyfcn = gyfcn.reshape(grad_yout.shape)

            # get the grad for the params
            with torch.enable_grad():
                tensor_params_copy = [
                    p.clone().requires_grad_() for p in tensor_params
                ]
                allparams_copy = param_sep.reconstruct_params(
                    tensor_params_copy)
                params_copy = allparams_copy[:nparams]
                objparams_copy = allparams_copy[nparams:]
                with ctx.fcn.useobjparams(objparams_copy):
                    yfcn = fcn(yout, *params_copy)

            grad_tensor_params = torch.autograd.grad(
                yfcn,
                tensor_params_copy,
                grad_outputs=gyfcn,
                create_graph=torch.is_grad_enabled(),
                allow_unused=True)
            grad_nontensor_params = [
                None for _ in range(param_sep.nnontensors())
            ]
            grad_params = param_sep.reconstruct_params(grad_tensor_params,
                                                       grad_nontensor_params)

        return (None, None, None, None, None, None, None, *grad_params)
Beispiel #4
0
def test_jac_grad():
    na = 3
    params = getfnparams(na)
    params2 = [torch.rand(1, dtype=dtype).requires_grad_() for p in params]
    jacs = jac(func1, params)
    nout = jacs[0].shape[-2]

    def fcnr(i, v, *params):
        jacs = jac(func1, params)
        return jacs[i].rmv(v.view(-1))

    def fcnl(i, w, *params):
        jacs = jac(func1, params)
        return jacs[i].mv(w.view(-1))

    def fcnr2(i, v, *params2):
        fmv = get_pure_function(jacs[i].rmv)
        params0 = v.view(-1)
        params1 = fmv.objparams()
        params12 = [p1 * p2 for p1, p2 in zip(params1, params2)]
        with fmv.useobjparams(params12):
            return fmv(params0)

    def fcnl2(i, w, *params2):
        fmv = get_pure_function(jacs[i].mv)
        params0 = w.view(-1)
        params1 = fmv.objparams()
        params12 = [p1 * p2 for (p1, p2) in zip(params1, params2)]
        with fmv.useobjparams(params12):
            return fmv(params0)

    v = torch.rand((na, ), dtype=dtype, requires_grad=True)
    w = [torch.rand_like(p).requires_grad_() for p in params]
    for i in range(len(jacs)):
        gradcheck(fcnr, (i, v, *params))
        gradgradcheck(fcnr, (i, v, *params))
        gradcheck(fcnl, (i, w[i], *params))
        gradgradcheck(fcnl, (i, w[i], *params))
        gradcheck(fcnr2, (i, v, *params2))
        gradgradcheck(fcnr2, (i, v, *params2))
        gradcheck(fcnl2, (i, w[i], *params2))
        gradgradcheck(fcnl2, (i, w[i], *params2))
Beispiel #5
0
 def fcnl(i, v, *allparams):
     nnparams = allparams[:num_nnparams]
     params = allparams[num_nnparams:]
     jacs = jac(func2(*nnparams), params)
     return jacs[i].mv(v.view(-1))
Beispiel #6
0
 def fcnl(i, w, *params):
     jacs = jac(func1, params)
     return jacs[i].mv(w.view(-1))
Beispiel #7
0
 def fcnr(i, v, *params):
     jacs = jac(func1, params)
     return jacs[i].rmv(v.view(-1))