Beispiel #1
0
  def _get_session_with_reader(self, enable_tracing):
    """Utility to create ComputeSession management ops.

    Creates a new ComputeSession handle and provides the following
    named nodes:

    ComputeSession/InputBatch -- a placeholder for attaching a string
      specification for AttachReader.
    ComputeSession/AttachReader -- the AttachReader op.

    Args:
      enable_tracing: bool, whether to enable tracing before attaching the data.

    Returns:
      handle: handle to a new ComputeSession returned by the AttachReader op.
      input_batch: InputBatch placeholder.
    """
    with tf.name_scope('ComputeSession'):
      input_batch = tf.placeholder(
          dtype=tf.string, shape=[None], name='InputBatch')

      # Get the ComputeSession and chain some essential ops.
      handle = self._get_compute_session()
      if enable_tracing:
        handle = dragnn_ops.set_tracing(handle, True)
      handle = dragnn_ops.attach_data_reader(
          handle, input_batch, name='AttachReader')

    return handle, input_batch
Beispiel #2
0
    def _get_session_with_reader(self, enable_tracing):
        """Utility to create ComputeSession management ops.

    Creates a new ComputeSession handle and provides the following
    named nodes:

    ComputeSession/InputBatch -- a placeholder for attaching a string
      specification for AttachReader.
    ComputeSession/AttachReader -- the AttachReader op.

    Args:
      enable_tracing: bool, whether to enable tracing before attaching the data.

    Returns:
      handle: handle to a new ComputeSession returned by the AttachReader op.
      input_batch: InputBatch placeholder.
    """
        with tf.name_scope('ComputeSession'):
            input_batch = tf.placeholder(dtype=tf.string,
                                         shape=[None],
                                         name='InputBatch')

            # Get the ComputeSession and chain some essential ops.
            handle = self._get_compute_session()
            if enable_tracing:
                handle = dragnn_ops.set_tracing(handle, True)
            handle = dragnn_ops.attach_data_reader(handle,
                                                   input_batch,
                                                   name='AttachReader')

        return handle, input_batch
  def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
    path = os.path.join(tf.test.get_temp_dir(), 'label-map')
    with open(path, 'w') as label_map_file:
      label_map_file.write('0\n')

    master_spec = spec_pb2.MasterSpec()
    text_format.Parse("""
        component {
          name: "test"
          transition_system {
            registered_name: "shift-only"
          }
          resource {
            name: "label-map"
            part {
              file_pattern: "%s"
              file_format: "text"
            }
          }
          network_unit {
            registered_name: "ExportFixedFeaturesNetwork"
          }
          backend {
            registered_name: "SyntaxNetComponent"
          }
          fixed_feature {
            name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
            predicate_map: "none"
          }
        }
        """ % path, master_spec)

    with tf.Graph().as_default():
      corpus = _create_fake_corpus()
      corpus = tf.constant(corpus, shape=[len(corpus)])
      handle = dragnn_ops.get_session(
          container='test',
          master_spec=master_spec.SerializeToString(),
          grid_point='')
      handle = dragnn_ops.attach_data_reader(handle, corpus)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=1, component='test')
      batch_size = dragnn_ops.batch_size(handle, component='test')
      master_state = component.MasterState(handle, batch_size)

      extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, master_spec.component[0])
      network_state = component.NetworkState()
      self.network_states['test'] = network_state
      handle = extractor.build_greedy_inference(master_state,
                                                self.network_states)
      focus1 = network_state.activations['focus1'].bulk_tensor
      focus2 = network_state.activations['focus2'].bulk_tensor
      focus3 = network_state.activations['focus3'].bulk_tensor

      with self.test_session() as sess:
        focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
        tf.logging.info('focus1=\n%s', focus1)
        tf.logging.info('focus2=\n%s', focus2)
        tf.logging.info('focus3=\n%s', focus3)

        self.assertAllEqual(
            focus1,
            [[0], [-1], [-1], [-1],
             [0], [1], [-1], [-1],
             [0], [1], [2], [-1],
             [0], [1], [2], [3]])

        self.assertAllEqual(
            focus2,
            [[-1], [-1], [-1], [-1],
             [1], [-1], [-1], [-1],
             [1], [2], [-1], [-1],
             [1], [2], [3], [-1]])

        self.assertAllEqual(
            focus3,
            [[-1], [-1], [-1], [-1],
             [-1], [-1], [-1], [-1],
             [2], [-1], [-1], [-1],
             [2], [3], [-1], [-1]])
  def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
    path = os.path.join(tf.test.get_temp_dir(), 'label-map')
    with open(path, 'w') as label_map_file:
      label_map_file.write('0\n')

    master_spec = spec_pb2.MasterSpec()
    text_format.Parse("""
        component {
          name: "test"
          transition_system {
            registered_name: "shift-only"
          }
          resource {
            name: "label-map"
            part {
              file_pattern: "%s"
              file_format: "text"
            }
          }
          network_unit {
            registered_name: "ExportFixedFeaturesNetwork"
          }
          backend {
            registered_name: "SyntaxNetComponent"
          }
          fixed_feature {
            name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
            predicate_map: "none"
          }
        }
        """ % path, master_spec)

    with tf.Graph().as_default():
      corpus = _create_fake_corpus()
      corpus = tf.constant(corpus, shape=[len(corpus)])
      handle = dragnn_ops.get_session(
          container='test',
          master_spec=master_spec.SerializeToString(),
          grid_point='')
      handle = dragnn_ops.attach_data_reader(handle, corpus)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=1, component='test')
      batch_size = dragnn_ops.batch_size(handle, component='test')
      master_state = component.MasterState(handle, batch_size)

      extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, master_spec.component[0])
      network_state = component.NetworkState()
      self.network_states['test'] = network_state
      handle = extractor.build_greedy_inference(master_state,
                                                self.network_states)
      focus1 = network_state.activations['focus1'].bulk_tensor
      focus2 = network_state.activations['focus2'].bulk_tensor
      focus3 = network_state.activations['focus3'].bulk_tensor

      with self.test_session() as sess:
        focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
        tf.logging.info('focus1=\n%s', focus1)
        tf.logging.info('focus2=\n%s', focus2)
        tf.logging.info('focus3=\n%s', focus3)

        self.assertAllEqual(
            focus1,
            [[0], [-1], [-1], [-1],
             [0], [1], [-1], [-1],
             [0], [1], [2], [-1],
             [0], [1], [2], [3]])

        self.assertAllEqual(
            focus2,
            [[-1], [-1], [-1], [-1],
             [1], [-1], [-1], [-1],
             [1], [2], [-1], [-1],
             [1], [2], [3], [-1]])

        self.assertAllEqual(
            focus3,
            [[-1], [-1], [-1], [-1],
             [-1], [-1], [-1], [-1],
             [2], [-1], [-1], [-1],
             [2], [3], [-1], [-1]])