コード例 #1
0
    def setup(self) -> None:
        """Prepares an instance for combining."""
        if self._schema is not None:
            for k, v in schema_util.get_all_leaf_features(self._schema):
                if v.WhichOneof('domain_info') == _NL_DOMAIN:
                    nld = v.natural_language_domain
                    self._nld_vocabularies[k] = nld.vocabulary
                    coverage_constraints = nld.coverage
                    self._nld_excluded_string_tokens[k] = set(
                        coverage_constraints.excluded_string_tokens)
                    self._nld_excluded_int_tokens[k] = set(
                        coverage_constraints.excluded_int_tokens)
                    self._nld_oov_string_tokens[k] = set(
                        coverage_constraints.oov_string_tokens)
                    if (self._nld_vocabularies[k]
                            or self._nld_excluded_string_tokens[k]
                            or self._nld_excluded_int_tokens[k]
                            or self._nld_oov_string_tokens[k]):
                        self._valid_feature_paths.add(k)
                    for t in nld.token_constraints:
                        if t.WhichOneof('value') == _INT_VALUE:
                            self._nld_specified_int_tokens[k].add(t.int_value)
                        else:
                            self._nld_specified_str_tokens[k].add(
                                t.string_value)

        if self._vocab_paths is not None:
            for k, v in self._vocab_paths.items():
                self._vocabs[k], self._rvocabs[k] = vocab_util.load_vocab(v)
コード例 #2
0
    def test_text_file(self):
        with tempfile.NamedTemporaryFile() as f:
            f.write(b'Foo\nBar\n')
            f.flush()

            vocab, reverse_vocab = vocab_util.load_vocab(f.name)
            self.assertEqual(vocab, {'Foo': 0, 'Bar': 1})
            self.assertEqual(reverse_vocab, {0: 'Foo', 1: 'Bar'})
コード例 #3
0
    def test_gz_recordio_file(self):
        with tempfile.NamedTemporaryFile(suffix='.tfrecord.gz') as f:
            writer = tf.io.TFRecordWriter(f.name, options='GZIP')
            for element in [b'Foo', b'Bar']:
                writer.write(element)
            writer.flush()
            f.flush()

            vocab, reverse_vocab = vocab_util.load_vocab(f.name)
            self.assertEqual(vocab, {'Foo': 0, 'Bar': 1})
            self.assertEqual(reverse_vocab, {0: 'Foo', 1: 'Bar'})
コード例 #4
0
    def test_gz_recordio_file(self):
        with tempfile.NamedTemporaryFile(suffix='.tfrecord.gz') as f:
            dataset = tf.data.Dataset.from_tensor_slices(['Foo', 'Bar'])
            writer = tf.data.experimental.TFRecordWriter(
                f.name, compression_type='GZIP')
            writer.write(dataset)
            f.flush()

            vocab, reverse_vocab = vocab_util.load_vocab(f.name)
            self.assertEqual(vocab, {'Foo': 0, 'Bar': 1})
            self.assertEqual(reverse_vocab, {0: 'Foo', 1: 'Bar'})