def update_existing_schema(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema: """ Takes the current schema and updates any fields in the current schema with fields from the new_schema. If current_schema has fields that do not exist in new_schema then they are unchanged. If current_schema has fields that also exist in new_schema then the field in new_schema is chosen. If fields exist in new_schema but not in current, these will be ignored. Args: current_schema (pa.Schema): Schema to update new_schema (pa.Schema): Schema with fields that you wish to be used to update current_schema Returns: pa.Schema: Returns a schema with the same column order as current_schema but with the fields updated for any fields that matched new_schema. """ updated_schema = pa.schema([]) for field in current_schema: if field.name in new_schema.names: updated_schema = updated_schema.append(new_schema.field( field.name)) else: updated_schema = updated_schema.append(field) return updated_schema
def _DEPRECATED_overwrite_to_fix_arrow_table_schema( path: Path, fallback_schema: pa.Schema) -> None: if not path.stat().st_size: return table = load_trusted_arrow_file(path) untyped_schema = table.schema fields = [ __DEPRECATED_fix_field( untyped_schema.field(i), (None if fallback_schema.get_field_index(name) == -1 else fallback_schema.field(fallback_schema.get_field_index(name))), ) for i, name in enumerate(untyped_schema.names) ] schema = pa.schema(fields) # Overwrite with new data # # We don't short-circuit by comparing schemas: two pa.Schema values # with different number formats evaluate as equal. # # We write a separate file to /var/tmp and then copy it: our sandbox # won't let us `rename(2)` in `path`'s directory. with tempfile_context(dir="/var/tmp") as rewrite_path: with pa.ipc.RecordBatchFileWriter(rewrite_path, schema) as writer: writer.write_table(pa.table(table.columns, schema=schema)) shutil.copyfile(rewrite_path, path)
def _EnumerateTypesAlongPath(arrow_schema: pa.Schema, column_path: path.ColumnPath) -> pa.DataType: """Enumerates nested types along a column_path. A nested type is either a list-like type or a struct type. It uses `column_path`[0] to first address a field in the schema, and enumerates its type. If that type is nested, it enumerates its child and continues recursively until the column_path reaches an end. The child of a list-like type is its value type. The child of a struct type is the type of the child field of the name given by the corresponding step in the column_path. Args: arrow_schema: The arrow schema to traverse. column_path: A path of field names. Yields: The arrow type of each level in the schema. Raises: ValueError: If a step does not exist in the arrow schema. ValueError: If arrow_schema has no more struct fields, but we did not iterate through every field in column_path. """ field_name = column_path.initial_step() column_path = column_path.suffix(1) arrow_field = arrow_schema.field(field_name) arrow_type = arrow_field.type yield arrow_type while True: if pa.types.is_struct(arrow_type): # get the field from the StructType if not column_path: break curr_field_name = column_path.initial_step() column_path = column_path.suffix(1) try: arrow_field = arrow_type[curr_field_name] except KeyError: raise ValueError( "The field: {} could not be found in the current Struct: {}" .format(curr_field_name, arrow_type)) arrow_type = arrow_field.type elif _IsListLike(arrow_type): arrow_type = arrow_type.value_type else: yield arrow_type if column_path: raise ValueError( "The arrow_schema fields are exhausted, but there are remaining " "fields in the column_path: {}".format(column_path)) break yield arrow_type
def schemas_equal(a: pa.Schema, b: pa.Schema, check_order: bool = True, check_metadata: bool = True) -> bool: """check if two schemas are equal :param a: first pyarrow schema :param b: second pyarrow schema :param compare_order: whether to compare order :param compare_order: whether to compare metadata :return: if the two schema equal """ if check_order: return a.equals(b, check_metadata=check_metadata) if check_metadata and a.metadata != b.metadata: return False da = {k: a.field(k) for k in a.names} db = {k: b.field(k) for k in b.names} return da == db
def _EnumerateTypesAlongPath(arrow_schema: pa.Schema, path: List[Text]) -> pa.DataType: """Enumerates nested types along a path. A nested type is either a list-like type or a struct type. It uses `path`[0] to first address a field in the schema, and enumerates its type. If that type is nested, it enumerates its child and continues recursively until the path reaches an end. The child of a list-like type is its value type. The child of a struct type is the type of the child field of the name given by the corresponding step in the path. Args: arrow_schema: The arrow schema to traverse. path: A path of field names. Yields: The arrow type of each level in the schema. Raises: ValueError: If a step does not exist in the arrow schema. """ path = collections.deque(path) field_name = path.popleft() arrow_field = arrow_schema.field(field_name) arrow_type = arrow_field.type yield arrow_type while True: if pa.types.is_struct(arrow_type): # get the field from the StructType if not path: # path is empty break curr_field_name = path.popleft() try: arrow_field = arrow_type[curr_field_name] except KeyError: raise ValueError( "The field: {} could not be found in the current Struct: {}" .format(curr_field_name, arrow_type)) arrow_type = arrow_field.type elif _IsListLike(arrow_type): arrow_type = arrow_type.value_type else: yield arrow_type break yield arrow_type
def _read_table( table_as_folder: "Table", columns: t.List[str], filter_expression: pds.Expression, partitioning: pds.Partitioning, table_schema: pa.Schema, ) -> pa.Table: """ Refer: https://arrow.apache.org/docs/python/dataset.html#dataset todo: need to find a way to preserve indexes while writing or else find a way to read with sort with pyarrow ... then there will be no need to use to_pandas() and also no need ofr casting """ if bool(columns): table_schema = pa.schema( fields=[table_schema.field(_c) for _c in columns], metadata=table_schema.metadata ) # noinspection PyProtectedMember _path = table_as_folder.path _table = pa.Table.from_batches( batches=pds.dataset( source=_path.full_path, filesystem=_path.fs, format=_FILE_FORMAT, schema=table_schema, partitioning=partitioning, ).to_batches( # todo: verify below claim and test if this will remain generally correct # using filters like columns and filter_expression here is more efficient # as it applies for per batch loaded rather than loading entire table and # then applying filters columns=columns, filter=filter_expression, ), # if column is specified table_schema will change as some columns will # disappear ... so we set to None # todo: check how to drop remaining columns from table_schema schema=table_schema, ) # todo: should we reconsider sort overhead ??? # return self.file_type.deserialize( # _table # ).sort_index(axis=0) return _table
def _GetNestDepthAndValueType(arrow_schema: pa.Schema, path: List[Text]) -> Tuple[int, pa.DataType]: """Returns the depth of a leaf field, and its innermost value type. The Depth is constituted by the number of nested lists in the leaf field. Args: arrow_schema: The arrow schema to traverse. path: A path of field names. The path must describe a leaf struct. Returns: A Tuple of depth and arrow type """ arrow_type = arrow_schema.field(path[0]).type depth = 0 for arrow_type in _EnumerateTypesAlongPath(arrow_schema, path): if _IsListLike(arrow_type): depth += 1 return depth, arrow_type
def arrow_schema_to_render_columns(schema: pa.Schema) -> Dict[str, RenderColumn]: return { name: _arrow_field_to_render_column(schema.field(i)) for i, name in enumerate(schema.names) }
def CanHandle(arrow_schema: pa.Schema, tensor_representation: schema_pb2.TensorRepresentation) -> bool: """Returns whether `tensor_representation` can be handled. The case where the tensor_representation cannot be handled is when: 1. Wrong column name / field name requested. 2. Non-leaf field is requested (for StructTypes). 3. There does not exist a ListType along the path. 4. Requested partitions paths are not an integer values or doesn't exist. Args: arrow_schema: The pyarrow schema. tensor_representation: The TensorRepresentation proto. """ ragged_tensor = tensor_representation.ragged_tensor if len(ragged_tensor.feature_path.step) < 1: return False value_path = path.ColumnPath.from_proto(ragged_tensor.feature_path) # Checking the outer dimensions represented by the value feature path. contains_list = False try: arrow_type = None for arrow_type in _EnumerateTypesAlongPath(arrow_schema, value_path): if _IsListLike(arrow_type): contains_list = True if pa.types.is_struct(arrow_type): # The path is depleted, but the last arrow_type is a struct. This means # the path is a Non-leaf field. return False except ValueError: # ValueError signifies wrong column name / field name requested. return False if not contains_list: return False # Check the auxiliar features that need to be accessed to form the inner # dimensions partitions. parent_path = value_path.parent() # Check the columns exists and have correct depth and type. for partition in ragged_tensor.partition: if partition.HasField("row_length"): try: field_path = parent_path.child(partition.row_length) # To avoid loop undefined variable lint error. partition_type = arrow_schema.field(field_path.initial_step()).type for partition_type in _EnumerateTypesAlongPath( arrow_schema, field_path, stop_at_path_end=True): # Iterate through them all. Only interested on the last type. pass if not _IsListLike(partition_type) or not pa.types.is_integer( partition_type.value_type): return False except ValueError: # ValueError signifies wrong column name / field name requested. return False elif partition.HasField("uniform_row_length"): if partition.uniform_row_length <= 0: return False else: return False # All checks passed successfully. return True
def conform_to_schema( cls, table: pa.Table, schema: pa.Schema, pandas_types=None, warn_extra_columns=True) \ -> pa.Table: """ Align an Arrow table to an Arrow schema. Columns will be matched using case-insensitive matching and columns not in the schema will be dropped. The resulting table will have the field order and case defined in the schema. Where column types do not match exactly, type coercion will be applied if possible. In some cases type coercion may result in overflows, for example casting int64 -> int32 will fail if any values are greater than the maximum int32 value. If the incoming data has been converted from Pandas, there are some conversions that can be applied if the original Pandas dtype is known. These dtypes can be supplied via the pandas_dtypes parameter and should line up with the data in the table (i.e. dtypes are for the source data, not the target schema). The method will return a dataset whose schema exactly matches the requested schema. If it is not possible to make the data conform to the schema for any reason, EDataConformance will be raised. :param table: The data to be conformed :param schema: The schema to conform to :param pandas_types: Pandas dtypes for the table, if the table has been converted from Pandas :param warn_extra_columns: Whether to log warnings it the table contains columns not in the schema :return: The conformed data, whose schema will exactly match the supplied schema parameter :raises: _ex.EDataConformance if conformance is not possible for any reason """ # If Pandas types are supplied they must match the table, i.e. table has been converted from Pandas if pandas_types is not None and len(pandas_types) != len( table.schema.types): raise _ex.EUnexpected() cls._check_duplicate_fields(schema, True) cls._check_duplicate_fields(table.schema, False) table_indices = { f.lower(): i for (i, f) in enumerate(table.schema.names) } conformed_data = [] conformance_errors = [] # Coerce types to match expected schema where possible for schema_index in range(len(schema.names)): try: schema_field = schema.field(schema_index) table_index = table_indices.get(schema_field.name.lower()) if table_index is None: message = cls.__E_FIELD_MISSING.format( field_name=schema_field.name) cls.__log.error(message) raise _ex.EDataConformance(message) table_column: pa.Array = table.column(table_index) pandas_type = pandas_types[table_index] \ if pandas_types is not None \ else None if table_column.type == schema_field.type: conformed_column = table_column else: conformed_column = cls._coerce_vector( table_column, schema_field, pandas_type) if not schema_field.nullable and table_column.null_count > 0: message = f"Null values present in non-null field [{schema_field.name}]" cls.__log.error(message) raise _ex.EDataConformance(message) conformed_data.append(conformed_column) except _ex.EDataConformance as e: conformance_errors.append(e) # Columns not defined in the schema will not be included in the conformed output if warn_extra_columns and table.num_columns > len(schema.types): schema_columns = set(map(str.lower, schema.names)) extra_columns = [ f"[{col}]" for col in table.schema.names if col.lower() not in schema_columns ] message = f"Columns not defined in the schema will be dropped: {', '.join(extra_columns)}" cls.__log.warning(message) if any(conformance_errors): if len(conformance_errors) == 1: raise conformance_errors[0] else: cls.__log.error("There were multiple data conformance errors") raise _ex.EDataConformance( "There were multiple data conformance errors", conformance_errors) return pa.Table.from_arrays(conformed_data, schema=schema) # noqa