def testTrainStepInfoInvalidInput(invalid_input): '''Test invalid initialization of TrainStepInfo''' optimizer_config = optim.LambConfig() with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config, step=invalid_input)
def testTrainStepInfo(): '''Test valid initializations of TrainStepInfo''' optimizer_config = optim.LambConfig() fetches=['out1','out2'] step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456) assert step_info.optimizer_config == optimizer_config assert step_info.all_finite == False assert step_info.fetches == fetches assert step_info.optimization_step == 123 assert step_info.step == 456 step_info = orttrainer.TrainStepInfo(optimizer_config) assert step_info.optimizer_config == optimizer_config assert step_info.all_finite == True assert step_info.fetches == [] assert step_info.optimization_step == 0 assert step_info.step == 0
def testDynamicLossScaler(): rtol = 1e-5 default_scaler = amp.loss_scaler.DynamicLossScaler() # Initial state train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") assert default_scaler.up_scale_window == 2000 assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") assert_allclose(default_scaler.max_loss_scale, float( 1 << 24), rtol=rtol, err_msg="max loss scale mismatch") # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) loss_scale = float(1 << 16) for cycles in range(1, 10): # 1999 updates without overflow produces 1999 stable steps for i in range(1, 2000): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == i assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached new_loss_scale = default_scaler.update(train_step_info) if cycles <= 8: loss_scale *= 2 assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 8 cycles, loss scale should be float(1 << 16)*(2**8) assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on loss_scale = float(1 << 16)*(2**8) for count in range(1, 2050): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == (count % 2000) assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # Setting train_step_info.all_finite = False to test down scaling train_step_info.all_finite = False # Performing 24 updates to half the loss scale each time loss_scale = float(1 << 16)*(2**8) for count in range(1, 25): new_loss_scale = default_scaler.update(train_step_info) loss_scale /= 2 assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 24 updates with gradient overflow, loss scale is 1.0 assert_allclose(new_loss_scale, 1., rtol=rtol, err_msg="loss scale mismatch") # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on for count in range(1, 5): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == 0 assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch")