def _remove_unsupported_feature_columns(examples_table: pa.Table,
                                        schema: schema_pb2.Schema) -> pa.Table:
  """Removes feature columns that contain unsupported values.

  All feature columns that are multivalent are dropped since they are
  not supported by sk-learn.

  All columns of STRUCT type are also dropped.

  Args:
    examples_table: Arrow table containing a batch of examples.
    schema: The schema for the data.

  Returns:
    Arrow table.
  """
  multivalent_features = schema_util.get_multivalent_features(schema)
  unsupported_columns = set()
  for f in multivalent_features:
    unsupported_columns.add(f.steps()[0])
  for column_name, column in zip(examples_table.schema.names,
                                 examples_table.itercolumns()):
    if (stats_util.get_feature_type_from_arrow_type(
        types.FeaturePath([column_name]),
        column.type) == statistics_pb2.FeatureNameStatistics.STRUCT):
      unsupported_columns.add(column_name)
  return examples_table.drop(unsupported_columns)
def _remove_unsupported_feature_columns(examples, schema):
    """Removes feature columns that contain unsupported values.

  All feature columns that are multivalent are dropped since they are
  not supported by sk-learn.

  Args:
    examples: ExampleBatch containing the values of each example per feature.
    schema: The schema for the data.
  """
    unsupported_features = schema_util.get_multivalent_features(schema)
    for feature_name in unsupported_features:
        del examples[feature_name]
    def _remove_unsupported_feature_columns(
            self, examples: pa.RecordBatch,
            schema: schema_pb2.Schema) -> pa.RecordBatch:
        """Removes feature columns that contain unsupported values.

    All feature columns that are multivalent are dropped since they are
    not supported by sk-learn.

    All columns of STRUCT type are also dropped.

    Args:
      examples: Arrow RecordBatch containing a batch of examples.
      schema: The schema for the data.

    Returns:
      Arrow RecordBatch.
    """
        columns = set(examples.schema.names)

        multivalent_features = schema_util.get_multivalent_features(schema)
        unsupported_columns = set()
        for f in multivalent_features:
            # Drop the column if they were in the examples.
            if f.steps()[0] in columns:
                unsupported_columns.add(f.steps()[0])
        for column_name, column in zip(examples.schema.names,
                                       examples.columns):
            # only support 1-nested non-struct arrays.
            column_type = column.type
            if (arrow_util.get_nest_level(column_type) != 1
                    or stats_util.get_feature_type_from_arrow_type(
                        types.FeaturePath([column_name]), column_type)
                    == statistics_pb2.FeatureNameStatistics.STRUCT):
                unsupported_columns.add(column_name)
            # Drop columns that were not in the schema.
            if types.FeaturePath([column_name]) not in self._schema_features:
                unsupported_columns.add(column_name)

        supported_columns = []
        supported_column_names = []
        for column_name, column in zip(examples.schema.names,
                                       examples.columns):
            if column_name not in unsupported_columns:
                supported_columns.append(column)
                supported_column_names.append(column_name)

        return pa.RecordBatch.from_arrays(supported_columns,
                                          supported_column_names)
示例#4
0
 def test_get_multivalent_features(self):
     schema = text_format.Parse(
         """
       feature {
         name: "fa"
         shape {
           dim {
             size: 1
           }
         }
       }
       feature {
         name: "fb"
         type: BYTES
         value_count {
           min: 0
           max: 1
         }
       }
       feature {
         name: "fc"
         value_count {
           min: 1
           max: 18
         }
       }
       feature {
         name: "fd"
         value_count {
           min: 1
           max: 1
         }
       }
       feature {
         name: "fe"
         shape {
           dim {
             size: 2
           }
         }
       }
       feature {
         name: "ff"
         shape {
           dim {
             size: 1
           }
           dim {
             size: 1
           }
         }
       }
       feature {
         name: "fg"
         value_count {
           min: 2
         }
       }
       feature {
         name: "fh"
         value_count {
           min: 0
           max: 2
         }
       }
       feature {
         name: "fi"
         type: STRUCT
         struct_domain {
           feature {
             name: "fi_fa"
             value_count {
               min: 0
               max: 1
             }
           }
           feature {
             name: "fi_fb"
             value_count {
               min: 0
               max: 2
             }
           }
         }
       }
       """, schema_pb2.Schema())
     expected = set([
         types.FeaturePath(['fc']),
         types.FeaturePath(['fe']),
         types.FeaturePath(['ff']),
         types.FeaturePath(['fg']),
         types.FeaturePath(['fh']),
         types.FeaturePath(['fi', 'fi_fb'])
     ])
     self.assertEqual(schema_util.get_multivalent_features(schema),
                      expected)
def test_get_multivalent_features(self):
    schema = text_format.Parse(
        """
        feature {
          name: "fa"
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fb"
          type: BYTES
          value_count {
            min: 0
            max: 1
          }
        }
        feature {
          name: "fc"
          value_count {
            min: 1
            max: 18
          }
        }
        feature {
          name: "fd"
          value_count {
            min: 1
            max: 1
          }
        }
        feature {
          name: "fe"
          shape {
            dim {
              size: 2
            }
          }
        }
        feature {
          name: "ff"
          shape {
            dim {
              size: 1
            }
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fg"
          value_count {
            min: 2
          }
        }
        feature {
          name: "fh"
          value_count {
            min: 0
            max: 2
          }
        }""", schema_pb2.Schema())
    expected = set(['fc', 'fe', 'ff', 'fg', 'fh'])
    self.assertEqual(schema_util.get_multivalent_features(schema), expected)
示例#6
0
    def __init__(self,
                 label_feature: types.FeaturePath,
                 schema: Optional[schema_pb2.Schema] = None,
                 max_encoding_length: int = 512,
                 seed: int = 12345,
                 multivalent_features: Optional[Set[types.FeaturePath]] = None,
                 categorical_features: Optional[Set[types.FeaturePath]] = None,
                 features_to_ignore: Optional[Set[types.FeaturePath]] = None,
                 normalize_by_max: bool = False,
                 allow_invalid_partitions: bool = False,
                 custom_stats_key: str = _ADJUSTED_MUTUAL_INFORMATION_KEY,
                 column_partitions: int = 1):
        """Initializes MutualInformation.

    Args:
      label_feature: The key used to identify labels in the ExampleBatch.
      schema: An optional schema describing the the dataset. Either a schema or
        a list of categorical and multivalent features must be provided.
      max_encoding_length: An int value to specify the maximum length of
        encoding to represent a feature value.
      seed: An int value to seed the RNG used in MI computation.
      multivalent_features: An optional set of features that are multivalent.
      categorical_features: An optional set of the features that are
        categorical.
      features_to_ignore: An optional set of features that should be ignored by
        the mutual information calculation.
      normalize_by_max: If True, AMI values are normalized to a range 0 to 1 by
        dividing by the maximum possible information AMI(Y, Y).
      allow_invalid_partitions: If True, generator tolerates input partitions
        that are invalid (e.g. size of partion is < the k for the KNN), where
        invalid partitions return no stats. The min_partitions_stat_presence arg
        to PartitionedStatisticsAnalyzer controls how many partitions may be
        invalid while still reporting the metric.
      custom_stats_key: A string that determines the key used in the custom
        statistic. This defaults to `_ADJUSTED_MUTUAL_INFORMATION_KEY`.
      column_partitions: If > 1, self.partitioner returns a PTransform that
        partitions input RecordBatches by column (feature), in addition to the
        normal row partitioning (by batch). The total number of effective
        partitions is column_partitions * row_partitions, where row_partitions
        is passed to self.partitioner.

    Raises:
      ValueError: If label_feature does not exist in the schema.
    """
        self._label_feature = label_feature
        self._schema = schema
        self._normalize_by_max = normalize_by_max
        if multivalent_features is not None:
            self._multivalent_features = multivalent_features
        elif self._schema is not None:
            self._multivalent_features = schema_util.get_multivalent_features(
                self._schema)
        else:
            raise ValueError(
                "Either multivalent feature set or schema must be provided")
        if categorical_features is not None:
            self._categorical_features = categorical_features
        elif self._schema is not None:
            self._categorical_features = schema_util.get_categorical_features(
                self._schema)
        else:
            raise ValueError(
                "Either categorical feature set or schema must be provided")
        if schema:
            assert schema_util.get_feature(self._schema, self._label_feature)
        self._label_feature_is_categorical = (self._label_feature
                                              in self._categorical_features)
        self._max_encoding_length = max_encoding_length
        self._seed = seed
        self._features_to_ignore = features_to_ignore
        self._allow_invalid_partitions = allow_invalid_partitions
        self._custom_stats_key = custom_stats_key
        self._column_partitions = column_partitions