Esempio n. 1
0
def test_fit_loop_calls_fit_with_data(n_batches, n_epochs):
    n_batches = 5
    n_epochs = 1
    config = fv3fit.TrainingLoopConfig(epochs=n_epochs)
    mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    config.fit_loop(mock_model, mock_Xy, validation_data)
    assert mock_model.fit.call_count == len(mock_Xy)
    epoch_calls = []
    for i in range(n_epochs):
        epoch_calls.append(
            mock_model.fit.call_args_list[i * n_batches : (i + 1) * n_batches]
        )
    # check all data are passed as args on each epoch
    for call_args_list in epoch_calls:
        for X, y in mock_Xy:
            assert (
                mock.call(
                    X, y, batch_size=config.batch_size, validation_data=validation_data,
                )
                in call_args_list
            )
Esempio n. 2
0
def test_fit_loop_calls_fit(n_epochs, n_batches):
    n_batches = 5
    n_epochs = 1
    config = fv3fit.TrainingLoopConfig(epochs=n_epochs)
    mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    config.fit_loop(mock_model, mock_Xy, validation_data)
    assert mock_model.fit.call_count == n_epochs * n_batches
Esempio n. 3
0
def test_fit_loop():
    n_batches = 5
    n_callbacks = 0
    config = fv3fit.TrainingLoopConfig()
    mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    callbacks = tuple(mock.MagicMock for _ in range(n_callbacks))
    config.fit_loop(mock_model, mock_Xy, validation_data, callbacks)
Esempio n. 4
0
def test_fit_loop_calls_in_reproducible_order():
    n_batches = 5
    n_epochs = 2
    config = fv3fit.TrainingLoopConfig(epochs=n_epochs)
    first_mock_model = mock.MagicMock(spec=tf.keras.Model)
    second_mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    fv3fit.set_random_seed(0)
    config.fit_loop(first_mock_model, mock_Xy, validation_data)
    fv3fit.set_random_seed(0)
    config.fit_loop(second_mock_model, mock_Xy, validation_data)
    assert first_mock_model.fit.call_args_list == second_mock_model.fit.call_args_list
Esempio n. 5
0
def test_fit_loop_calls_in_different_order_on_two_epochs():
    n_batches = 5
    n_epochs = 2
    config = fv3fit.TrainingLoopConfig(epochs=n_epochs)
    mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    config.fit_loop(mock_model, mock_Xy, validation_data)
    assert mock_model.fit.call_count == len(mock_Xy) * n_epochs
    epoch_calls = []
    for i in range(n_epochs):
        epoch_calls.append(
            mock_model.fit.call_args_list[i * n_batches : (i + 1) * n_batches]
        )
    assert epoch_calls[0] != epoch_calls[1]
Esempio n. 6
0
def test_fit_loop_calls_callbacks(n_callbacks, n_epochs):
    n_batches = 5
    config = fv3fit.TrainingLoopConfig(epochs=n_epochs)
    mock_model = mock.MagicMock(spec=tf.keras.Model)
    mock_history = mock.MagicMock(spec=tf.keras.callbacks.History)
    mock_model.fit.return_value = mock_history
    mock_Xy = []
    for _ in range(n_batches):
        mock_Xy.append(
            (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
        )
    validation_data = (mock.MagicMock(spec=np.ndarray), mock.MagicMock(spec=np.ndarray))
    callbacks = tuple(mock.MagicMock() for _ in range(n_callbacks))
    config.fit_loop(mock_model, mock_Xy, validation_data, callbacks)
    for callback in callbacks:
        assert callback.call_count == n_epochs
        for i, call_args in enumerate(callback.call_args_list):
            assert call_args == mock.call(
                fv3fit.EpochResult(
                    epoch=i, history=tuple(mock_history for _ in range(n_batches))
                )
            )