def test_index_converts_field_correctly(self): vocab = Vocabulary() b_index = vocab.add_token_to_namespace("B", namespace='*tags') i_index = vocab.add_token_to_namespace("I", namespace='*tags') o_index = vocab.add_token_to_namespace("O", namespace='*tags') tags = ["B", "I", "O", "O", "O"] tag_field = TagField(tags, self.text, tag_namespace="*tags") tag_field.index(vocab) # pylint: disable=protected-access assert tag_field._indexed_tags == [b_index, i_index, o_index, o_index, o_index] assert tag_field._num_tags == 3
def test_pad_produces_one_hot_targets(self): vocab = Vocabulary() vocab.add_token_to_namespace("B", namespace='*tags') vocab.add_token_to_namespace("I", namespace='*tags') vocab.add_token_to_namespace("O", namespace='*tags') tags = ["B", "I", "O", "O", "O"] tag_field = TagField(tags, self.text, tag_namespace="*tags") tag_field.index(vocab) padding_lengths = tag_field.get_padding_lengths() array = tag_field.as_array(padding_lengths) numpy.testing.assert_array_almost_equal( array, numpy.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1]]))
def test_pad_produces_one_hot_targets(self): vocab = Vocabulary() vocab.add_token_to_namespace("B", namespace='*tags') vocab.add_token_to_namespace("I", namespace='*tags') vocab.add_token_to_namespace("O", namespace='*tags') text = TextField(["here", "are", "some", "words", "."], [token_indexers["single id"]("words")]) tags = ["B", "I", "O", "O", "O"] tag_field = TagField(tags, text, tag_namespace="*tags") tag_field.index(vocab) padding_lengths = tag_field.get_padding_lengths() array = tag_field.pad(padding_lengths) numpy.testing.assert_array_almost_equal( array, numpy.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1]]))
def test_index_converts_field_correctly(self): vocab = Vocabulary() b_index = vocab.add_token_to_namespace("B", namespace='*tags') i_index = vocab.add_token_to_namespace("I", namespace='*tags') o_index = vocab.add_token_to_namespace("O", namespace='*tags') text = TextField(["here", "are", "some", "words", "."], [token_indexers["single id"]("words")]) tags = ["B", "I", "O", "O", "O"] tag_field = TagField(tags, text, tag_namespace="*tags") tag_field.index(vocab) # pylint: disable=protected-access assert tag_field._indexed_tags == [ b_index, i_index, o_index, o_index, o_index ] assert tag_field._num_tags == 3