Пример #1
0
def testTrainStepInfoInvalidInput(invalid_input):
    '''Test invalid initialization of TrainStepInfo'''
    optimizer_config = optim.LambConfig()
    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input)

    with pytest.raises(AssertionError):
        orttrainer.TrainStepInfo(optimizer_config, step=invalid_input)
Пример #2
0
def testTrainStepInfo():
    '''Test valid initializations of TrainStepInfo'''

    optimizer_config = optim.LambConfig()
    fetches=['out1','out2']
    step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config,
                                         all_finite=False,
                                         fetches=fetches,
                                         optimization_step=123,
                                         step=456)
    assert step_info.optimizer_config == optimizer_config
    assert step_info.all_finite == False
    assert step_info.fetches == fetches
    assert step_info.optimization_step == 123
    assert step_info.step == 456

    step_info = orttrainer.TrainStepInfo(optimizer_config)
    assert step_info.optimizer_config == optimizer_config
    assert step_info.all_finite == True
    assert step_info.fetches == []
    assert step_info.optimization_step == 0
    assert step_info.step == 0
Пример #3
0
def testDynamicLossScaler():
    rtol = 1e-5
    default_scaler = amp.loss_scaler.DynamicLossScaler()

    # Initial state
    train_step_info = orttrainer.TrainStepInfo(optim.LambConfig())
    assert_allclose(default_scaler.loss_scale, float(1 << 16),
                    rtol=rtol, err_msg="loss scale mismatch")
    assert default_scaler.up_scale_window == 2000
    assert_allclose(default_scaler.min_loss_scale, 1.0,
                    rtol=rtol, err_msg="min loss scale mismatch")
    assert_allclose(default_scaler.max_loss_scale, float(
        1 << 24), rtol=rtol, err_msg="max loss scale mismatch")

    # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True)
    loss_scale = float(1 << 16)
    for cycles in range(1, 10):

        # 1999 updates without overflow produces 1999 stable steps
        for i in range(1, 2000):
            new_loss_scale = default_scaler.update(train_step_info)
            assert default_scaler._stable_steps_count == i
            assert_allclose(new_loss_scale, loss_scale,
                            rtol=rtol, err_msg=f"loss scale mismatch at update {i}")

        # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached
        new_loss_scale = default_scaler.update(train_step_info)
        if cycles <= 8:
            loss_scale *= 2
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # After 8 cycles, loss scale should be float(1 << 16)*(2**8)
    assert_allclose(new_loss_scale, float(1 << 16)
                    * (2**8), rtol=rtol, err_msg="loss scale mismatch")

    # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on
    loss_scale = float(1 << 16)*(2**8)
    for count in range(1, 2050):
        new_loss_scale = default_scaler.update(train_step_info)
        assert default_scaler._stable_steps_count == (count % 2000)
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # Setting train_step_info.all_finite = False to test down scaling
    train_step_info.all_finite = False

    # Performing 24 updates to half the loss scale each time
    loss_scale = float(1 << 16)*(2**8)
    for count in range(1, 25):
        new_loss_scale = default_scaler.update(train_step_info)
        loss_scale /= 2
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")

    # After 24 updates with gradient overflow, loss scale is 1.0
    assert_allclose(new_loss_scale, 1.,
                    rtol=rtol, err_msg="loss scale mismatch")

    # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on
    for count in range(1, 5):
        new_loss_scale = default_scaler.update(train_step_info)
        assert default_scaler._stable_steps_count == 0
        assert_allclose(new_loss_scale, loss_scale,
                        rtol=rtol, err_msg="loss scale mismatch")