예제 #1
0
  def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
            dtypes, rng, rtol=None, atol=None):
    shapecheck(in_shapes, out_shape)(fun)
    masked_fun = mask(fun, in_shapes, out_shape)
    padded_args = [rng(shape, dtype)
                   for shape, dtype in zip(padded_in_shapes, dtypes)]
    padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))

    out_specs, _ = tree_flatten(out_shape)
    out_specs = map(parse_spec, out_specs)
    out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
    logical_out_shapes = [eval_poly_shape(s, logical_env)
                          for s in out_specs]
    logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
    logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]

    in_specs = map(parse_spec, in_shapes)
    in_specs = map(finalize_spec, in_specs, padded_in_shapes)
    logical_in_shapes = [eval_poly_shape(s, logical_env)
                         for s in in_specs]
    logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
    logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
    logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
    assert outs_tree == logical_outs_tree
    self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
                        atol=atol, rtol=rtol)

    # Check that abstract evaluation works
    padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
    self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
                        atol=atol, rtol=rtol)
예제 #2
0
 def thunk():
   shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])
예제 #3
0
 def thunk():
   shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))
예제 #4
0
 def thunk():
   shapecheck(['n'], 'n+-1')(lambda x: x)