예제 #1
0
 def gather(operand, start_indices):
   dimension_numbers = xla_data_pb2.GatherDimensionNumbers()
   dimension_numbers.offset_dims.extend([1])
   dimension_numbers.collapsed_slice_dims.extend([0])
   dimension_numbers.start_index_map.extend([0])
   dimension_numbers.index_vector_dim = 1
   return xla.gather(operand, start_indices, dimension_numbers, slice_sizes)
예제 #2
0
파일: jax_to_tf.py 프로젝트: PForet/jax
def _gather(operand, start_indices, dimension_numbers, slice_sizes,
            indices_are_sorted=False):
  """Tensorflow implementation of gather."""
  out_shape = _gather_shape(
      operand, start_indices, dimension_numbers, slice_sizes)
  proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
  out, = tf.xla.experimental.compile(
      lambda o, s: tfxla.gather(o, s, proto, slice_sizes, indices_are_sorted),
      [operand, start_indices])
  out.set_shape(out_shape)
  return out
예제 #3
0
파일: jax2tf.py 프로젝트: hereismari/jax
def _gather(operand, start_indices, dimension_numbers, slice_sizes):
  """Tensorflow implementation of gather."""
  res = _try_tf_gather(operand, start_indices, dimension_numbers, slice_sizes)
  if res is not None:
    return res
  out_shape = _gather_shape(
      operand, start_indices, dimension_numbers, slice_sizes)
  proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
  out, = tf.xla.experimental.compile(
      lambda o, s: tfxla.gather(o, s, proto, slice_sizes, False),
      [operand, start_indices])
  out.set_shape(out_shape)
  return out
예제 #4
0
def _gather(operand, start_indices, dimension_numbers, slice_sizes):
  """Tensorflow implementation of gather."""
  res = _try_tf_gather(operand, start_indices, dimension_numbers, slice_sizes)
  if res is not None:
    return res
  out_shape = _gather_shape(
      operand, start_indices, dimension_numbers, slice_sizes)
  proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
  # We compile because without it we run into a TF bug:
  # tfxla.gather fails on constant inputs with "must be a compile-time constant"
  # b/153556869
  out = _xla_compile(lambda o, s: tfxla.gather(o, s, proto, slice_sizes, False),
                     operand, start_indices)
  out.set_shape(out_shape)
  return out