def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) if least_specialized is core.ConcreteArray: out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, named_shape=named_shape_rule(*avals, **kwargs)) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) else core.DShapedArray) return ty(shape, dtype_rule(*avals, **kwargs), weak_type) elif least_specialized is core.UnshapedArray: return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type) else: raise TypeError(avals, least_specialized)
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) return [ core.ConcreteArray(val.dtype, val, weak_type=weak_type) for val, weak_type in safe_zip(out_vals, weak_types) ] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) out_named_shapes = named_shape_rule(*avals, **kwargs) return [ core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape) for s, d, weak_type, named_shape in safe_zip( out_shapes, out_dtypes, weak_types, out_named_shapes) ] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [ core.UnshapedArray(dtype, weak_type=weak_type) for dtype, weak_type in safe_zip(out_dtypes, weak_types) ] else: raise TypeError(avals, least_specialized)
def test_concrete_array_string_representation(self): # https://github.com/google/jax/issues/5364 self.assertEqual( str( core.ConcreteArray(np.dtype(np.int32), np.array([1], dtype=np.int32))), 'ConcreteArray([1], dtype=int32)')