def test_objective_no_target(): from nntools.objectives import Objective input_layer = mock.Mock() loss_function = mock.Mock() input = object() objective = Objective(input_layer, loss_function) result = objective.get_loss(input) input_layer.get_output.assert_called_with(input) network_output = input_layer.get_output.return_value loss_function.assert_called_with(network_output, objective.target_var) assert result == loss_function.return_value
def test_objective(): from nntools.objectives import Objective input_layer = mock.Mock() loss_function = mock.Mock() input, target, arg1, kwarg1 = (object(),) * 4 objective = Objective(input_layer, loss_function) result = objective.get_loss(input, target, arg1, kwarg1=kwarg1) # We expect that the input layer's `get_output` was called with # the `input` argument we provided, plus the extra positional and # keyword arguments. input_layer.get_output.assert_called_with(input, arg1, kwarg1=kwarg1) network_output = input_layer.get_output.return_value # The `network_output` and `target` are fed into the loss # function: loss_function.assert_called_with(network_output, target) assert result == loss_function.return_value