Esempio n. 1
0
def format_table(
    pa_table: pa.Table,
    key: Union[int, slice, range, str, Iterable],
    formatter: Formatter,
    format_columns: Optional[list] = None,
    output_all_columns=False,
):
    """
    Format a pyarrow Table depending on the key that was used and a Formatter object.

    Args:
        pa_table (``pyarrow.Table``): The input pyarrow Table to format
        key (``Union[int, slice, range, str, Iterable]``): Depending on the key that was used, the formatter formats
            the table as either a row, a column or a batch.
        formatter (``datasets.formatting.formatting.Formatter``): Any subclass of a Formatter such as
            PythonFormatter, NumpyFormatter, etc.
        format_columns (Optional ``List[str]``): if not None, it defines the columns that will be formatted using the
            given formatter. Other columns are discarded (unless ``output_all_columns`` is True)
        output_all_columns (``bool``, defaults to False). If True, the formatted output is completed using the columns
            that are not in the ``format_columns`` list. For these columns, the PythonFormatter is used.


    Returns:
        A row, column or batch formatted object defined by the Formatter:
        - the PythonFormatter returns a dictionary for a row or a batch, and a list for a column.
        - the NumpyFormatter returns a dictionary for a row or a batch, and a np.array for a column.
        - the PandasFormatter returns a pd.DataFrame for a row or a batch, and a pd.Series for a column.
        - the TorchFormatter returns a dictionary for a row or a batch, and a torch.Tensor for a column.
        - the TFFormatter returns a dictionary for a row or a batch, and a tf.Tensor for a column.
    """
    query_type = key_to_query_type(key)
    python_formatter = PythonFormatter()
    if format_columns is None:
        return formatter(pa_table, query_type=query_type)
    elif query_type == "column":
        if key in format_columns:
            return formatter(pa_table, query_type)
        else:
            return python_formatter(pa_table, query_type=query_type)
    else:
        pa_table_to_format = pa_table.drop(col for col in pa_table.column_names
                                           if col not in format_columns)
        formatted_output = formatter(pa_table_to_format, query_type=query_type)
        if output_all_columns:
            if isinstance(formatted_output, Mapping):
                pa_table_with_remaining_columns = pa_table.drop(
                    col for col in pa_table.column_names
                    if col in format_columns)
                remaining_columns_dict = python_formatter(
                    pa_table_with_remaining_columns, query_type=query_type)
                formatted_output.update(remaining_columns_dict)
            else:
                raise TypeError(
                    f"Custom formatting function must return a dict to work with output_all_columns=True, but got {formatted_output}"
                )
        return formatted_output
Esempio n. 2
0
def group_by_pk_hash_bucket(table: pa.Table, num_buckets: int,
                            primary_keys: List[str]) -> np.ndarray:

    # generate the primary key digest column
    all_pk_column_fields = []
    for pk_name in primary_keys:
        # casting a primary key column to numpy also ensures no nulls exist
        column_fields = table[pk_name].to_numpy()
        all_pk_column_fields.append(column_fields)
    hash_column_generator = hash_pk_bytes_generator(all_pk_column_fields)
    table = sc.append_pk_hash_column(table, hash_column_generator)

    # drop primary key columns to free up memory
    table = table.drop(primary_keys)

    # group hash bucket record indices
    hash_bucket_to_indices = np.empty([num_buckets], dtype="object")
    record_index = 0
    for digest in sc.pk_hash_column_np(table):
        hash_bucket = pk_digest_to_hash_bucket_index(digest, num_buckets)
        if hash_bucket_to_indices[hash_bucket] is None:
            hash_bucket_to_indices[hash_bucket] = []
        hash_bucket_to_indices[hash_bucket].append(record_index)
        record_index += 1

    # generate the ordered record number column
    hash_bucket_to_table = np.empty([num_buckets], dtype="object")
    for hash_bucket in range(len(hash_bucket_to_indices)):
        indices = hash_bucket_to_indices[hash_bucket]
        if indices:
            hash_bucket_to_table[hash_bucket] = sc.append_record_idx_col(
                table.take(indices),
                indices,
            )
    return hash_bucket_to_table
def _remove_unsupported_feature_columns(examples_table: pa.Table,
                                        schema: schema_pb2.Schema) -> pa.Table:
  """Removes feature columns that contain unsupported values.

  All feature columns that are multivalent are dropped since they are
  not supported by sk-learn.

  All columns of STRUCT type are also dropped.

  Args:
    examples_table: Arrow table containing a batch of examples.
    schema: The schema for the data.

  Returns:
    Arrow table.
  """
  multivalent_features = schema_util.get_multivalent_features(schema)
  unsupported_columns = set()
  for f in multivalent_features:
    unsupported_columns.add(f.steps()[0])
  for column_name, column in zip(examples_table.schema.names,
                                 examples_table.itercolumns()):
    if (stats_util.get_feature_type_from_arrow_type(
        types.FeaturePath([column_name]),
        column.type) == statistics_pb2.FeatureNameStatistics.STRUCT):
      unsupported_columns.add(column_name)
  return examples_table.drop(unsupported_columns)
Esempio n. 4
0
def pyarrow_transform(batch: pa.Table) -> pa.Table:
    batch = batch.filter(pac.equal(batch["variety"], "Versicolor"))
    batch = batch.append_column(
        "normalized.sepal.length",
        pac.divide(batch["sepal.length"], pac.max(batch["sepal.length"])),
    )
    return batch.drop(["sepal.length"])
Esempio n. 5
0
def _query_table(pa_table: pa.Table, key: Union[int, slice, range, str,
                                                Iterable]) -> pa.Table:
    """
    Query a pyarrow Table to extract the subtable that correspond to the given key.
    """
    if isinstance(key, int):
        return pa_table.slice(key % pa_table.num_rows, 1)
    if isinstance(key, slice):
        key = range(*key.indices(pa_table.num_rows))
    if isinstance(key, range):
        if _is_range_contiguous(key) and key.start >= 0:
            return pa_table.slice(key.start, key.stop - key.start)
        else:
            pass  # treat as an iterable
    if isinstance(key, str):
        return pa_table.drop(column for column in pa_table.column_names
                             if column != key)
    if isinstance(key, Iterable):
        if len(key) == 0:
            return pa_table.slice(0, 0)
        # don't use pyarrow.Table.take even for pyarrow >=1.0 (see https://issues.apache.org/jira/browse/ARROW-9773)
        return pa.concat_tables(
            pa_table.slice(int(i) % pa_table.num_rows, 1) for i in key)

    _raise_bad_key_type(key)