def test_convert_data_csv_invalid_input_filename(self): filename_no_extension = os.path.join(tempfile.mkdtemp(), 'no_extension') filename_unsupported_extension = os.path.join( tempfile.mkdtemp(), 'unsupported_extension.mp4') with self.assertRaisesRegex(ValueError, '.*must.*supported.*file.*extension'): _ = util.convert_comments_data(filename_no_extension) with self.assertRaisesRegex(ValueError, '.*must.*supported.*file.*extension'): _ = util.convert_comments_data(filename_unsupported_extension)
def test_convert_data_csv(self): input_file = self._write_csv(self._create_example_csv()) output_file = util.convert_comments_data(input_file) # Remove the quotes around identity terms list that read_csv injects. df = pd.read_csv(output_file).replace("'", '', regex=True) expected_df = pd.DataFrame() expected_df = expected_df.append( { 'comment_text': 'comment 1', 'toxicity': 0.0, 'gender': [], 'sexual_orientation': ['bisexual'], 'race': ['other_race_or_ethnicity'], 'religion': ['atheist', 'other_religion'], 'disability': [ 'physical_disability', 'intellectual_or_learning_disability', 'psychiatric_or_mental_illness', 'other_disability' ] }, ignore_index=True) self.assertEqual(df.reset_index(drop=True, inplace=True), expected_df.reset_index(drop=True, inplace=True))
def test_convert_data_tfrecord(self): input_file = self._write_tf_records(self._create_example_tfrecord()) output_file = util.convert_comments_data(input_file) output_example_list = [] for serialized in tf.data.TFRecordDataset(filenames=[output_file]): output_example = tf.train.Example() output_example.ParseFromString(serialized.numpy()) output_example_list.append(output_example) self.assertEqual(len(output_example_list), 1) self.assertEqual( output_example_list[0], text_format.Parse( """ features { feature { key: "comment_text" value { bytes_list {value: [ "comment 1" ] }} } feature { key: "toxicity" value { float_list { value: [ 0.0 ] }}} feature { key: "sexual_orientation" value { bytes_list { value: ["bisexual"] }} } feature { key: "gender" value { bytes_list { }}} feature { key: "race" value { bytes_list { value: [ "other_race_or_ethnicity" ] }} } feature { key: "religion" value { bytes_list { value: [ "atheist", "other_religion" ] } } } feature { key: "disability" value { bytes_list { value: [ "physical_disability", "intellectual_or_learning_disability", "psychiatric_or_mental_illness", "other_disability"] }} } } """, tf.train.Example()))
# Downloads a file from a URL if it is not already in the cache using the `tf.keras.utils.get_file()` function. if download_original_data: train_tf_file = tf.keras.utils.get_file( 'train_tf.tfrecord', 'https://storage.googleapis.com/civil_comments_dataset/train_tf.tfrecord' ) validate_tf_file = tf.keras.utils.get_file( 'validate_tf.tfrecord', 'https://storage.googleapis.com/civil_comments_dataset/validate_tf.tfrecord' ) # The identity terms list will be grouped together by their categories # (see 'IDENTITY_COLUMNS') on threshould 0.5. Only the identity term column, # text column and label column will be kept after processing. train_tf_file = util.convert_comments_data(train_tf_file) validate_tf_file = util.convert_comments_data(validate_tf_file) # TODO 1a else: train_tf_file = tf.keras.utils.get_file( 'train_tf_processed.tfrecord', 'https://storage.googleapis.com/civil_comments_dataset/train_tf_processed.tfrecord' ) validate_tf_file = tf.keras.utils.get_file( 'validate_tf_processed.tfrecord', 'https://storage.googleapis.com/civil_comments_dataset/validate_tf_processed.tfrecord' ) # ### Use TFDV to generate statistics and Facets to visualize the data