def test_aggregation_from_distributed_zero_megatron_full_precision_lamb(
    device="cuda",
    checkpoint_dir="checkpoint_dir/distributed_zero_megatron/full_precision/lamb/"
):
    opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}}

    # extract state dictionaries to compare
    loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(
        device, opts, checkpoint_dir, use_lamb=True)
    test_megatron_aggregation(checkpoint_dir,
                              loaded_state_dict,
                              expected_state_dict,
                              is_mixedprecision=False)
def test_aggregation_from_distributed_zero_megatron_full_precision_adam(
    device='cuda',
    checkpoint_dir='checkpoint_dir/distributed_zero_megatron/full_precision/adam/'
):
    opts = {'device': {'id': device}, 'debug': {'deterministic_compute': True}}

    # extract state dictionaries to compare
    loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(
        device, opts, checkpoint_dir, use_lamb=False)
    test_megatron_aggregation(checkpoint_dir,
                              loaded_state_dict,
                              expected_state_dict,
                              is_mixedprecision=False)