Exemplo n.º 1
0
def _grad_test(fn,
               obj,
               args,
               sens_type=f64,
               pipeline=grad_pipeline,
               rel_error=1e-3,
               argspec=None):
    pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap)
    if argspec is None:
        argspec = tuple(
            from_value(arg, broaden=True) for arg in clean_args(args))
    else:
        argspec = tuple(to_abstract_test(x) for x in argspec)
    sens_type = to_abstract_test(sens_type)
    if isinstance(obj, FunctionType):
        res = pipeline.run(input=obj, argspec=[*argspec, sens_type])
    else:
        pip = pipeline.configure(parse=False)
        res = pip.run(graph=obj, argspec=[*argspec, sens_type])
    gtest = GradTester(fn=fn,
                       gfn=res['output'],
                       args=args,
                       argnames=[f'in{i}' for i in range(len(args))],
                       outnames=None,
                       rel_error=rel_error)
    gtest.assert_match()
Exemplo n.º 2
0
def _grad_test(
    fn,
    obj,
    args,
    sens_type=f64,
    pipeline=grad_pipeline,
    rel_error=1e-3,
    argspec=None,
):
    pipeline = pipeline.insert_after(steps.step_parse, grad_wrap)
    if argspec is None:
        argspec = tuple(
            from_value(arg, broaden=True) for arg in clean_args(args))
    else:
        argspec = tuple(to_abstract_test(x) for x in argspec)
    sens_type = to_abstract_test(sens_type)
    if isinstance(obj, FunctionType):
        res = pipeline(input=obj, argspec=[*argspec, sens_type])
    else:
        pip = pipeline.without_step(steps.step_parse)
        res = pip(graph=obj, argspec=[*argspec, sens_type])
    gtest = GradTester(
        fn=fn,
        gfn=res["output"],
        args=args,
        argnames=[f"in{i}" for i in range(len(args))],
        outnames=None,
        rel_error=rel_error,
    )
    gtest.assert_match()
Exemplo n.º 3
0
def make_argspec(args, broad_specs):
    if broad_specs is None:
        broad_specs = (True,) * len(args)
    return tuple(
        from_value(arg, broaden=bs)
        for bs, arg in zip(broad_specs, clean_args(args))
    )
Exemplo n.º 4
0
def _grad_test(fn,
               obj,
               args,
               sens_type,
               pipeline=grad_pipeline,
               rel_error=1e-3):

    pytorch_grads = pt_fn_grads(fn, *args)

    sens_type_shape = sens_type
    if sens_type == ():
        sens_type = APT_0d_loss
    elif sens_type == (1, ):
        sens_type = APT_loss
    else:
        sens_type = AbstractArray(AbstractScalar({
            TYPE: f32,
            VALUE: ANYTHING
        }), {
            SHAPE: sens_type,
            TYPE: PyTorchTensor
        })

    pipeline = standard_pipeline
    pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap)
    argspec = tuple(from_value(arg, broaden=True) for arg in clean_args(args))
    sens_type = to_abstract_test(sens_type)
    if isinstance(obj, FunctionType):
        res = pipeline.run(input=obj, argspec=[*argspec, sens_type])
    else:
        pip = pipeline.configure(parse=False)
        res = pip.run(graph=obj, argspec=[*argspec, sens_type])

    if sens_type == APT_loss:
        sens = torch.Tensor([1.0])
    elif sens_type == APT_0d_loss:
        sens = torch.Tensor([1.0]).reshape(())
    else:
        sens = torch.ones(sens_type_shape)

    myia_grads = res['output'](*args, sens)

    for pt_g, my_g in zip(pytorch_grads, myia_grads):
        # print("pytorch_grad", pt_g)
        # print("myia_grad", my_g)
        assert torch.allclose(pt_g,
                              my_g,
                              rtol=1e-05,
                              atol=1e-06,
                              equal_nan=True)
Exemplo n.º 5
0
def _grad_test(fn, obj, args,
               sens_type=f64,
               pipeline=grad_pipeline,
               rel_error=1e-3):
    pipeline = pipeline.insert_after('parse', grad_wrap=grad_wrap)
    argspec = [{'value': arg} for arg in clean_args(args)]
    sens_type = {'type': sens_type}
    if isinstance(obj, FunctionType):
        res = pipeline.run(input=obj, argspec=[*argspec, sens_type])
    else:
        pip = pipeline.configure(parse=False)
        res = pip.run(graph=obj, argspec=[*argspec, sens_type])
    gtest = GradTester(
        fn=fn,
        gfn=res['output'],
        args=args,
        argnames=[f'in{i}' for i in range(len(args))],
        outnames=None,
        rel_error=rel_error
    )
    gtest.assert_match()