def _gather_for_scalar_indexing(args: GatherArgs): """Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice. E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0). """ indices = tf.expand_dims(args.dnums.start_index_map, 1) # lax.gather uses an "index map" which maps `start_indices` to the right axes # in `operand`. Since tf.strided_slice uses a single array for specifying the # start indices, we use a scatter to map the start indices to the right axes. op_shape = jax2tf._eval_shape(args.op_shape) slice_sizes_tf = jax2tf._eval_shape(args.slice_sizes) # TODO(marcvanzee): Consider transposing `operand`, which is probably more # optimization friendly. begin = tf.scatter_nd(indices, args.start_indices, [len(op_shape)]) begin = _clip(op_shape, begin, slice_sizes_tf) end = slice_sizes_tf + begin # `collapsed_slice_dims` is a tuple of dimensions to collapse, e.g. (0, 2). # `tf.strided_slice` expects a binary mask to specify the shrink axes, i.e., # if we want to shrink axis 0 and 2, this corresponds to binary mask 101, # which is 5 in decimals. The following line converts the lax representation # to the one used by `tf.strided_slice`. shrink_mask = sum(2**x for x in args.dnums.collapsed_slice_dims) res = tf.strided_slice(args.operand, begin, end, shrink_axis_mask=shrink_mask) # Shape inference doesn't work for tf.strided_slice. res.set_shape(jax2tf._aval_to_tf_shape(args.out_aval)) return res
def _dynamic_update_slice(operand, update, *start_indices, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): start_indices = tf.stack(start_indices) op_shape = jax2tf._eval_shape(_in_avals[0].shape) op_size = tf.size(operand) update_shape_tf = jax2tf._eval_shape(_in_avals[1].shape) start_indices = _clip(op_shape, start_indices, update_shape_tf) end_indices = tf.add(start_indices, update_shape_tf) flatten = tf.keras.backend.flatten # Get the cells to update in `operand` as an array of ids. id_tensor = tf.reshape(tf.range(op_size), op_shape) scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices) # Create an array containing updates at scattered_indices and zeros otherwise. flat_indices = tf.expand_dims(flatten(scattered_indices), -1) flat_update = flatten(update) update = tf.scatter_nd(flat_indices, flat_update, (op_size, )) update = tf.reshape(update, op_shape) # Create a bool mask that is True only where `operand` should be updated. update_mask = tf.ones_like(flat_update, dtype=tf.bool) update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size, )) update_mask = tf.reshape(update_mask, op_shape) # Use the mask to only update `operand` with `update`. return tf.where(update_mask, update, operand)
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): start_indices = tf.stack(start_indices) slice_sizes_tf = jax2tf._eval_shape(slice_sizes) operand_shape = jax2tf._eval_shape(_in_avals[0].shape) start_indices = _clip(operand_shape, start_indices, slice_sizes_tf) return tf.slice(operand, start_indices, size=slice_sizes_tf)
def _pre_gather_with_batch_dims(args: GatherArgs): """Returns True if this call to gather has non-empty batch dimensions. This is for instance triggered when doing jax.vmap(lax.dynamic_slice). """ # All dimensions in the output array and not in offset_dims are batch_dims. batch_dims = tuple(x for x in range(len(args.out_aval.shape)) if x not in args.dnums.offset_dims) # We assume exactly one batch (and one or more non-batch dimensions). if len(batch_dims) != 1: raise ValueError(f"batch_dims is {len(batch_dims)} but should be 1") # `start_index_map` maps indices in `start_indices` to indices in `operand`. # For simplicity, we currently only consider the case where this mapping is # the identity function, i.e., [2, 3] in `start_indices` maps to # `operand[2, 3]`. if args.dnums.start_index_map != tuple(range( args.start_indices_shape[-1])): raise ValueError("unsupported start_index_map") # The batch dims in `start_indices` and `operand` should agree. if jax2tf._eval_shape(args.op_shape)[0] != args.start_indices_shape[0]: raise ValueError("Batch dimensions in operand and start_indices don't " "agree")
def _gather_for_multidim_indexing(args: GatherArgs): """Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather. E.g., jnp.take(op, [[0], [1]], axis=0). """ # Guess the axis. axis = args.dnums.collapsed_slice_dims[0] squeezed_indices = tf.squeeze(args.start_indices, -1) op_shape = jax2tf._eval_shape(args.op_shape) start_indices = _clip((op_shape[axis], ), squeezed_indices, (1, )) return tf.gather(args.operand, start_indices, axis=axis, batch_dims=0)
def _gather_with_batch_dims(args: GatherArgs): """Implements call to gather with non-empty batch dimensions. E.g., when doing `jax.vmap(lax.dynamic_slice). """ op_shape = jax2tf._eval_shape(args.op_shape) start_indices = _clip(op_shape, args.start_indices, args.slice_sizes) result = tf.map_fn( lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes), start_indices, fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype)) result = tf.squeeze(result, axis=1) return result
def _pad(operand, padding_value, *, padding_config, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): low, high, interior = util.unzip3(padding_config) # Do only the interior padding first. This is rarely needed. if any(i != 0 for _, _, i in padding_config): operand = _interior_padding(operand, padding_value, padding_config, jax2tf._eval_shape(_in_avals[0].shape)) # Now do the non-negative edge padding. This is the common case, use tf.pad. non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0)) for lo, hi, _ in padding_config] operand = tf.pad(operand, non_negative_padding, mode="CONSTANT", constant_values=padding_value) # Now the negative edge padding (this is also rare) if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config): output_shape = jax2tf._eval_shape(_out_aval.shape) begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config] operand = tf.slice(operand, begins, output_shape) return operand