def gather(operand, start_indices): dimension_numbers = xla_data_pb2.GatherDimensionNumbers() dimension_numbers.offset_dims.extend([1]) dimension_numbers.collapsed_slice_dims.extend([0]) dimension_numbers.start_index_map.extend([0]) dimension_numbers.index_vector_dim = 1 return xla.gather(operand, start_indices, dimension_numbers, slice_sizes)
def _gather_dimensions_proto(indices_shape, dimension_numbers): proto = xla_data_pb2.GatherDimensionNumbers() proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto