def test_evaluate(): param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1") param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2") add = ops.add(param1, param2) func = Function(add, [param1, param2], "TestFunction") input1 = np.array([2, 1], dtype=np.float32).reshape(2, 1) input2 = np.array([3, 7], dtype=np.float32).reshape(2, 1) out_tensor = Tensor("float32", Shape([2, 1])) assert func.evaluate([out_tensor], [Tensor(input1), Tensor(input2)]) assert np.allclose(out_tensor.data, np.array([5, 8]).reshape(2, 1))
def test_evaluate_invalid_input_shape(): param1 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data1") param2 = ops.parameter(Shape([2, 1]), dtype=np.float32, name="data2") add = ops.add(param1, param2) func = Function(add, [param1, param2], "TestFunction") with pytest.raises(RuntimeError) as e: assert func.evaluate( [Tensor("float32", Shape([2, 1]))], [ Tensor("float32", Shape([3, 1])), Tensor("float32", Shape([3, 1])) ], ) assert "must be compatible with the partial shape: {2,1}" in str(e.value)