def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(
        loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=0.001)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=legacy_loss_scaler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelGradientAccumulationLegacyExperimental(
        gradient_accumulation_steps):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch': {
            'gradient_accumulation_steps': gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        loss = trainer.train_step(*sample_input)
        experimental_losses.append(loss.cpu().item())

    # LEGACY IMPLEMENTATION
    device = torch.device(device)
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "LambOptimizer",
        None,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
        gradient_accumulation_steps=gradient_accumulation_steps)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       rtol=1e-6)
def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config):
    # Common setup
    train_steps = 512

    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    optim_config = optimizer_config(lr=0.01)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    if optimizer_config == optim.AdamConfig:
        legacy_optimizer = 'AdamOptimizer'
    elif optimizer_config == optim.LambConfig:
        legacy_optimizer = 'LambOptimizer'
    elif optimizer_config == optim.SGDConfig:
        legacy_optimizer = 'SGDOptimizer'
    else:
        raise RuntimeError("Invalid optimizer_config")

    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer,
                       None,
                       learning_rate_description,
                       device,
                       _use_deterministic_compute=True)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL API
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.AdamConfig(
        params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False
    )
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr)

    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        legacy_optim_map,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
    )
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        legacy_sample_input = [*sample_input, learning_rate]
        legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelLegacyExperimentalBasicTraining():
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        lr=0.001)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc,
                                       "LambOptimizer", None,
                                       learning_rate_description, device)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       True,
                                       rtol=1e-5)
Ejemplo n.º 6
0
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
    # Common data
    total_steps = 5
    bptt=35

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    loss_scaler = amp.DynamicLossScaler()
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'mixed_precision' : {
                                                'enabled' : True,
                                                'loss_scaler' : loss_scaler},
                                            'debug' : {'deterministic_compute' : True,}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    experimental_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())
        experimental_preds_dtype.append(exp_preds.dtype)

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=loss_scaler)
    # Training loop
    legacy_loss = []
    legacy_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())
        legacy_preds_dtype.append(leg_preds.dtype)

    # Compare legacy vs experimental APIs
    assert experimental_preds_dtype == legacy_preds_dtype
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
Ejemplo n.º 7
0
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps):
    # Common data
    torch.set_printoptions(precision=10)

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
                                            'debug' : {'deterministic_compute' : True}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       gradient_accumulation_steps=gradient_accumulation_steps)
    # Training loop
    legacy_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
Ejemplo n.º 8
0
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
    # Common data
    total_steps = 5
    bptt = 35

    # Setup for the experimental ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'device' : {
            'id' : device
        },
        'debug' : {
            'deterministic_compute': True
        },
    })
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _ = trainer.train_step(data, targets)

    # Setup for the legacy ORTTrainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc,
                                       device, _use_deterministic_compute=True)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))

    # Compare legacy vs experimental APIs
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler,
                                                  legacy_lr_scheduler):
    ############################################################################
    # These tests require hard-coded values for 'total_steps' and 'initial_lr' #
    ############################################################################

    # Common setup
    total_steps = 128
    device = 'cuda'
    seed = 1
    warmup = 0.05
    cycles = 0.5
    power = 1.
    lr_end = 1e-7

    # Setup both Experimental and Legacy LR Schedulers before the experimental loop
    if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup)
    elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      cycles=cycles)
    elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      power=power,
                                      lr_end=lr_end)
    else:
        raise RuntimeError("Invalid legacy_lr_scheduler")
    if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    power=power,
                                    lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'lr_scheduler': lr_scheduler
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())
        assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0],
                        legacy_lr_scheduler(i))

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        initial_lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step):
    # Common data
    total_steps = 10
    lr = 0.001
    warmup = 0.5
    cycles = 0.5
    power = 1.
    lr_end = 1e-7
    torch.set_printoptions(precision=10)

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'debug' : {'deterministic_compute' : True},
                                            'lr_scheduler' : lr_scheduler})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optimizer_config(lr=lr)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)

    if optimizer_config == optim.AdamConfig:
        legacy_optimizer_config = 'AdamOptimizer'
    elif optimizer_config == optim.LambConfig:
        legacy_optimizer_config = 'LambOptimizer'
    elif optimizer_config == optim.SGDConfig:
        legacy_optimizer_config = 'SGDOptimizer'
    else:
        raise RuntimeError("Invalid optimizer_config")

    if get_lr_this_step == _test_commons.legacy_constant_lr_scheduler or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup)
    elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid get_lr_this_step")

    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, legacy_optimizer_config,
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=get_lr_this_step)
    # Training loop
    legacy_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets)
        legacy_loss.append(leg_loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss)
Ejemplo n.º 11
0
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler,
                                                  legacy_lr_scheduler):
    ############################################################################
    # These tests require hard-coded values for 'total_steps' and 'initial_lr' #
    ############################################################################

    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'lr_scheduler':
        lr_scheduler(total_steps=total_steps, warmup=0.5)
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())
        assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0],
                        legacy_lr_scheduler(i))

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        initial_lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)