示例#1
0
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
                           named_shape_rule, *avals, **kwargs):
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    assert not prim.multiple_results
    weak_type = weak_type_rule(*avals, **kwargs)
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    if least_specialized is core.ConcreteArray:
        out = prim.impl(*[x.val for x in avals], **kwargs)
        return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
    elif least_specialized is core.ShapedArray:
        return core.ShapedArray(shape_rule(*avals, **kwargs),
                                dtype_rule(*avals, **kwargs),
                                weak_type=weak_type,
                                named_shape=named_shape_rule(*avals, **kwargs))
    elif least_specialized is core.DShapedArray:
        shape = shape_rule(*avals, **kwargs)
        ty = (core.ShapedArray if all(type(d) is int
                                      for d in shape) else core.DShapedArray)
        return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
    elif least_specialized is core.UnshapedArray:
        return core.UnshapedArray(dtype_rule(*avals, **kwargs),
                                  weak_type=weak_type)
    else:
        raise TypeError(avals, least_specialized)
示例#2
0
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule,
                                        weak_type_rule, named_shape_rule,
                                        *avals, **kwargs):
    assert prim.multiple_results
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    weak_types = weak_type_rule(*avals, **kwargs)
    if least_specialized is core.ConcreteArray:
        out_vals = prim.impl(*[x.val for x in avals], **kwargs)
        return [
            core.ConcreteArray(val.dtype, val, weak_type=weak_type)
            for val, weak_type in safe_zip(out_vals, weak_types)
        ]
    elif least_specialized is core.ShapedArray:
        out_shapes = shape_rule(*avals, **kwargs)
        out_dtypes = dtype_rule(*avals, **kwargs)
        out_named_shapes = named_shape_rule(*avals, **kwargs)
        return [
            core.ShapedArray(s,
                             d,
                             weak_type=weak_type,
                             named_shape=named_shape)
            for s, d, weak_type, named_shape in safe_zip(
                out_shapes, out_dtypes, weak_types, out_named_shapes)
        ]
    elif least_specialized is core.UnshapedArray:
        out_dtypes = dtype_rule(*avals, **kwargs)
        return [
            core.UnshapedArray(dtype, weak_type=weak_type)
            for dtype, weak_type in safe_zip(out_dtypes, weak_types)
        ]
    else:
        raise TypeError(avals, least_specialized)
示例#3
0
 def test_concrete_array_string_representation(self):
     # https://github.com/google/jax/issues/5364
     self.assertEqual(
         str(
             core.ConcreteArray(np.dtype(np.int32),
                                np.array([1], dtype=np.int32))),
         'ConcreteArray([1], dtype=int32)')