示例#1
0
def test_jacobian_times_vectorfield_adjoint_gradcheck(bs, dim):
    defsh = tuple([bs, dim] + [res] * dim)
    v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda()
    m = torch.randn_like(v)
    m.requires_grad = True
    catch_gradcheck("Failed jacobian_times_vectorfield_adjoint gradcheck",
                    lm.jacobian_times_vectorfield_adjoint, (v, m))
示例#2
0
def test_affine_interp_gradcheck(bs, dim, c, testI, testA, testT):
    if not (testI or testA or testT): return # nothing to test
    imsh = tuple([bs,c]+[res]*dim)
    I = torch.randn(imsh, dtype=torch.float64, requires_grad=testI).cuda()
    A = torch.randn((bs,dim,dim), dtype=I.dtype, requires_grad=testA).to(I.device)
    T = torch.randn((bs,dim), dtype=I.dtype, requires_grad=testT).to(I.device)
    catch_gradcheck(f"Failed affine interp gradcheck with batch size {bs} dim {dim} channels {c}", lm.affine_interp, (I,A,T))
示例#3
0
def test_fluid_flat_gradcheck(bs, dim):
    fluid_params = [.1, .01, .001]
    defsh = tuple([bs, dim] + [res] * dim)
    v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda()
    metric = lm.FluidMetric(fluid_params)
    catch_gradcheck(
        f"Failed fluid flat gradcheck with batch size {bs} dim {dim}",
        metric.flat, (v, ))
示例#4
0
def test_interp_gradcheck(bs, nc, dim, testI, testu, broadcastI):
    if not (testI or testu): return  # nothing to test
    if broadcastI:
        imsh = tuple([1, nc] + [res] * dim)
    else:
        imsh = tuple([bs, nc] + [res] * dim)
    defsh = tuple([bs, dim] + [res] * dim)
    I = torch.randn(imsh, dtype=torch.float64, requires_grad=testI).cuda()
    u = torch.randn(defsh, dtype=I.dtype, requires_grad=testu).to(I.device)
    catch_gradcheck("Failed interp gradcheck", lm.interp, (I, u))
示例#5
0
def test_jacobian_times_vectorfield_gradcheck(bs, dim, disp, trans, testphi,
                                              testm):
    if not (testphi or testm): return  # nothing to test
    defsh = tuple([bs, dim] + [res] * dim)
    phiinv = torch.randn(defsh, dtype=torch.float64,
                         requires_grad=testphi).cuda()
    m = torch.randn_like(phiinv)
    m.requires_grad = testm
    foo = lambda v, w: lm.jacobian_times_vectorfield(
        v, w, displacement=disp, transpose=trans)
    catch_gradcheck("Failed jacobian_times_vectorfield gradcheck", foo,
                    (phiinv, m))
示例#6
0
def test_Ad_star_gradcheck(bs, dim):
    defsh = tuple([bs, dim] + [res] * dim)
    phiinv = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda()
    m = torch.randn_like(phiinv)
    catch_gradcheck(f"Failed Ad_star gradcheck with batch size {bs} dim {dim}",
                    lm.Ad_star, (phiinv, m))
示例#7
0
def test_regrid_displacement_gradcheck(bs, dim):
    imsh = tuple([bs, dim] + [res] * dim)
    I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda()
    outshape = [res + 1] * dim
    foo = lambda J: lm.regrid(J, shape=outshape, displacement=True)
    catch_gradcheck("Failed regrid displacement gradcheck", foo, (I, ))