Exemplo n.º 1
0
 def testTags(self):
     """Test if multiple args with the same tag are grouped."""
     a = array_ops.constant([1.])
     b = array_ops.constant([2.])
     c = array_ops.constant([3.])
     d = array_ops.constant([4.])
     custom = op_hint.OpHint("test_tag")
     a = custom.add_input(a,
                          tag="mytag",
                          aggregate=op_hint.OpHint.AGGREGATE_STACK)
     b, = custom.add_inputs(b)
     c = custom.add_input(c,
                          tag="mytag",
                          aggregate=op_hint.OpHint.AGGREGATE_STACK)
     d = custom.add_input(d,
                          tag="mytag2",
                          aggregate=op_hint.OpHint.AGGREGATE_STACK)
     res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
     custom.add_outputs([res])
     with self.cached_session():
         self.assertEqual(self._get_input_index(a), 0)
         self.assertEqual(self._get_sort_index(a), 0)
         self.assertEqual(self._get_input_index(b), 1)
         self.assertEqual(self._get_input_index(c), 0)
         self.assertEqual(self._get_sort_index(c), 1)
Exemplo n.º 2
0
  def testAggregate(self):
    a = array_ops.constant([3., 4.])
    b = array_ops.constant([5., 6.])
    hint = op_hint.OpHint("agg")
    a0, a1 = array_ops.unstack(a)
    b0, b1 = array_ops.unstack(b)

    a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)

    c0 = math_ops.add(a0, b0, name="addleft")
    c1 = math_ops.add(a1, b1, name="addright")
    c0 = hint.add_output(
        c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    c1 = hint.add_output(
        c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)

    curr = array_ops.stack([c0, c1])
    output = array_ops.identity(curr, name="FINAL_OUTPUT")
    with self.cached_session() as sess:
      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
          graph_def=sess.graph_def)
      self.assertEqual(
          self._getGraphOpTypes(
              stubbed_graphdef,
              output_nodes=[op_hint._tensor_name_base(output.name)]),
          set(["agg", "Const", "Identity"]))
Exemplo n.º 3
0
 def testOverrideIndex(self):
     a = array_ops.constant([1.])
     b = array_ops.constant([2.])
     c = array_ops.constant([3.])
     custom = op_hint.OpHint("test_override")
     b = custom.add_input(b)  # should auto assign 0
     a = custom.add_input(a, index_override=1)
     c = custom.add_input(c)  # should auto assign 2
     with self.cached_session():
         self.assertEqual(self._get_input_index(a), 1)
         self.assertEqual(self._get_input_index(b), 0)
         self.assertEqual(self._get_input_index(c), 2)
Exemplo n.º 4
0
 def _double_values(x):
     custom = op_hint.OpHint("add_test")
     x, = custom.add_inputs(x)
     output = math_ops.multiply(x, x)
     output, = custom.add_outputs(output)
     return output
Exemplo n.º 5
0
 def _scaled_and_bias_and_identity(a, x, b):
     custom = op_hint.OpHint("scale_and_bias_and_identity")
     a, x, b = custom.add_inputs(a, x, b)
     return custom.add_outputs(a * x + b, x)
Exemplo n.º 6
0
 def _swish(input_tensor, scale):
     custom = op_hint.OpHint("cool_activation")
     input_tensor, scale = custom.add_inputs(input_tensor, scale)
     output = math_ops.sigmoid(input_tensor) * input_tensor * scale
     output, = custom.add_outputs(output)
     return output