def gather(operand, start_indices): dimension_numbers = xla_data_pb2.GatherDimensionNumbers() dimension_numbers.offset_dims.extend([1]) dimension_numbers.collapsed_slice_dims.extend([0]) dimension_numbers.start_index_map.extend([0]) dimension_numbers.index_vector_dim = 1 return xla.gather(operand, start_indices, dimension_numbers, slice_sizes)
def _gather(operand, start_indices, dimension_numbers, slice_sizes, indices_are_sorted=False): """Tensorflow implementation of gather.""" out_shape = _gather_shape( operand, start_indices, dimension_numbers, slice_sizes) proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers) out, = tf.xla.experimental.compile( lambda o, s: tfxla.gather(o, s, proto, slice_sizes, indices_are_sorted), [operand, start_indices]) out.set_shape(out_shape) return out
def _gather(operand, start_indices, dimension_numbers, slice_sizes): """Tensorflow implementation of gather.""" res = _try_tf_gather(operand, start_indices, dimension_numbers, slice_sizes) if res is not None: return res out_shape = _gather_shape( operand, start_indices, dimension_numbers, slice_sizes) proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers) out, = tf.xla.experimental.compile( lambda o, s: tfxla.gather(o, s, proto, slice_sizes, False), [operand, start_indices]) out.set_shape(out_shape) return out
def _gather(operand, start_indices, dimension_numbers, slice_sizes): """Tensorflow implementation of gather.""" res = _try_tf_gather(operand, start_indices, dimension_numbers, slice_sizes) if res is not None: return res out_shape = _gather_shape( operand, start_indices, dimension_numbers, slice_sizes) proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers) # We compile because without it we run into a TF bug: # tfxla.gather fails on constant inputs with "must be a compile-time constant" # b/153556869 out = _xla_compile(lambda o, s: tfxla.gather(o, s, proto, slice_sizes, False), operand, start_indices) out.set_shape(out_shape) return out