def testFailsOnNonIdentityTranslator(self):
    component_spec = spec_pb2.ComponentSpec()
    text_format.Parse("""
        name: "test"
        network_unit {
          registered_name: "IdentityNetwork"
        }
        linked_feature {
          name: "features" embedding_dim: -1 size: 1
          source_translator: "history"
          source_component: "mock"
        }
        """, component_spec)

    # For feature extraction:
    with tf.Graph().as_default():
      comp = bulk_component.BulkFeatureExtractorComponentBuilder(
          self.master, component_spec)

      # Expect feature extraction to generate a error due to the "history"
      # translator.
      with self.assertRaises(NotImplementedError):
        comp.build_greedy_training(self.master_state, self.network_states)

    # As well as annotation:
    with tf.Graph().as_default():
      comp = bulk_component.BulkAnnotatorComponentBuilder(
          self.master, component_spec)

      with self.assertRaises(NotImplementedError):
        comp.build_greedy_training(self.master_state, self.network_states)
  def testFailsOnRecurrentLinkedFeature(self):
    component_spec = spec_pb2.ComponentSpec()
    text_format.Parse("""
        name: "test"
        network_unit {
          registered_name: "FeedForwardNetwork"
          parameters {
            key: 'hidden_layer_sizes' value: '64'
          }
        }
        linked_feature {
          name: "features" embedding_dim: -1 size: 1
          source_translator: "identity"
          source_component: "test"
          source_layer: "layer_0"
        }
        """, component_spec)

    # For feature extraction:
    with tf.Graph().as_default():
      comp = bulk_component.BulkFeatureExtractorComponentBuilder(
          self.master, component_spec)

      # Expect feature extraction to generate a error due to the "history"
      # translator.
      with self.assertRaises(RuntimeError):
        comp.build_greedy_training(self.master_state, self.network_states)

    # As well as annotation:
    with tf.Graph().as_default():
      comp = bulk_component.BulkAnnotatorComponentBuilder(
          self.master, component_spec)

      with self.assertRaises(RuntimeError):
        comp.build_greedy_training(self.master_state, self.network_states)
  def testFailsOnFixedFeature(self):
    component_spec = spec_pb2.ComponentSpec()
    text_format.Parse("""
        name: "annotate"
        network_unit {
          registered_name: "IdentityNetwork"
        }
        fixed_feature {
          name: "fixed" embedding_dim: 32 size: 1
        }
        """, component_spec)
    with tf.Graph().as_default():
      comp = bulk_component.BulkAnnotatorComponentBuilder(
          self.master, component_spec)

      # Expect feature extraction to generate a runtime error due to the
      # fixed feature.
      with self.assertRaises(RuntimeError):
        comp.build_greedy_training(self.master_state, self.network_states)
    def testPreCreateCalledBeforeCreate(self):
        component_spec = spec_pb2.ComponentSpec()
        text_format.Parse(
            """
        name: "test"
        network_unit {
          registered_name: "IdentityNetwork"
        }
        """, component_spec)

        class AssertPreCreateBeforeCreateNetwork(
                network_units.NetworkUnitInterface):
            """Mock that asserts that .create() is called before .pre_create()."""
            def __init__(self, comp, test_fixture):
                super(AssertPreCreateBeforeCreateNetwork, self).__init__(comp)
                self._test_fixture = test_fixture
                self._pre_create_called = False

            def get_logits(self, network_tensors):
                return tf.zeros([2, 1], dtype=tf.float32)

            def pre_create(self, *unused_args):
                self._pre_create_called = True

            def create(self, *unused_args, **unuesd_kwargs):
                self._test_fixture.assertTrue(self._pre_create_called)
                return []

        builder = bulk_component.BulkFeatureExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)

        self.setUp()
        builder = bulk_component.BulkAnnotatorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_training(component.MasterState(['foo', 'bar'], 2),
                                      self.network_states)

        self.setUp()
        builder = bulk_component.BulkAnnotatorComponentBuilder(
            self.master, component_spec)
        builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
        builder.build_greedy_inference(
            component.MasterState(['foo', 'bar'], 2), self.network_states)