def testTags(self): """Test if multiple args with the same tag are grouped.""" a = array_ops.constant([1.]) b = array_ops.constant([2.]) c = array_ops.constant([3.]) d = array_ops.constant([4.]) custom = op_hint.OpHint("test_tag") a = custom.add_input(a, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK) b, = custom.add_inputs(b) c = custom.add_input(c, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK) d = custom.add_input(d, tag="mytag2", aggregate=op_hint.OpHint.AGGREGATE_STACK) res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b)) custom.add_outputs([res]) with self.cached_session(): self.assertEqual(self._get_input_index(a), 0) self.assertEqual(self._get_sort_index(a), 0) self.assertEqual(self._get_input_index(b), 1) self.assertEqual(self._get_input_index(c), 0) self.assertEqual(self._get_sort_index(c), 1)
def testAggregate(self): a = array_ops.constant([3., 4.]) b = array_ops.constant([5., 6.]) hint = op_hint.OpHint("agg") a0, a1 = array_ops.unstack(a) b0, b1 = array_ops.unstack(b) a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) c0 = math_ops.add(a0, b0, name="addleft") c1 = math_ops.add(a1, b1, name="addright") c0 = hint.add_output( c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) c1 = hint.add_output( c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) curr = array_ops.stack([c0, c1]) output = array_ops.identity(curr, name="FINAL_OUTPUT") with self.cached_session() as sess: stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) self.assertEqual( self._getGraphOpTypes( stubbed_graphdef, output_nodes=[op_hint._tensor_name_base(output.name)]), set(["agg", "Const", "Identity"]))
def testOverrideIndex(self): a = array_ops.constant([1.]) b = array_ops.constant([2.]) c = array_ops.constant([3.]) custom = op_hint.OpHint("test_override") b = custom.add_input(b) # should auto assign 0 a = custom.add_input(a, index_override=1) c = custom.add_input(c) # should auto assign 2 with self.cached_session(): self.assertEqual(self._get_input_index(a), 1) self.assertEqual(self._get_input_index(b), 0) self.assertEqual(self._get_input_index(c), 2)
def _double_values(x): custom = op_hint.OpHint("add_test") x, = custom.add_inputs(x) output = math_ops.multiply(x, x) output, = custom.add_outputs(output) return output
def _scaled_and_bias_and_identity(a, x, b): custom = op_hint.OpHint("scale_and_bias_and_identity") a, x, b = custom.add_inputs(a, x, b) return custom.add_outputs(a * x + b, x)
def _swish(input_tensor, scale): custom = op_hint.OpHint("cool_activation") input_tensor, scale = custom.add_inputs(input_tensor, scale) output = math_ops.sigmoid(input_tensor) * input_tensor * scale output, = custom.add_outputs(output) return output