Beispiel #1
0
    def test_generate_slices_with_kwargs(self):
        def slice_function_1(example, **kwargs):
            slice_keys = []
            test_num = kwargs['test_num']
            if example.get('test_column_1'):
                slice_keys.append('test_slice_key_a_' + str(test_num))
            if example.get('test_column_2'):
                slice_keys.append('test_slice_key_b_' + str(test_num))
            return slice_keys

        def slice_function_2(example, **kwargs):  # pylint: disable=unused-argument
            return []

        input_example = {
            'test_column_1': np.array([1]),
            'test_column_2': np.array([2])
        }

        expected_result = [('test_slice_key_a_1', input_example),
                           ('test_slice_key_b_1', input_example)]
        test_kwargs = {'test_num': 1}
        # slice_function_1 returns a list of multiple slice keys, and
        # slice_function_2 returns an empty list.
        actual_result = list(
            slicing_util.generate_slices(input_example,
                                         [slice_function_1, slice_function_2],
                                         **test_kwargs))
        self.assertCountEqual(expected_result, actual_result)
 def process(
     self, element: Tuple[pa.RecordBatch, anomalies_pb2.Anomalies]
 ) -> Iterable[types.SlicedRecordBatch]:
     record_batch, anomalies_proto = element
     for sliced_record_batch in slicing_util.generate_slices(
             record_batch,
         [anomalies_util.get_anomalies_slicer(anomalies_proto)]):
         yield sliced_record_batch
Beispiel #3
0
    def test_generate_slices_bad_slice_function(self):
        def bad_slice_function(example):  # pylint: disable=unused-argument
            return 1 / 0

        input_example = {'test_column': np.array([1])}

        with self.assertRaisesRegexp(
                ValueError, 'One of the slice_functions '
                'bad_slice_function raised an exception: '
                'ZeroDivisionError.*'):
            list(
                slicing_util.generate_slices(input_example,
                                             [bad_slice_function]))
Beispiel #4
0
    def test_generate_slices_without_kwargs(self):
        def slice_function_1(example):
            if example.get('test_column_1'):
                return ['test_slice_key_1']

        def slice_function_2(example):
            if example.get('test_column_2'):
                return ['test_slice_key_2']

        input_example = {
            'test_column_1': np.array([1]),
            'test_column_2': np.array(['a'])
        }

        expected_result = [('test_slice_key_1', input_example),
                           ('test_slice_key_2', input_example)]
        # Each slice function returns a list of one slice key.
        actual_result = list(
            slicing_util.generate_slices(input_example,
                                         [slice_function_1, slice_function_2]))
        self.assertCountEqual(expected_result, actual_result)
Beispiel #5
0
 def process(self, element):
     example, anomalies_proto = element
     for slice_key_and_example in slicing_util.generate_slices(
             example, [anomalies_util.anomalies_slicer],
             anomaly_proto=anomalies_proto):
         yield slice_key_and_example
 def process(self, element):
   record_batch, anomalies_proto = element
   for slice_key in slicing_util.generate_slices(
       record_batch, [anomalies_util.anomalies_slicer],
       anomalies=anomalies_proto):
     yield slice_key, record_batch
Beispiel #7
0
 def process(self, element):
     table, anomalies_proto = element
     for slice_key in slicing_util.generate_slices(
             table, [anomalies_util.anomalies_slicer],
             anomalies=anomalies_proto):
         yield slice_key, table
Beispiel #8
0
 def process(self, element):
     record_batch, anomalies_proto = element
     for sliced_record_batch in slicing_util.generate_slices(
             record_batch,
         [anomalies_util.get_anomalies_slicer(anomalies_proto)]):
         yield sliced_record_batch