def test_jacobian_times_vectorfield_adjoint_gradcheck(bs, dim): defsh = tuple([bs, dim] + [res] * dim) v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda() m = torch.randn_like(v) m.requires_grad = True catch_gradcheck("Failed jacobian_times_vectorfield_adjoint gradcheck", lm.jacobian_times_vectorfield_adjoint, (v, m))
def test_affine_interp_gradcheck(bs, dim, c, testI, testA, testT): if not (testI or testA or testT): return # nothing to test imsh = tuple([bs,c]+[res]*dim) I = torch.randn(imsh, dtype=torch.float64, requires_grad=testI).cuda() A = torch.randn((bs,dim,dim), dtype=I.dtype, requires_grad=testA).to(I.device) T = torch.randn((bs,dim), dtype=I.dtype, requires_grad=testT).to(I.device) catch_gradcheck(f"Failed affine interp gradcheck with batch size {bs} dim {dim} channels {c}", lm.affine_interp, (I,A,T))
def test_fluid_flat_gradcheck(bs, dim): fluid_params = [.1, .01, .001] defsh = tuple([bs, dim] + [res] * dim) v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda() metric = lm.FluidMetric(fluid_params) catch_gradcheck( f"Failed fluid flat gradcheck with batch size {bs} dim {dim}", metric.flat, (v, ))
def test_interp_gradcheck(bs, nc, dim, testI, testu, broadcastI): if not (testI or testu): return # nothing to test if broadcastI: imsh = tuple([1, nc] + [res] * dim) else: imsh = tuple([bs, nc] + [res] * dim) defsh = tuple([bs, dim] + [res] * dim) I = torch.randn(imsh, dtype=torch.float64, requires_grad=testI).cuda() u = torch.randn(defsh, dtype=I.dtype, requires_grad=testu).to(I.device) catch_gradcheck("Failed interp gradcheck", lm.interp, (I, u))
def test_jacobian_times_vectorfield_gradcheck(bs, dim, disp, trans, testphi, testm): if not (testphi or testm): return # nothing to test defsh = tuple([bs, dim] + [res] * dim) phiinv = torch.randn(defsh, dtype=torch.float64, requires_grad=testphi).cuda() m = torch.randn_like(phiinv) m.requires_grad = testm foo = lambda v, w: lm.jacobian_times_vectorfield( v, w, displacement=disp, transpose=trans) catch_gradcheck("Failed jacobian_times_vectorfield gradcheck", foo, (phiinv, m))
def test_Ad_star_gradcheck(bs, dim): defsh = tuple([bs, dim] + [res] * dim) phiinv = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda() m = torch.randn_like(phiinv) catch_gradcheck(f"Failed Ad_star gradcheck with batch size {bs} dim {dim}", lm.Ad_star, (phiinv, m))
def test_regrid_displacement_gradcheck(bs, dim): imsh = tuple([bs, dim] + [res] * dim) I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda() outshape = [res + 1] * dim foo = lambda J: lm.regrid(J, shape=outshape, displacement=True) catch_gradcheck("Failed regrid displacement gradcheck", foo, (I, ))