コード例 #1
0
    def test_convert_data_csv_invalid_input_filename(self):
        filename_no_extension = os.path.join(tempfile.mkdtemp(),
                                             'no_extension')
        filename_unsupported_extension = os.path.join(
            tempfile.mkdtemp(), 'unsupported_extension.mp4')

        with self.assertRaisesRegex(ValueError,
                                    '.*must.*supported.*file.*extension'):
            _ = util.convert_comments_data(filename_no_extension)

        with self.assertRaisesRegex(ValueError,
                                    '.*must.*supported.*file.*extension'):
            _ = util.convert_comments_data(filename_unsupported_extension)
コード例 #2
0
    def test_convert_data_csv(self):
        input_file = self._write_csv(self._create_example_csv())
        output_file = util.convert_comments_data(input_file)

        # Remove the quotes around identity terms list that read_csv injects.
        df = pd.read_csv(output_file).replace("'", '', regex=True)

        expected_df = pd.DataFrame()
        expected_df = expected_df.append(
            {
                'comment_text':
                'comment 1',
                'toxicity':
                0.0,
                'gender': [],
                'sexual_orientation': ['bisexual'],
                'race': ['other_race_or_ethnicity'],
                'religion': ['atheist', 'other_religion'],
                'disability': [
                    'physical_disability',
                    'intellectual_or_learning_disability',
                    'psychiatric_or_mental_illness', 'other_disability'
                ]
            },
            ignore_index=True)

        self.assertEqual(df.reset_index(drop=True, inplace=True),
                         expected_df.reset_index(drop=True, inplace=True))
コード例 #3
0
    def test_convert_data_tfrecord(self):
        input_file = self._write_tf_records(self._create_example_tfrecord())
        output_file = util.convert_comments_data(input_file)
        output_example_list = []
        for serialized in tf.data.TFRecordDataset(filenames=[output_file]):
            output_example = tf.train.Example()
            output_example.ParseFromString(serialized.numpy())
            output_example_list.append(output_example)

        self.assertEqual(len(output_example_list), 1)
        self.assertEqual(
            output_example_list[0],
            text_format.Parse(
                """
        features {
          feature { key: "comment_text"
                    value { bytes_list {value: [ "comment 1" ] }}
                  }
          feature { key: "toxicity" value { float_list { value: [ 0.0 ] }}}
          feature { key: "sexual_orientation"
                    value { bytes_list { value: ["bisexual"] }}
                  }
          feature { key: "gender" value { bytes_list { }}}
          feature { key: "race"
                    value { bytes_list { value: [ "other_race_or_ethnicity" ] }}
                  }
          feature { key: "religion"
                    value { bytes_list {
                      value: [  "atheist", "other_religion" ] }
                    }
                  }
          feature { key: "disability" value { bytes_list {
                    value: [
                      "physical_disability",
                      "intellectual_or_learning_disability",
                      "psychiatric_or_mental_illness",
                      "other_disability"] }}
                  }
        }
        """, tf.train.Example()))
コード例 #4
0
# Downloads a file from a URL if it is not already in the cache using the `tf.keras.utils.get_file()` function.
if download_original_data:
    train_tf_file = tf.keras.utils.get_file(
        'train_tf.tfrecord',
        'https://storage.googleapis.com/civil_comments_dataset/train_tf.tfrecord'
    )
    validate_tf_file = tf.keras.utils.get_file(
        'validate_tf.tfrecord',
        'https://storage.googleapis.com/civil_comments_dataset/validate_tf.tfrecord'
    )

    # The identity terms list will be grouped together by their categories
    # (see 'IDENTITY_COLUMNS') on threshould 0.5. Only the identity term column,
    # text column and label column will be kept after processing.
    train_tf_file = util.convert_comments_data(train_tf_file)
    validate_tf_file = util.convert_comments_data(validate_tf_file)

# TODO 1a

else:
    train_tf_file = tf.keras.utils.get_file(
        'train_tf_processed.tfrecord',
        'https://storage.googleapis.com/civil_comments_dataset/train_tf_processed.tfrecord'
    )
    validate_tf_file = tf.keras.utils.get_file(
        'validate_tf_processed.tfrecord',
        'https://storage.googleapis.com/civil_comments_dataset/validate_tf_processed.tfrecord'
    )

# ### Use TFDV to generate statistics and Facets to visualize the data