コード例 #1
0
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule,
                                        weak_type_rule, named_shape_rule,
                                        *avals, **kwargs):
    assert prim.multiple_results
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    weak_types = weak_type_rule(*avals, **kwargs)
    if least_specialized is core.ConcreteArray:
        out_vals = prim.impl(*[x.val for x in avals], **kwargs)
        return [
            core.ConcreteArray(val.dtype, val, weak_type=weak_type)
            for val, weak_type in safe_zip(out_vals, weak_types)
        ]
    elif least_specialized is core.ShapedArray:
        out_shapes = shape_rule(*avals, **kwargs)
        out_dtypes = dtype_rule(*avals, **kwargs)
        out_named_shapes = named_shape_rule(*avals, **kwargs)
        return [
            core.ShapedArray(s,
                             d,
                             weak_type=weak_type,
                             named_shape=named_shape)
            for s, d, weak_type, named_shape in safe_zip(
                out_shapes, out_dtypes, weak_types, out_named_shapes)
        ]
    elif least_specialized is core.UnshapedArray:
        out_dtypes = dtype_rule(*avals, **kwargs)
        return [
            core.UnshapedArray(dtype, weak_type=weak_type)
            for dtype, weak_type in safe_zip(out_dtypes, weak_types)
        ]
    else:
        raise TypeError(avals, least_specialized)
コード例 #2
0
def _broadcast_to(arr, shape):
    if hasattr(arr, "broadcast_to"):
        return arr.broadcast_to(shape)
    _check_arraylike("broadcast_to", arr)
    arr = arr if isinstance(arr, ndarray) else _asarray(arr)
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
        shape = (shape, )
    shape = core.canonicalize_shape(shape)  # check that shape is concrete
    arr_shape = np.shape(arr)
    if core.symbolic_equal_shape(arr_shape, shape):
        return arr
    else:
        nlead = len(shape) - len(arr_shape)
        shape_tail = shape[nlead:]
        compatible = all(
            core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
            for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
        if nlead < 0 or not compatible:
            msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
            raise ValueError(msg.format(arr_shape, shape))
        diff, = np.where(
            tuple(not core.symbolic_equal_dim(arr_d, shape_d)
                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
        new_dims = tuple(range(nlead)) + tuple(nlead + diff)
        kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
        return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape,
                                    kept_dims)
コード例 #3
0
ファイル: sharded_jit.py プロジェクト: John1Tang/jax
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
コード例 #4
0
ファイル: global_device_array.py プロジェクト: jbampton/jax
def _get_shard_indices_replica_ids_uncached(
        global_shape: Shape, global_mesh: pxla.Mesh,
        mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
    indices = _get_indices(global_shape, global_mesh, mesh_axes)
    index_to_replica: Dict[int, int] = Counter()
    out = {}
    unique_shards = 0
    for device, index in safe_zip(global_mesh.devices.flat, indices):
        h_index = _hashed_index(index)
        replica_id = index_to_replica[h_index]
        if replica_id == 0:
            unique_shards += 1
        index_to_replica[h_index] += 1
        out[device] = (index, replica_id)

    shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes)
    expected_unique_shards = prod([
        g // s for g, s in safe_zip(global_shape, shard_shape)
        if g != 0 or s != 0
    ])
    if expected_unique_shards != unique_shards:
        raise RuntimeError(
            f'Number of expected unique shards are: {expected_unique_shards} but '
            f'got {unique_shards}. Please file a bug at '
            'https://github.com/google/jax/issues.')
    return out
コード例 #5
0
ファイル: tf_test_util.py プロジェクト: jamestwebber/jax
  def CheckShapePolymorphism(self, f_jax: Callable, *,
                             input_signature: Sequence[tf.TensorSpec],
                             polymorphic_shapes: Optional[Sequence[Any]],
                             expected_output_signature: Optional[tf.TensorSpec] = None,
                             enable_xla: bool = True):
    """Converts a function using polymorphic shapes.

    Args:
      f_jax: a JAX function of `n` arguments
      input_signature: used as the input signature for the tf.function.
      polymorphic_shapes: Specifies input shapes to be treated polymorphically
        during conversion.
      expected_output_signature: if given, this function tests whether the
        actual output signature is equal to this one.
      enable_xla: Whether to enable XLA conversion for jax2tf.convert.
    """
    f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
                             enable_xla=enable_xla)
    f_tf = tf.function(f_tf, autograph=False, input_signature=input_signature)
    concrete_f_tf = f_tf.get_concrete_function(*input_signature)
    if expected_output_signature:
      # Strangely, output_shapes can be a single shape for a function with a
      # single result, or a list/tuple of shapes.
      concrete_output_tf_shape = concrete_f_tf.output_shapes
      if not isinstance(concrete_output_tf_shape, (tuple, list)):  # Single result
        assert not isinstance(expected_output_signature, (tuple, list))
        expected_output_signature = [expected_output_signature]
        concrete_output_tf_shape = [concrete_output_tf_shape]

      for expected, found in util.safe_zip(expected_output_signature,
                                           concrete_output_tf_shape):
        self.assertEqual(tuple(expected.shape), tuple(found))
    return f_tf
コード例 #6
0
ファイル: sharded_jit.py プロジェクト: John1Tang/jax
def _xla_sharded_args(c, avals, in_parts):
  xla_args = []
  for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
    param = xla.with_sharding(c, sharding, xla.parameter, c, i,
                             *xla.aval_to_xla_shapes(aval))
    xla_args.append(param)
  return xla_args
コード例 #7
0
ファイル: tf_test_util.py プロジェクト: GregCT/jax
  def CheckShapePolymorphism(self, f_jax: Callable, *,
                             input_signature: Sequence[tf.TensorSpec],
                             polymorphic_shapes: Optional[Sequence[Any]],
                             expected_output_signature: tf.TensorSpec):
    """Convert a function using polymorphic shapes.

    Args:
      f_jax: a JAX function of `n` arguments
      input_signature: used as the input signature for the tf.function.
      in_shapes: if given, it must be a sequence of `n` shape specifications and
        must match the `input_signature`. (see jax2tf.convert).
    """
    f_tf = tf.function(
        jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes),
        autograph=False,
        input_signature=input_signature)
    concrete_f_tf = f_tf.get_concrete_function(*input_signature)
    if expected_output_signature:
      # Strangely, output_shapes can be a single shape for a function with a
      # single result, or a list/tuple of shapes.
      concrete_output_tf_shape = concrete_f_tf.output_shapes
      if not isinstance(concrete_output_tf_shape, (tuple, list)):  # Single result
        assert not isinstance(expected_output_signature, (tuple, list))
        expected_output_signature = [expected_output_signature]
        concrete_output_tf_shape = [concrete_output_tf_shape]

      for expected, found in util.safe_zip(expected_output_signature,
                                           concrete_output_tf_shape):
        self.assertEqual(tuple(expected.shape), tuple(found))
    return f_tf
コード例 #8
0
ファイル: fft.py プロジェクト: jamestwebber/jax
def _fft_core(func_name, fft_type, a, s, axes, norm):
    full_name = "jax.numpy.fft." + func_name

    if s is not None:
        s = tuple(map(operator.index, s))
        if np.any(np.less(s, 0)):
            raise ValueError("Shape should be non-negative.")
    if norm is not None:
        raise NotImplementedError("%s only supports norm=None, got %s" %
                                  (full_name, norm))
    if s is not None and axes is not None and len(s) != len(axes):
        # Same error as numpy.
        raise ValueError("Shape and axes have different lengths.")

    orig_axes = axes
    if axes is None:
        if s is None:
            axes = range(a.ndim)
        else:
            axes = range(a.ndim - len(s), a.ndim)

    if len(axes) != len(set(axes)):
        raise ValueError("%s does not support repeated axes. Got axes %s." %
                         (full_name, axes))

    if len(axes) > 3:
        # XLA does not support FFTs over more than 3 dimensions
        raise ValueError("%s only supports 1D, 2D, and 3D FFTs. "
                         "Got axes %s with input rank %s." %
                         (full_name, orig_axes, a.ndim))

    # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
    if orig_axes is not None:
        axes = tuple(range(a.ndim - len(axes), a.ndim))
        a = jnp.moveaxis(a, orig_axes, axes)

    if s is not None:
        a = jnp.asarray(a)
        in_s = list(a.shape)
        for axis, x in safe_zip(axes, s):
            in_s[axis] = x
        if fft_type == xla_client.FftType.IRFFT:
            in_s[-1] = (in_s[-1] // 2 + 1)
        # Cropping
        a = a[tuple(map(slice, in_s))]
        # Padding
        a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)])
    else:
        if fft_type == xla_client.FftType.IRFFT:
            s = [a.shape[axis] for axis in axes[:-1]]
            if axes:
                s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
        else:
            s = [a.shape[axis] for axis in axes]

    transformed = lax.fft(a, fft_type, s)

    if orig_axes is not None:
        transformed = jnp.moveaxis(transformed, axes, orig_axes)
    return transformed
コード例 #9
0
ファイル: global_device_array.py プロジェクト: wayfeng/jax
def _get_shard_indices_replica_ids_uncached(
        global_shape: Shape, global_mesh: pxla.Mesh,
        mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
    indices = _get_indices(global_shape, global_mesh, mesh_axes)
    replica_ids = _calc_replica_ids(global_mesh, mesh_axes)
    return dict((d, (i, r)) for d, i, r in safe_zip(global_mesh.devices.flat,
                                                    indices, replica_ids))
コード例 #10
0
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 2)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.mesh_axes, mesh_axes)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)
コード例 #11
0
    def test_gda_subset_devices(self, mesh_axes, expected_index,
                                expected_shard_shape, expected_replica_ids):
        global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)

        def cb(index):
            return global_input_data[index]

        gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                              mesh_axes, cb)
        self.assertEqual(gda.local_shards[0].index, expected_index[0])
        self.assertArraysEqual(gda.local_data(0),
                               global_input_data[expected_index[0]])
        self.assertEqual(gda.local_shards[1].index, expected_index[1])
        self.assertArraysEqual(gda.local_data(1),
                               global_input_data[expected_index[1]])
        self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
        replica_ids = [i.replica_id for i in gda.local_shards]
        self.assertListEqual(replica_ids, expected_replica_ids)
        for g, l in safe_zip(gda.global_shards, gda.local_shards):
            self.assertEqual(g.device, l.device)
            self.assertEqual(g.index, l.index)
            self.assertEqual(g.replica_id, l.replica_id)
            self.assertArraysEqual(g.data, l.data)
コード例 #12
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    indices = _get_indices(global_shape, global_mesh, mesh_axes)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return {d: i
            for d, i in safe_zip(global_mesh.devices.flat, indices)
            }  # type: ignore
コード例 #13
0
def _xla_callable_args(c,
                       avals,
                       tuple_args,
                       *,
                       replicated=None,
                       partitions=None,
                       partitions_proto: bool = False,
                       donated_invars=None,
                       filter_tokens=True):
    assert partitions is None or len(partitions) == len(avals)
    if not tuple_args:
        if replicated is None:
            replicated = [None] * len(avals)
        if partitions is None:
            parts: List[object] = [None] * len(avals)
        elif partitions_proto:
            parts = partitions
        else:
            parts = [
                _replicated_param if part is None else part
                for part in partitions
            ]
        counts = it.count()
        xla_args = [
            _xla_param(c, next(counts), xla_shape, r, p, partitions_proto,
                       filter_tokens)
            for (a, r, p) in safe_zip(avals, replicated, parts)
            for xla_shape in aval_to_xla_shapes(a)
        ]
        if donated_invars is not None:
            donated_invars = [
                d for (a, _, _,
                       d) in zip(avals, replicated, parts, donated_invars)
                for xla_shape in aval_to_xla_shapes(a)
            ]
        return xla_args, donated_invars
    else:
        if replicated is not None:
            replicated = [
                r for a, r in zip(avals, replicated) if a is not abstract_token
            ]
        if partitions is None:
            tuple_parts = None
        elif partitions_proto:
            tuple_parts = tuple_sharding_proto(partitions)
        else:
            tuple_parts = tuple(partitions)
        tuple_shape = xc.Shape.tuple_shape([
            shape if not (filter_tokens and a is abstract_token) else
            _token_param_shape() for a in avals
            for shape in aval_to_xla_shapes(a)
        ])
        tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts,
                                 partitions_proto, filter_tokens)
        xla_args = [
            v if not (filter_tokens and a is abstract_token) else
            xops.CreateToken(c)
            for a, v in zip(avals, xla_destructure(c, tuple_param))
        ]
        return xla_args, donated_invars
コード例 #14
0
ファイル: sharded_jit.py プロジェクト: John1Tang/jax
def _avals_to_results_handler(nrep, npart, partitions, out_avals):
  handlers = [_aval_to_result_handler(npart, parts, out_aval)
              for parts, out_aval in safe_zip(partitions, out_avals)]

  def handler(out_bufs):
    return [h(bufs) for h, bufs in zip(handlers, out_bufs)]

  return handler
コード例 #15
0
ファイル: batching.py プロジェクト: wayfeng/jax
 def _map_to_tile(*args_flat):
     sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat)
              if i is not None)
     tile_size_ = tile_size or next(sizes, None)
     assert tile_size_ is not None, "No mapped arguments?"
     outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat,
                              in_axes_flat), {}
     yield map(untile_axis, outputs_flat, out_axes_flat)
コード例 #16
0
ファイル: convolution.py プロジェクト: frederikwilde/jax
def _conv_general_vjp_lhs_padding(in_shape, window_dimensions, window_strides,
                                  out_shape, padding, lhs_dilation,
                                  rhs_dilation) -> List[Tuple[int, int]]:
    lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
    rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
    out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
    pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
    pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 -
                 out_dilated_shape - pad_before)
    return safe_zip(pad_before, pad_after)
コード例 #17
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
コード例 #18
0
def get_shard_indices_replica_ids(
        global_shape: Shape, global_mesh: pxla.Mesh,
        mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
    indices = _get_indices(global_shape, global_mesh, mesh_axes)
    index_to_replica: Dict[_HashableIndex, int] = Counter()
    out = {}
    for device, index in safe_zip(global_mesh.devices.flat, indices):
        h_index = _HashableIndex(index)
        replica_id = index_to_replica[h_index]
        index_to_replica[h_index] += 1
        out[device] = (index, replica_id)
    return out
コード例 #19
0
def _gsda_shard_arg(x, devices, indices):
    pjit_mesh = maps.thread_resources.env.physical_mesh
    if x._global_mesh != pjit_mesh:
        raise ValueError(
            "Pjit's mesh and GDA's mesh should be equal. Got Pjit "
            f"mesh: {pjit_mesh},\n GDA mesh: {x._global_mesh}")
    assert all(g.index == i for g, i in safe_zip(x.global_shards, indices)), (
        "Indices calculated by GDA and pjit do not match. Please file a bug "
        "on https://github.com/google/jax/issues. "
        f"Got GDA indices: {[g.index for g in x.global_shards]},\n"
        f"pjit indices: {indices}")
    return [s.data for s in x.local_shards]
コード例 #20
0
def benchmark_suite(prepare: Callable[..., Callable],
                    params_list: List[Dict],
                    name: str,
                    target_total_secs: int = None):
    """Benchmarks a function for several combinations of parameters.

  Prints the summarized results in a table..

  Args:
    prepare: given kwargs returns a benchmark function specialized to the kwargs.
    params_list: a list of kwargs on which to run the benchmark.
    name: the name of this benchmark suite
    target_total_secs: the ``target_total_secs`` to pass to ``benchmark``.
 """
    # Sort parameters alphabetically so benchmark results print consistently.
    params_list = [OrderedDict(sorted(p.items())) for p in params_list]
    assert all(p.keys() == params_list[0].keys() for p in params_list)

    times = []
    for params in params_list:
        f = prepare(**params)
        subname = name + "".join("_%s=%s" % (n, _param_str(p))
                                 for n, p in params.items())
        times.append(
            benchmark(f, name=subname, target_total_secs=target_total_secs))

    param_names = list(params_list[0].keys())
    data_header = param_names + ["mean", "%std", "relative"]
    data = [
        list(map(_param_str, params.values())) +
        [t.mean(), _pstd(t), t.mean() / times[0].mean()]
        for params, t in safe_zip(params_list, times)
    ]

    if FLAGS.baseline_dir:
        mean_idx = len(param_names)
        means = _get_baseline_means(FLAGS.baseline_dir, name)
        assert len(means) == len(data), (means, data)
        data_header.append("mean/baseline")
        for idx, mean in enumerate(means):
            data[idx].append(data[idx][mean_idx] / mean)

    print("---------Benchmark summary for %s---------" % name)
    print(tabulate(data, data_header))
    print()

    if FLAGS.export_dir:
        filename = _export_results(data_header, data, FLAGS.export_dir, name)
        print("Wrote %s results to %s" % (name, filename))
        print()
コード例 #21
0
ファイル: global_device_array.py プロジェクト: jbampton/jax
    def __init__(self,
                 global_shape: Shape,
                 global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes,
                 device_buffers: Sequence[DeviceArray],
                 _gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
                 _enable_checks: bool = True):
        self._global_shape = global_shape
        self._global_mesh = global_mesh
        self._mesh_axes = mesh_axes
        self._device_buffers = device_buffers
        # Optionally precomputed for performance.
        self._gda_fast_path_args = _gda_fast_path_args
        self._current_process = xb.process_index()

        if self._gda_fast_path_args is None:
            self._local_devices = self._global_mesh.local_devices
        else:
            self._local_devices = self._gda_fast_path_args.local_devices

        if _enable_checks or config.jax_enable_checks:
            for db, ld in safe_zip(device_buffers, self._local_devices):
                if db.device() != ld:
                    raise ValueError(
                        "The `global_mesh.local_devices` and `device_buffers` device "
                        "order doesn't match. Please use `global_mesh.local_devices` to "
                        "put arrays on devices instead of `jax.local_devices()`"
                    )

        if _enable_checks or config.jax_enable_checks:
            ss = get_shard_shape(self._global_shape, self._global_mesh,
                                 self.mesh_axes)
            assert all(db.shape == ss for db in device_buffers), (
                f"Expected shard shape {ss} doesn't match the device buffer "
                f"shape, got: {[db.shape for db in device_buffers]}")

        dtype = device_buffers[0].dtype
        if _enable_checks or config.jax_enable_checks:
            assert all(db.dtype == dtype for db in device_buffers), (
                "Input arrays to GlobalDeviceArray must have matching dtypes, "
                f"got: {[db.dtype for db in device_buffers]}")
        self.dtype = dtype
コード例 #22
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat,
                                            indices))  # type: ignore
コード例 #23
0
ファイル: sharded_jit.py プロジェクト: xueeinstein/jax
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
コード例 #24
0
 def devices_indices_map(
     self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
   indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
   return {d: i for d, i in safe_zip(self.devices.flat, indices)}  # type: ignore
コード例 #25
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform
    if platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError(f"sharded_jit not supported for {platform}")

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    ctx = xla.TranslationContext(
        c, platform, axis_env,
        extend_name_stack(wrap_name(name, "sharded_jit")))
    out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
    out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
コード例 #26
0
ファイル: shape_poly_test.py プロジェクト: zhaowilliam/jax
 def solve_shape_vars(shape_spec: str,
                      shape: Sequence[int]) -> Dict[str, int]:
     shape_polys = masking.parse_spec(shape_spec)
     return jax2tf.jax2tf._solve_shape_vars(
         util.safe_zip(shape_polys, shape))
コード例 #27
0
def tree_reduce(function: Callable[[T, Any], T],
                tree: Any,
                initializer: Any = no_initializer) -> T:
  if initializer is no_initializer:
    return functools.reduce(function, tree_leaves(tree))
  else:
    return functools.reduce(function, tree_leaves(tree), initializer)

def tree_all(tree):
  return all(tree_leaves(tree))

register_pytree_node(
  collections.OrderedDict,
  lambda x: (tuple(x.values()), tuple(x.keys())),
  lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))

register_pytree_node(
  collections.defaultdict,
  lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))),
  lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))  # type: ignore[index]



class _HashableCallableShim:
  """Object that delegates __call__, __hash__, and __eq__ to another object."""
  def __init__(self, fun):
    self.fun = fun

  def __call__(self, *args, **kw):
    return self.fun(*args, **kw)