def func_true(a, b): minus_one = Tensor( data_layout=types_pb2.N, tensor_data=np.array([-1], dtype=self.dtype)) return control_flow_ops.cond( math_ops.less(a, b), lambda: math_ops.add(a, math_ops.mul(b, minus_one)), lambda: math_ops.add(a, b))
def func_true(a, b): minus_one = Tensor( data_layout=types_pb2.N, tensor_data=np.array([-1], dtype=self.dtype)) res = control_flow_ops.cond( math_ops.less(a, b), lambda: math_ops.add(a, math_ops.mul(b, minus_one)), lambda: math_ops.add(a, b))[0] # Use the cond results before returning. return math_ops.mul(res, res)
def test_cond_op_simple_func(self): with Graph(name=self.graph_name, backend=self.backend) as graph: x0 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype)) x1 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype)) y = Tensor( data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype)) z = Tensor( data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype)) expected_res = Tensor( data_layout=types_pb2.N, tensor_data=np.array([30], dtype=self.dtype)) # res = y + z if x0 < x1 else y * z res = control_flow_ops.cond( math_ops.less(x0, x1), lambda: math_ops.add(y, z), lambda: math_ops.mul(y, z)) self.runAndValidate(graph, expected_res.tensor_data)
def test_use_nested_op_result(self): def func_true(a, b): minus_one = Tensor( data_layout=types_pb2.N, tensor_data=np.array([-1], dtype=self.dtype)) res = control_flow_ops.cond( math_ops.less(a, b), lambda: math_ops.add(a, math_ops.mul(b, minus_one)), lambda: math_ops.add(a, b))[0] # Use the cond results before returning. return math_ops.mul(res, res) def func_false(a, b): two = Tensor( data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype)) return control_flow_ops.cond( math_ops.greater(a, b), lambda: math_ops.mul(a, two), lambda: math_ops.mul(b, two)) with Graph(name=self.graph_name, backend=self.backend) as graph: x0 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype)) x1 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype)) y = Tensor( data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype)) z = Tensor( data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype)) expected_res = Tensor( data_layout=types_pb2.N, tensor_data=np.array([100], dtype=self.dtype)) # if x0 < x1: # if y < z: # res = (y - z) ^ 2 # else: # res = y + z # else: # if y > z: # res = 2y # else: # res = 2z res = control_flow_ops.cond( math_ops.less(x0, x1), lambda: func_true(y, z), lambda: func_false(y, z)) self.runAndValidate(graph, expected_res.tensor_data)
def test_cond_op_func_call(self): def func(a, b): minus_three = Tensor( data_layout=types_pb2.N, tensor_data=np.array([-3], dtype=self.dtype)) return math_ops.add(a, math_ops.mul(b, minus_three)) with Graph(name=self.graph_name, backend=self.backend) as graph: x0 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype)) x1 = Tensor( data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype)) y = Tensor( data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype)) z = Tensor( data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype)) expected_res = Tensor( data_layout=types_pb2.N, tensor_data=np.array([-50], dtype=self.dtype)) # res = y - 3z if x0 < x1 else y * z res = control_flow_ops.cond( math_ops.less(x0, x1), lambda: func(y, z), lambda: math_ops.mul(y, z)) self.runAndValidate(graph, expected_res.tensor_data)