def test_build_graph(): """Tests the build_graph method.""" model = base_model.Model(TestAssistant.mock_dataset()) model.model_fn = lambda x, y: (x, y) expected_x = TestAssistant.zero_array() expected_is_training = True x, is_training = model.build_graph(expected_x, expected_is_training) np.testing.assert_array_equal(x, expected_x) assert is_training == expected_is_training
def test_test(): """Test the test() method, doesn't raise.""" model = base_model.Model(TestAssistant.mock_dataset()) model.model_fn = lambda x, y: x model.test()
def test_build_graph_raises(): """Tests that the non overridden model_fn raises on a build_graph call.""" model = base_model.Model(TestAssistant.mock_dataset()) with pytest.raises(NotImplementedError): model.build_graph(TestAssistant.zero_tensor(), True)
def test_model_fn_raises(): """Tests that the non overridden model_fn raises NotImplemented.""" model = base_model.Model(TestAssistant.mock_dataset()) with pytest.raises(NotImplementedError): model.model_fn(TestAssistant.zero_tensor(), True)