def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self): component_spec = spec_pb2.ComponentSpec() text_format.Parse(""" name: "test" network_unit { registered_name: "IdentityNetwork" } fixed_feature { name: "fixed1" embedding_dim: -1 size: 1 } fixed_feature { name: "fixed2" embedding_dim: -1 size: 1 } fixed_feature { name: "fixed3" embedding_dim: -1 size: 1 } """, component_spec) with tf.Graph().as_default(): comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec) # Should not raise errors. self.network_states[component_spec.name] = component.NetworkState() comp.build_greedy_training(self.master_state, self.network_states) self.network_states[component_spec.name] = component.NetworkState() comp.build_greedy_inference(self.master_state, self.network_states)
def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self): component_spec = spec_pb2.ComponentSpec() text_format.Parse(""" name: "test" network_unit { registered_name: "IdentityNetwork" } fixed_feature { name: "fixed" embedding_dim: 2 size: 1 } """, component_spec) with self.assertRaises(ValueError): unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
def testBulkFeatureIdExtractorFailsOnLinkedFeature(self): component_spec = spec_pb2.ComponentSpec() text_format.Parse(""" name: "test" network_unit { registered_name: "IdentityNetwork" } fixed_feature { name: "fixed" embedding_dim: -1 size: 1 } linked_feature { name: "linked" embedding_dim: -1 size: 1 source_translator: "identity" source_component: "mock" } """, component_spec) with tf.Graph().as_default(): with self.assertRaises(ValueError): unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
def testBulkFeatureIdExtractorExtractFocusWithOffset(self): path = os.path.join(tf.test.get_temp_dir(), 'label-map') with open(path, 'w') as label_map_file: label_map_file.write('0\n') master_spec = spec_pb2.MasterSpec() text_format.Parse(""" component { name: "test" transition_system { registered_name: "shift-only" } resource { name: "label-map" part { file_pattern: "%s" file_format: "text" } } network_unit { registered_name: "ExportFixedFeaturesNetwork" } backend { registered_name: "SyntaxNetComponent" } fixed_feature { name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus" predicate_map: "none" } fixed_feature { name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus" predicate_map: "none" } fixed_feature { name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus" predicate_map: "none" } } """ % path, master_spec) with tf.Graph().as_default(): corpus = _create_fake_corpus() corpus = tf.constant(corpus, shape=[len(corpus)]) handle = dragnn_ops.get_session( container='test', master_spec=master_spec.SerializeToString(), grid_point='') handle = dragnn_ops.attach_data_reader(handle, corpus) handle = dragnn_ops.init_component_data( handle, beam_size=1, component='test') batch_size = dragnn_ops.batch_size(handle, component='test') master_state = component.MasterState(handle, batch_size) extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, master_spec.component[0]) network_state = component.NetworkState() self.network_states['test'] = network_state handle = extractor.build_greedy_inference(master_state, self.network_states) focus1 = network_state.activations['focus1'].bulk_tensor focus2 = network_state.activations['focus2'].bulk_tensor focus3 = network_state.activations['focus3'].bulk_tensor with self.test_session() as sess: focus1, focus2, focus3 = sess.run([focus1, focus2, focus3]) tf.logging.info('focus1=\n%s', focus1) tf.logging.info('focus2=\n%s', focus2) tf.logging.info('focus3=\n%s', focus3) self.assertAllEqual( focus1, [[0], [-1], [-1], [-1], [0], [1], [-1], [-1], [0], [1], [2], [-1], [0], [1], [2], [3]]) self.assertAllEqual( focus2, [[-1], [-1], [-1], [-1], [1], [-1], [-1], [-1], [1], [2], [-1], [-1], [1], [2], [3], [-1]]) self.assertAllEqual( focus3, [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [2], [-1], [-1], [-1], [2], [3], [-1], [-1]])
def testPreCreateCalledBeforeCreate(self): component_spec = spec_pb2.ComponentSpec() text_format.Parse( """ name: "test" network_unit { registered_name: "IdentityNetwork" } """, component_spec) class AssertPreCreateBeforeCreateNetwork( network_units.NetworkUnitInterface): """Mock that asserts that .create() is called before .pre_create().""" def __init__(self, comp, test_fixture): super(AssertPreCreateBeforeCreateNetwork, self).__init__(comp) self._test_fixture = test_fixture self._pre_create_called = False def get_logits(self, network_tensors): return tf.zeros([2, 1], dtype=tf.float32) def pre_create(self, *unused_args): self._pre_create_called = True def create(self, *unused_args, **unuesd_kwargs): self._test_fixture.assertTrue(self._pre_create_called) return [] builder = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2), self.network_states) self.setUp() builder = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_inference( component.MasterState(['foo', 'bar'], 2), self.network_states) self.setUp() builder = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2), self.network_states) self.setUp() builder = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_inference( component.MasterState(['foo', 'bar'], 2), self.network_states) self.setUp() builder = bulk_component.BulkAnnotatorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2), self.network_states) self.setUp() builder = bulk_component.BulkAnnotatorComponentBuilder( self.master, component_spec) builder.network = AssertPreCreateBeforeCreateNetwork(builder, self) builder.build_greedy_inference( component.MasterState(['foo', 'bar'], 2), self.network_states)