예제 #1
0
  def testAddCellSubgraphSpecHook(self):
    component = MockComponent()
    cell = export_pb2.CellSubgraphSpec()
    cell.input.add(
        name='feature',
        tensor='feature_tensor',
        type=export_pb2.CellSubgraphSpec.Input.TYPE_FEATURE)
    cell.input.add(
        name='recurrent',
        tensor='recurrent_tensor',
        type=export_pb2.CellSubgraphSpec.Input.TYPE_RECURRENT)
    cell.output.add(name='layer_0', tensor='layer_0_tensor')
    cell.output.add(name='logits', tensor='logits_tensor')

    with self.test_session() as session:
      graph = session.graph

      # Add hooks for the cell constructed above.
      with tf.variable_scope(component.name, reuse=True):
        runtime_support.add_hooks(component, cell)

      # Get the hook containing the wire-format proto.
      cell_wire_format = graph.get_tensor_by_name(
          '{}/EXPORT/CellSubgraphSpec:0'.format(component.name))

      # Check that the hook matches the cell.
      tf.global_variables_initializer().run()
      self.assertEqual(cell_wire_format.eval(), cell.SerializeToString())
예제 #2
0
  def testAddDerivedParamHooks(self):
    component = MockComponent()
    derived_name = 'derived'

    with self.test_session() as session:
      graph = session.graph

      # Add hooks.
      with tf.variable_scope(component.name, reuse=True):
        runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())

      session.run(tf.global_variables_initializer())

      # Get hooks for the derived vector.
      vector = graph.get_tensor_by_name('derived/vector:0')
      self.assertEqual(vector.shape, (3,))

      # Get the hooks for the derived variable.
      matrix = graph.get_tensor_by_name(
          '{}/{}/matrix/blocked32:0'.format(component.name, derived_name))
      self.assertAllEqual(tf.shape(matrix).eval(), [4, 128, 32])

      # Check the bfloat16 version. It should have the same shape.
      bfloat16_matrix = graph.get_tensor_by_name(
          '{}/{}/matrix/blocked32/bfloat16:0'.format(component.name,
                                                     derived_name))
      self.assertAllEqual(tf.shape(bfloat16_matrix).eval(), [4, 128, 32])
예제 #3
0
  def testAddFixedHooks(self):
    component = MockComponent()
    fixed0 = component.spec.fixed_feature.add()
    fixed1 = component.spec.fixed_feature.add()
    fixed0.embedding_dim = -1
    fixed1.embedding_dim = 32
    fixed0.vocabulary_size = 100
    fixed1.vocabulary_size = 1000
    fixed0_matrix_name = network_units.fixed_embeddings_name(0)
    fixed1_matrix_name = network_units.fixed_embeddings_name(1)

    with self.test_session() as session:
      graph = session.graph

      # Create fixed embedding matrices.  Only channel 1 uses one.
      with tf.variable_scope(component.name):
        tf.get_variable(
            fixed1_matrix_name, shape=[1000 + 1, 32], dtype=tf.float32)

      # Add hooks.  This should ignore channel 0 and add hooks for channel 1.
      with tf.variable_scope(component.name, reuse=True):
        runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())

      # Check that no hooks were added for channel 0.
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/trimmed:0'.format(component.name, fixed0_matrix_name))

      # Get the hooks added for channel 1.
      trimmed = graph.get_tensor_by_name(
          '{}/{}/trimmed:0'.format(component.name, fixed1_matrix_name))

      # Check dimensions of the hooks.
      tf.global_variables_initializer().run()
      self.assertAllEqual(tf.shape(trimmed).eval(), [1000, 32])
예제 #4
0
    def _add_runtime_hooks(self):
        """Adds "hook" nodes to the graph for use by the runtime, if enabled.

    Does nothing if master.build_runtime_graph is False.  Subclasses should call
    this at the end of build_*_inference().  For details on the runtime hooks,
    see runtime_support.py.
    """
        if self.master.build_runtime_graph:
            with tf.variable_scope(self.name, reuse=True):
                runtime_support.add_hooks(self, self._cell_subgraph_spec)
            self._cell_subgraph_spec = None  # prevent further exports
예제 #5
0
  def testAddLinkedHooks(self):
    component = MockComponent()
    link0 = component.spec.linked_feature.add()
    link1 = component.spec.linked_feature.add()
    link0.embedding_dim = -1  # direct link
    link1.embedding_dim = 32  # transformed link
    link0_matrix_name = network_units.linked_embeddings_name(0)
    link1_matrix_name = network_units.linked_embeddings_name(1)

    with self.test_session() as session:
      graph = session.graph

      # Create linked embedding matrices.  Only channel 1 uses one.
      with tf.variable_scope(component.name):
        tf.get_variable(link1_matrix_name, shape=[64 + 1, 32], dtype=tf.float32)

      # Add hooks.  This should ignore channel 0 and add hooks for channel 1.
      with tf.variable_scope(component.name, reuse=True):
        runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())

      # Check that no hooks were added for channel 0.
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/weights:0'.format(component.name, link0_matrix_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name('{}/{}/weights/transposed:0'.format(
            component.name, link0_matrix_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name('{}/{}/weights/transposed/shape:0'.format(
            component.name, link0_matrix_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name('{}/{}/weights/transposed/blocked32:0'.format(
            component.name, link0_matrix_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name('{}/{}/weights/transposed/blocked48:0'.format(
            component.name, link0_matrix_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/out_of_bounds:0'.format(component.name, link0_matrix_name))

      # Get the hooks added for channel 1.
      weights = graph.get_tensor_by_name(
          '{}/{}/weights:0'.format(component.name, link1_matrix_name))
      transposed = graph.get_tensor_by_name('{}/{}/weights/transposed:0'.format(
          component.name, link1_matrix_name))
      transposed_shape = graph.get_tensor_by_name(
          '{}/{}/weights/transposed/shape:0'.format(component.name,
                                                    link1_matrix_name))
      transposed32 = graph.get_tensor_by_name(
          '{}/{}/weights/transposed/blocked32:0'.format(component.name,
                                                        link1_matrix_name))
      transposed48 = graph.get_tensor_by_name(
          '{}/{}/weights/transposed/blocked48:0'.format(component.name,
                                                        link1_matrix_name))
      out_of_bounds = graph.get_tensor_by_name(
          '{}/{}/out_of_bounds:0'.format(component.name, link1_matrix_name))

      # Check dimensions of the hooks.
      tf.global_variables_initializer().run()
      self.assertAllEqual(tf.shape(weights).eval(), [64, 32])
      self.assertAllEqual(tf.shape(transposed).eval(), [32, 64])
      self.assertAllEqual(transposed_shape.eval(), [32, 64])
      self.assertAllEqual(tf.shape(transposed32).eval(), [2, 32, 32])
      self.assertAllEqual(tf.shape(transposed48).eval(), [2, 32, 48])
      self.assertAllEqual(tf.shape(out_of_bounds).eval(), [1, 32])
예제 #6
0
  def testAddParamsHooks(self):
    component = MockComponent()
    rank2_name = 'rank2'
    rank3_name = 'rank3'

    with self.test_session() as session:
      graph = session.graph

      # Add hooks.  This should add hooks for all rank-2 params.
      with tf.variable_scope(component.name, reuse=True):
        runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())

      # Check that no hooks were added for the rank-3 params.
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/matrix:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/transposed:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/matrix/blocked32:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/matrix/blocked48:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/transposed/blocked32:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/transposed/blocked48:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/matrix/shape:0'.format(component.name, rank3_name))
      with self.assertRaises(KeyError):
        graph.get_tensor_by_name(
            '{}/{}/transposed/shape:0'.format(component.name, rank3_name))

      # Get the hooks added for each variable.
      matrix = graph.get_tensor_by_name(
          '{}/{}/matrix:0'.format(component.name, rank2_name))
      transposed = graph.get_tensor_by_name(
          '{}/{}/transposed:0'.format(component.name, rank2_name))
      matrix32 = graph.get_tensor_by_name(
          '{}/{}/matrix/blocked32:0'.format(component.name, rank2_name))
      matrix48 = graph.get_tensor_by_name(
          '{}/{}/matrix/blocked48:0'.format(component.name, rank2_name))
      transposed32 = graph.get_tensor_by_name(
          '{}/{}/transposed/blocked32:0'.format(component.name, rank2_name))
      transposed48 = graph.get_tensor_by_name(
          '{}/{}/transposed/blocked48:0'.format(component.name, rank2_name))
      matrix_shape = graph.get_tensor_by_name(
          '{}/{}/matrix/shape:0'.format(component.name, rank2_name))
      transposed_shape = graph.get_tensor_by_name(
          '{}/{}/transposed/shape:0'.format(component.name, rank2_name))

      # Check dimensions of the hooks.
      tf.global_variables_initializer().run()
      self.assertAllEqual(tf.shape(matrix).eval(), [64, 127])
      self.assertAllEqual(tf.shape(transposed).eval(), [127, 64])
      self.assertAllEqual(matrix_shape.eval(), [64, 127])
      self.assertAllEqual(transposed_shape.eval(), [127, 64])
      self.assertAllEqual(tf.shape(matrix32).eval(), [4, 64, 32])
      self.assertAllEqual(tf.shape(matrix48).eval(), [3, 64, 48])
      self.assertAllEqual(tf.shape(transposed32).eval(), [2, 127, 32])
      self.assertAllEqual(tf.shape(transposed48).eval(), [2, 127, 48])