コード例 #1
0
ファイル: dataset_slicing.py プロジェクト: coraljain/privacy
def get_single_slice_specs(slicing_spec: SlicingSpec,
                           num_classes: int = None) -> List[SingleSliceSpec]:
  """Returns slices of data according to slicing_spec."""
  result = []

  if slicing_spec.entire_dataset:
    result.append(SingleSliceSpec())

  # Create slices by class.
  by_class = slicing_spec.by_class
  if isinstance(by_class, bool):
    if by_class:
      assert num_classes, "When by_class == True, num_classes should be given."
      assert 0 <= num_classes <= 1000, (
          f"Too much classes for slicing by classes. "
          f"Found {num_classes}.")
      for c in range(num_classes):
        result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
  elif isinstance(by_class, int):
    result.append(SingleSliceSpec(SlicingFeature.CLASS, by_class))
  elif isinstance(by_class, collections.Iterable):
    for c in by_class:
      result.append(SingleSliceSpec(SlicingFeature.CLASS, c))

  # Create slices by percentiles
  if slicing_spec.by_percentiles:
    for percent in range(0, 100, 10):
      result.append(
          SingleSliceSpec(SlicingFeature.PERCENTILE, (percent, percent + 10)))

  # Create slices by correctness of the classifications.
  if slicing_spec.by_classification_correctness:
    result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True))
    result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False))

  return result
コード例 #2
0
 def testStr(self, feature, value, expected_str):
     self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
コード例 #3
0
 def testStrEntireDataset(self):
     self.assertEqual(str(SingleSliceSpec()), 'Entire dataset')
コード例 #4
0
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
    if hasattr(data, 'slice_spec'):
        return data.slice_spec
    return SingleSliceSpec()
コード例 #5
0
 def test_slice_entire_dataset(self):
     entire_dataset_slice = SingleSliceSpec()
     output = get_slice(self.input_data, entire_dataset_slice)
     expected = self.input_data
     expected.slice_spec = entire_dataset_slice
     self.assertTrue(_are_all_fields_equal(output, self.input_data))