def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self):
    component_spec = spec_pb2.ComponentSpec()
    text_format.Parse("""
        name: "test"
        network_unit {
          registered_name: "IdentityNetwork"
        }
        fixed_feature {
          name: "fixed1" embedding_dim: -1 size: 1
        }
        fixed_feature {
          name: "fixed2" embedding_dim: -1 size: 1
        }
        fixed_feature {
          name: "fixed3" embedding_dim: -1 size: 1
        }
        """, component_spec)
    with tf.Graph().as_default():
      comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, component_spec)

      # Should not raise errors.
      self.network_states[component_spec.name] = component.NetworkState()
      comp.build_greedy_training(self.master_state, self.network_states)
      self.network_states[component_spec.name] = component.NetworkState()
      comp.build_greedy_inference(self.master_state, self.network_states)
 def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self):
   component_spec = spec_pb2.ComponentSpec()
   text_format.Parse("""
       name: "test"
       network_unit {
         registered_name: "IdentityNetwork"
       }
       fixed_feature {
         name: "fixed" embedding_dim: 2 size: 1
       }
       """, component_spec)
   with self.assertRaises(ValueError):
     unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
         self.master, component_spec)
 def testBulkFeatureIdExtractorFailsOnLinkedFeature(self):
   component_spec = spec_pb2.ComponentSpec()
   text_format.Parse("""
       name: "test"
       network_unit {
         registered_name: "IdentityNetwork"
       }
       fixed_feature {
         name: "fixed" embedding_dim: -1 size: 1
       }
       linked_feature {
         name: "linked" embedding_dim: -1 size: 1
         source_translator: "identity"
         source_component: "mock"
       }
       """, component_spec)
   with tf.Graph().as_default():
     with self.assertRaises(ValueError):
       unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
           self.master, component_spec)
  def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
    path = os.path.join(tf.test.get_temp_dir(), 'label-map')
    with open(path, 'w') as label_map_file:
      label_map_file.write('0\n')

    master_spec = spec_pb2.MasterSpec()
    text_format.Parse("""
        component {
          name: "test"
          transition_system {
            registered_name: "shift-only"
          }
          resource {
            name: "label-map"
            part {
              file_pattern: "%s"
              file_format: "text"
            }
          }
          network_unit {
            registered_name: "ExportFixedFeaturesNetwork"
          }
          backend {
            registered_name: "SyntaxNetComponent"
          }
          fixed_feature {
            name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
            predicate_map: "none"
          }
        }
        """ % path, master_spec)

    with tf.Graph().as_default():
      corpus = _create_fake_corpus()
      corpus = tf.constant(corpus, shape=[len(corpus)])
      handle = dragnn_ops.get_session(
          container='test',
          master_spec=master_spec.SerializeToString(),
          grid_point='')
      handle = dragnn_ops.attach_data_reader(handle, corpus)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=1, component='test')
      batch_size = dragnn_ops.batch_size(handle, component='test')
      master_state = component.MasterState(handle, batch_size)

      extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, master_spec.component[0])
      network_state = component.NetworkState()
      self.network_states['test'] = network_state
      handle = extractor.build_greedy_inference(master_state,
                                                self.network_states)
      focus1 = network_state.activations['focus1'].bulk_tensor
      focus2 = network_state.activations['focus2'].bulk_tensor
      focus3 = network_state.activations['focus3'].bulk_tensor

      with self.test_session() as sess:
        focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
        tf.logging.info('focus1=\n%s', focus1)
        tf.logging.info('focus2=\n%s', focus2)
        tf.logging.info('focus3=\n%s', focus3)

        self.assertAllEqual(
            focus1,
            [[0], [-1], [-1], [-1],
             [0], [1], [-1], [-1],
             [0], [1], [2], [-1],
             [0], [1], [2], [3]])

        self.assertAllEqual(
            focus2,
            [[-1], [-1], [-1], [-1],
             [1], [-1], [-1], [-1],
             [1], [2], [-1], [-1],
             [1], [2], [3], [-1]])

        self.assertAllEqual(
            focus3,
            [[-1], [-1], [-1], [-1],
             [-1], [-1], [-1], [-1],
             [2], [-1], [-1], [-1],
             [2], [3], [-1], [-1]])
    def testPreCreateCalledBeforeCreate(self):
        component_spec = spec_pb2.ComponentSpec()
        text_format.Parse(
            """
        name: "test"
        network_unit {
          registered_name: "IdentityNetwork"
        }
        """, component_spec)

        class AssertPreCreateBeforeCreateNetwork(
                network_units.NetworkUnitInterface):
            """Mock that asserts that .create() is called before .pre_create()."""
            def __init__(self, comp, test_fixture):
                super(AssertPreCreateBeforeCreateNetwork, self).__init__(comp)
                self._test_fixture = test_fixture
                self._pre_create_called = False

            def get_logits(self, network_tensors):
                return tf.zeros([2, 1], dtype=tf.float32)

            def pre_create(self, *unused_args):
                self._pre_create_called = True

            def create(self, *unused_args, **unuesd_kwargs):
                self._test_fixture.assertTrue(self._pre_create_called)
                return []

        builder = bulk_component.BulkFeatureExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)

        self.setUp()
        builder = bulk_component.BulkAnnotatorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkAnnotatorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)