コード例 #1
0
def test_trainer_utils_put_model_to_train_mode_should_call_the_model_train_method(
):
    # Given
    trainer = TrainerUtils()
    model_mock = MagicMock()

    # When
    trainer._put_model_to_train_mode(model_mock)

    # Then
    model_mock.train.assert_called()
コード例 #2
0
def test_trainer_utils_compute_gradients_should_call_the_backward_method_of_the_loss_tensor(
):
    # Given
    trainer = TrainerUtils()
    loss_tensor_mock = MagicMock()

    # When
    trainer._compute_gradients(loss_tensor_mock)

    # Then
    loss_tensor_mock.backward.assert_called()
コード例 #3
0
def test_trainer_utils_clean_gradients_should_call_the_model_zero_grad_method(
):
    # Given
    trainer = TrainerUtils()
    model_mock = MagicMock()

    # When
    trainer._clean_gradients(model_mock)

    # Then
    model_mock.zero_grad.assert_called()
コード例 #4
0
def test_trainer_utils_apply_gradient_descent_should_call_the_optimizer_step_method(
):
    # Given
    trainer = TrainerUtils()
    optimizer_mock = MagicMock()

    # When
    trainer._apply_gradient_descent(optimizer_mock)

    # Then
    optimizer_mock.step.assert_called()
コード例 #5
0
def test_trainer_utils_compute_loss_should_call_the_forward_method_of_the_loss_module(
):
    # Given
    trainer = TrainerUtils()
    criterion_mock = MagicMock()

    # When
    _ = trainer._compute_loss(criterion_mock, "a", "b")

    # Then
    criterion_mock.assert_called_with("a", "b")
コード例 #6
0
def test_trainer_utils_transpose_decoder_output_matrix_should_call_the_transpose_method_with_correct_parameters(
):
    # Given
    trainer = TrainerUtils()
    tensor_mock = MagicMock()

    # When
    _ = trainer._transpose_decoder_output_matrix(tensor_mock)

    # Then
    tensor_mock.transpose.assert_called_once_with(2, 1)
コード例 #7
0
def test_trainer_utils_put_model_on_the_device_should_call_the_to_method_with_correct_device(
):
    # Given
    trainer = TrainerUtils()
    lstm_model_mock = MagicMock()

    # When
    trainer._put_model_on_the_device(lstm_model_mock)

    # Then
    lstm_model_mock.to.assert_called_with(DEVICE)
コード例 #8
0
def test_trainer_utils_put_tensors_on_the_should_call_the_to_method_with_correct_device(
):
    # Given
    trainer = TrainerUtils()
    tensor_1 = MagicMock()
    tensor_2 = MagicMock()
    tensors = tuple((tensor_1, tensor_2))

    # When
    _ = trainer._put_tensors_on_the_device(tensors)

    # Then
    tensor_1.to.assert_called_with(DEVICE)
    tensor_2.to.assert_called_with(DEVICE)
コード例 #9
0
def test_trainer_utils_detach_hidden_states_should_call_the_detach_method_for_a_tuple_of_tensor(
):
    # Given
    trainer = TrainerUtils()
    hidden_state_1 = MagicMock()
    hidden_state_2 = MagicMock()
    hidden_states = tuple((hidden_state_1, hidden_state_2))

    # When
    _ = trainer._detach_hidden_states(hidden_states)

    # Then
    hidden_state_1.detach.assert_called()
    hidden_state_2.detach.assert_called()
コード例 #10
0
def test_trainer_utils_get_model_output_should_call_the_model_forward_pass():
    # Given
    trainer = TrainerUtils()
    model_mock = MagicMock()
    model_mock.return_value = [1, 2]

    sequence_of_ids = [3, 4]
    hidden_states = [5, 6]

    # When
    _ = trainer._get_model_output(model_mock, sequence_of_ids, hidden_states)

    # Then
    model_mock.assert_called_with(sequence_of_ids, hidden_states)