예제 #1
0
def ConvertNumpyArrayToLiteral(value):
    """Converts a Numpy array or a nested tuple thereof to an XLA literal."""
    if isinstance(value, tuple):
        literal = xla_data_pb2.LiteralProto()
        literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
        for component in value:
            component_literal = literal.tuple_literals.add()
            component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component))
        return literal
    else:
        return _ConvertNumpyArrayToLiteral(value)
예제 #2
0
def _ConvertNumpyArrayToLiteral(ndarray):
  """Converts a Numpy array to a XLA literal."""
  type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)]
  literal = xla_data_pb2.LiteralProto()
  literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)

  if ndarray.ndim == 0:
    getattr(literal, type_record.literal_field_name).append(
        ndarray.astype(type_record.literal_field_type).item())
  else:
    # Ndarrays with boolean dtypes need special type conversion with protobufs
    if ndarray.dtype in {_np.bool_, _np.dtype('bool')}:
      for element in _np.nditer(ndarray):
        getattr(literal, type_record.literal_field_name).append(
            type_record.literal_field_type(element))
    else:
      ndarray_flat = ndarray.ravel(order='A')
      getattr(literal, type_record.literal_field_name).extend(ndarray_flat)
  return literal