Ejemplo n.º 1
0
def _gather_for_scalar_indexing(args: GatherArgs):
    """Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice.

  E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
  """
    indices = tf.expand_dims(args.dnums.start_index_map, 1)
    # lax.gather uses an "index map" which maps `start_indices` to the right axes
    # in `operand`. Since tf.strided_slice uses a single array for specifying the
    # start indices, we use a scatter to map the start indices to the right axes.
    op_shape = jax2tf._eval_shape(args.op_shape)
    slice_sizes_tf = jax2tf._eval_shape(args.slice_sizes)
    # TODO(marcvanzee): Consider transposing `operand`, which is probably more
    # optimization friendly.
    begin = tf.scatter_nd(indices, args.start_indices, [len(op_shape)])
    begin = _clip(op_shape, begin, slice_sizes_tf)
    end = slice_sizes_tf + begin

    # `collapsed_slice_dims` is a tuple of dimensions to collapse, e.g. (0, 2).
    # `tf.strided_slice` expects a binary mask to specify the shrink axes, i.e.,
    # if we want to shrink axis 0 and 2, this corresponds to binary mask 101,
    # which is 5 in decimals. The following line converts the lax representation
    # to the one used by `tf.strided_slice`.
    shrink_mask = sum(2**x for x in args.dnums.collapsed_slice_dims)
    res = tf.strided_slice(args.operand,
                           begin,
                           end,
                           shrink_axis_mask=shrink_mask)
    # Shape inference doesn't work for tf.strided_slice.
    res.set_shape(jax2tf._aval_to_tf_shape(args.out_aval))
    return res
Ejemplo n.º 2
0
def _dynamic_update_slice(operand, update, *start_indices,
                          _in_avals: Sequence[core.ShapedArray],
                          _out_aval: core.ShapedArray):
    start_indices = tf.stack(start_indices)

    op_shape = jax2tf._eval_shape(_in_avals[0].shape)
    op_size = tf.size(operand)
    update_shape_tf = jax2tf._eval_shape(_in_avals[1].shape)

    start_indices = _clip(op_shape, start_indices, update_shape_tf)
    end_indices = tf.add(start_indices, update_shape_tf)
    flatten = tf.keras.backend.flatten

    # Get the cells to update in `operand` as an array of ids.
    id_tensor = tf.reshape(tf.range(op_size), op_shape)
    scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices)

    # Create an array containing updates at scattered_indices and zeros otherwise.
    flat_indices = tf.expand_dims(flatten(scattered_indices), -1)
    flat_update = flatten(update)
    update = tf.scatter_nd(flat_indices, flat_update, (op_size, ))
    update = tf.reshape(update, op_shape)

    # Create a bool mask that is True only where `operand` should be updated.
    update_mask = tf.ones_like(flat_update, dtype=tf.bool)
    update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size, ))
    update_mask = tf.reshape(update_mask, op_shape)

    # Use the mask to only update `operand` with `update`.
    return tf.where(update_mask, update, operand)
Ejemplo n.º 3
0
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
                   _in_avals: Sequence[core.ShapedArray],
                   _out_aval: core.ShapedArray):
    start_indices = tf.stack(start_indices)
    slice_sizes_tf = jax2tf._eval_shape(slice_sizes)

    operand_shape = jax2tf._eval_shape(_in_avals[0].shape)
    start_indices = _clip(operand_shape, start_indices, slice_sizes_tf)
    return tf.slice(operand, start_indices, size=slice_sizes_tf)
Ejemplo n.º 4
0
def _pre_gather_with_batch_dims(args: GatherArgs):
    """Returns True if this call to gather has non-empty batch dimensions.

  This is for instance triggered when doing jax.vmap(lax.dynamic_slice).
  """
    # All dimensions in the output array and not in offset_dims are batch_dims.
    batch_dims = tuple(x for x in range(len(args.out_aval.shape))
                       if x not in args.dnums.offset_dims)

    # We assume exactly one batch (and one or more non-batch dimensions).
    if len(batch_dims) != 1:
        raise ValueError(f"batch_dims is {len(batch_dims)} but should be 1")

    # `start_index_map` maps indices in `start_indices` to indices in `operand`.
    # For simplicity, we currently only consider the case where this mapping is
    # the identity function, i.e., [2, 3] in `start_indices` maps to
    # `operand[2, 3]`.
    if args.dnums.start_index_map != tuple(range(
            args.start_indices_shape[-1])):
        raise ValueError("unsupported start_index_map")

    # The batch dims in `start_indices` and `operand` should agree.
    if jax2tf._eval_shape(args.op_shape)[0] != args.start_indices_shape[0]:
        raise ValueError("Batch dimensions in operand and start_indices don't "
                         "agree")
Ejemplo n.º 5
0
def _gather_for_multidim_indexing(args: GatherArgs):
    """Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather.

  E.g., jnp.take(op, [[0], [1]], axis=0).
  """
    # Guess the axis.
    axis = args.dnums.collapsed_slice_dims[0]
    squeezed_indices = tf.squeeze(args.start_indices, -1)
    op_shape = jax2tf._eval_shape(args.op_shape)
    start_indices = _clip((op_shape[axis], ), squeezed_indices, (1, ))
    return tf.gather(args.operand, start_indices, axis=axis, batch_dims=0)
Ejemplo n.º 6
0
def _gather_with_batch_dims(args: GatherArgs):
    """Implements call to gather with non-empty batch dimensions.

  E.g., when doing `jax.vmap(lax.dynamic_slice).
  """
    op_shape = jax2tf._eval_shape(args.op_shape)
    start_indices = _clip(op_shape, args.start_indices, args.slice_sizes)
    result = tf.map_fn(
        lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes),
        start_indices,
        fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype))
    result = tf.squeeze(result, axis=1)
    return result
Ejemplo n.º 7
0
def _pad(operand, padding_value, *, padding_config,
         _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
    low, high, interior = util.unzip3(padding_config)

    # Do only the interior padding first. This is rarely needed.
    if any(i != 0 for _, _, i in padding_config):
        operand = _interior_padding(operand, padding_value, padding_config,
                                    jax2tf._eval_shape(_in_avals[0].shape))

    # Now do the non-negative edge padding. This is the common case, use tf.pad.
    non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0))
                            for lo, hi, _ in padding_config]
    operand = tf.pad(operand,
                     non_negative_padding,
                     mode="CONSTANT",
                     constant_values=padding_value)
    # Now the negative edge padding (this is also rare)
    if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config):
        output_shape = jax2tf._eval_shape(_out_aval.shape)
        begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config]
        operand = tf.slice(operand, begins, output_shape)

    return operand