def _grad_test(fn, obj, args, sens_type=f64, pipeline=grad_pipeline, rel_error=1e-3, argspec=None): pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap) if argspec is None: argspec = tuple( from_value(arg, broaden=True) for arg in clean_args(args)) else: argspec = tuple(to_abstract_test(x) for x in argspec) sens_type = to_abstract_test(sens_type) if isinstance(obj, FunctionType): res = pipeline.run(input=obj, argspec=[*argspec, sens_type]) else: pip = pipeline.configure(parse=False) res = pip.run(graph=obj, argspec=[*argspec, sens_type]) gtest = GradTester(fn=fn, gfn=res['output'], args=args, argnames=[f'in{i}' for i in range(len(args))], outnames=None, rel_error=rel_error) gtest.assert_match()
def _grad_test( fn, obj, args, sens_type=f64, pipeline=grad_pipeline, rel_error=1e-3, argspec=None, ): pipeline = pipeline.insert_after(steps.step_parse, grad_wrap) if argspec is None: argspec = tuple( from_value(arg, broaden=True) for arg in clean_args(args)) else: argspec = tuple(to_abstract_test(x) for x in argspec) sens_type = to_abstract_test(sens_type) if isinstance(obj, FunctionType): res = pipeline(input=obj, argspec=[*argspec, sens_type]) else: pip = pipeline.without_step(steps.step_parse) res = pip(graph=obj, argspec=[*argspec, sens_type]) gtest = GradTester( fn=fn, gfn=res["output"], args=args, argnames=[f"in{i}" for i in range(len(args))], outnames=None, rel_error=rel_error, ) gtest.assert_match()
def make_argspec(args, broad_specs): if broad_specs is None: broad_specs = (True,) * len(args) return tuple( from_value(arg, broaden=bs) for bs, arg in zip(broad_specs, clean_args(args)) )
def _grad_test(fn, obj, args, sens_type, pipeline=grad_pipeline, rel_error=1e-3): pytorch_grads = pt_fn_grads(fn, *args) sens_type_shape = sens_type if sens_type == (): sens_type = APT_0d_loss elif sens_type == (1, ): sens_type = APT_loss else: sens_type = AbstractArray(AbstractScalar({ TYPE: f32, VALUE: ANYTHING }), { SHAPE: sens_type, TYPE: PyTorchTensor }) pipeline = standard_pipeline pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap) argspec = tuple(from_value(arg, broaden=True) for arg in clean_args(args)) sens_type = to_abstract_test(sens_type) if isinstance(obj, FunctionType): res = pipeline.run(input=obj, argspec=[*argspec, sens_type]) else: pip = pipeline.configure(parse=False) res = pip.run(graph=obj, argspec=[*argspec, sens_type]) if sens_type == APT_loss: sens = torch.Tensor([1.0]) elif sens_type == APT_0d_loss: sens = torch.Tensor([1.0]).reshape(()) else: sens = torch.ones(sens_type_shape) myia_grads = res['output'](*args, sens) for pt_g, my_g in zip(pytorch_grads, myia_grads): # print("pytorch_grad", pt_g) # print("myia_grad", my_g) assert torch.allclose(pt_g, my_g, rtol=1e-05, atol=1e-06, equal_nan=True)
def _grad_test(fn, obj, args, sens_type=f64, pipeline=grad_pipeline, rel_error=1e-3): pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap) argspec = [{'value': arg} for arg in clean_args(args)] sens_type = {'type': sens_type} if isinstance(obj, FunctionType): res = pipeline.run(input=obj, argspec=[*argspec, sens_type]) else: pip = pipeline.configure(parse=False) res = pip.run(graph=obj, argspec=[*argspec, sens_type]) gtest = GradTester( fn=fn, gfn=res['output'], args=args, argnames=[f'in{i}' for i in range(len(args))], outnames=None, rel_error=rel_error ) gtest.assert_match()