Пример #1
0
    def testWrapModelLossFnStateDict(self):
        torch.manual_seed(1)
        device = torch.device("cuda")
        class LinearModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(2, 4)
            def forward(self, y=None, x=None):
                if y is not None:
                    return self.linear(x) + y
                else:
                    return self.linear(x) + torch.ones(2, 4)

        pt_model = LinearModel()
        data = torch.randn(2, 2)
        label = torch.tensor([0, 1], dtype=torch.int64)
        input_desc = IODescription('x', [2, 2], torch.float32)
        label_desc = IODescription('label', [2, ], torch.int64, num_classes=4)
        output_desc = IODescription('output', [2, 4], torch.float32)
        loss_desc = IODescription('loss', [], torch.float32)
        model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc])
        def loss_fn(x, label):
            return F.nll_loss(F.log_softmax(x, dim=1), label)

        def get_lr_this_step(global_step):
            learningRate = 0.02
            return torch.tensor([learningRate])

        ort_trainer = ORTTrainer(
            pt_model, loss_fn, model_desc, "SGDOptimizer", None,
            IODescription('Learning_Rate', [1, ], torch.float32), device,
            get_lr_this_step=get_lr_this_step)
        ort_trainer.train_step(x=data, label=label)
        state_dict = ort_trainer.state_dict()
        assert state_dict.keys() == {'linear.bias', 'linear.weight'}
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(
        loss_scaler, legacy_loss_scaler):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.AdamConfig(lr=0.001)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'mixed_precision': {
            'enabled': True,
            'loss_scaler': loss_scaler
        }
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=legacy_loss_scaler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelGradientAccumulationLegacyExperimental(
        gradient_accumulation_steps):
    # Common setup
    total_steps = 10
    device = "cuda"
    seed = 1

    # EXPERIMENTAL IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'batch': {
            'gradient_accumulation_steps': gradient_accumulation_steps
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        loss = trainer.train_step(*sample_input)
        experimental_losses.append(loss.cpu().item())

    # LEGACY IMPLEMENTATION
    device = torch.device(device)
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "LambOptimizer",
        None,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
        gradient_accumulation_steps=gradient_accumulation_steps)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       rtol=1e-6)
    def get_onnx_model(self,
                       model,
                       model_desc,
                       inputs,
                       device,
                       _enable_internal_postprocess=True,
                       _extra_postprocess=None):
        lr_desc = IODescription('Learning_Rate', [
            1,
        ], torch.float32)
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes,
            lr_desc,
            device,
            world_rank=0,
            world_size=1,
            _opset_version=12,
            _enable_internal_postprocess=_enable_internal_postprocess,
            _extra_postprocess=_extra_postprocess)

        train_output = model.train_step(*inputs)
        return model.onnx_model_
def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config):
    # Common setup
    train_steps = 512

    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    opts =  orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    optim_config = optimizer_config(lr=0.01)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    if optimizer_config == optim.AdamConfig:
        legacy_optimizer = 'AdamOptimizer'
    elif optimizer_config == optim.LambConfig:
        legacy_optimizer = 'LambOptimizer'
    elif optimizer_config == optim.SGDConfig:
        legacy_optimizer = 'SGDOptimizer'
    else:
        raise RuntimeError("Invalid optimizer_config")

    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer,
                       None,
                       learning_rate_description,
                       device,
                       _use_deterministic_compute=True)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map):
    # Common setup
    total_steps = 128
    device = "cuda"
    seed = 1

    # EXPERIMENTAL API
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.AdamConfig(
        params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False
    )
    opts = orttrainer.ORTTrainerOptions(
        {
            "debug": {"deterministic_compute": True},
            "device": {
                "id": device,
            },
        }
    )
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr)

    legacy_trainer = Legacy_ORTTrainer(
        model,
        None,
        legacy_model_desc,
        "AdamOptimizer",
        legacy_optim_map,
        learning_rate_description,
        device,
        _use_deterministic_compute=True,
    )
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        legacy_sample_input = [*sample_input, learning_rate]
        legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def testToyBERTModelLegacyExperimentalBasicTraining():
    # Common setup
    train_steps = 10
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    params = optimizer_parameters(model)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        lr=0.001)
    legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc,
                                       "LambOptimizer", None,
                                       learning_rate_description, device)
    legacy_losses = []
    for i in range(train_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input, learning_rate)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses,
                                       legacy_losses,
                                       True,
                                       rtol=1e-5)
Пример #8
0
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
    # Common data
    total_steps = 5
    bptt=35

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    loss_scaler = amp.DynamicLossScaler()
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'mixed_precision' : {
                                                'enabled' : True,
                                                'loss_scaler' : loss_scaler},
                                            'debug' : {'deterministic_compute' : True,}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    experimental_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())
        experimental_preds_dtype.append(exp_preds.dtype)

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       use_mixed_precision=True,
                                       loss_scaler=loss_scaler)
    # Training loop
    legacy_loss = []
    legacy_preds_dtype = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())
        legacy_preds_dtype.append(leg_preds.dtype)

    # Compare legacy vs experimental APIs
    assert experimental_preds_dtype == legacy_preds_dtype
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
Пример #9
0
def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps):
    # Common data
    torch.set_printoptions(precision=10)

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
                                            'debug' : {'deterministic_compute' : True}})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optim.LambConfig(lr=0.001)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       gradient_accumulation_steps=gradient_accumulation_steps)
    # Training loop
    legacy_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
        legacy_loss.append(leg_loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
Пример #10
0
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
    # Common data
    total_steps = 5
    bptt = 35

    # Setup for the experimental ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'device' : {
            'id' : device
        },
        'debug' : {
            'deterministic_compute': True
        },
    })
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _ = trainer.train_step(data, targets)

    # Setup for the legacy ORTTrainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc,
                                       device, _use_deterministic_compute=True)
    # Training loop
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))

    # Compare legacy vs experimental APIs
    _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler,
                                                  legacy_lr_scheduler):
    ############################################################################
    # These tests require hard-coded values for 'total_steps' and 'initial_lr' #
    ############################################################################

    # Common setup
    total_steps = 128
    device = 'cuda'
    seed = 1
    warmup = 0.05
    cycles = 0.5
    power = 1.
    lr_end = 1e-7

    # Setup both Experimental and Legacy LR Schedulers before the experimental loop
    if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup)
    elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      cycles=cycles)
    elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler:
        legacy_lr_scheduler = partial(legacy_lr_scheduler,
                                      initial_lr=initial_lr,
                                      total_steps=total_steps,
                                      warmup=warmup,
                                      power=power,
                                      lr_end=lr_end)
    else:
        raise RuntimeError("Invalid legacy_lr_scheduler")
    if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps,
                                    warmup=warmup,
                                    power=power,
                                    lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'lr_scheduler': lr_scheduler
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())
        assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0],
                        legacy_lr_scheduler(i))

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    model = load_bert_onnx_model()
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        initial_lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
Пример #12
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step):
    # Common data
    total_steps = 10
    lr = 0.001
    warmup = 0.5
    cycles = 0.5
    power = 1.
    lr_end = 1e-7
    torch.set_printoptions(precision=10)

    # Setup experimental API
    torch.manual_seed(seed)
    set_seed(seed)
    if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
    elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
        lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid lr_scheduler")

    options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
                                            'debug' : {'deterministic_compute' : True},
                                            'lr_scheduler' : lr_scheduler})
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    optim_config = optimizer_config(lr=lr)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
    # Training loop
    experimental_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        exp_loss, exp_preds = trainer.train_step(data, targets)
        experimental_loss.append(exp_loss.cpu())

    # Setup legacy API
    torch.manual_seed(seed)
    set_seed(seed)

    if optimizer_config == optim.AdamConfig:
        legacy_optimizer_config = 'AdamOptimizer'
    elif optimizer_config == optim.LambConfig:
        legacy_optimizer_config = 'LambOptimizer'
    elif optimizer_config == optim.SGDConfig:
        legacy_optimizer_config = 'SGDOptimizer'
    else:
        raise RuntimeError("Invalid optimizer_config")

    if get_lr_this_step == _test_commons.legacy_constant_lr_scheduler or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup)
    elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles)
    elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler:
        get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
    else:
        raise RuntimeError("Invalid get_lr_this_step")

    model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
    legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, legacy_optimizer_config,
                                       None, lr_desc, device=device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=get_lr_this_step)
    # Training loop
    legacy_loss = []
    for i in range(total_steps):
        data, targets = batcher_fn(train_data, i)
        leg_loss, leg_preds = legacy_trainer.train_step(data, targets)
        legacy_loss.append(leg_loss.cpu())

    # Compare legacy vs experimental APIs
    _test_helpers.assert_model_outputs(legacy_loss, experimental_loss)
Пример #14
0
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler,
                                                  legacy_lr_scheduler):
    ############################################################################
    # These tests require hard-coded values for 'total_steps' and 'initial_lr' #
    ############################################################################

    # Common setup
    total_steps = 10
    device = 'cuda'
    seed = 1

    # EXPERIMENTAL API
    model_desc = bert_model_description()
    model = load_bert_onnx_model()
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.AdamConfig(lr=initial_lr)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
        'lr_scheduler':
        lr_scheduler(total_steps=total_steps, warmup=0.5)
    })
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    experimental_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        experimental_losses.append(
            trainer.train_step(*sample_input).cpu().item())
        assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0],
                        legacy_lr_scheduler(i))

    # LEGACY IMPLEMENTATION
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    device = torch.device(device)
    legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(
        initial_lr)
    legacy_trainer = Legacy_ORTTrainer(model,
                                       None,
                                       legacy_model_desc,
                                       "AdamOptimizer",
                                       None,
                                       learning_rate_description,
                                       device,
                                       _use_deterministic_compute=True,
                                       get_lr_this_step=legacy_lr_scheduler)
    legacy_losses = []
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, i)
        leg_loss = legacy_trainer.train_step(*sample_input)
        legacy_losses.append(leg_loss.cpu().item())

    # Check results
    _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def runBertTrainingTest(gradient_accumulation_steps,
                        use_mixed_precision,
                        allreduce_post_accumulation,
                        use_simple_model_desc=True,
                        use_internel_loss_scale=False):
    model_desc = bert_model_description()
    simple_model_desc = remove_extra_info(
        model_desc) if use_simple_model_desc else model_desc
    learning_rate_description = ort_trainer_learning_rate_description()
    device = torch.device("cuda", 0)

    torch.manual_seed(1)
    onnxruntime.set_seed(1)

    onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx"))

    loss_scaler = LossScaler("ort_test_input_loss_scalar",
                             True) if use_internel_loss_scale else None

    model = ORTTrainer(onnx_model,
                       None,
                       simple_model_desc,
                       "LambOptimizer",
                       map_optimizer_attributes,
                       learning_rate_description,
                       device,
                       postprocess_model=None,
                       gradient_accumulation_steps=gradient_accumulation_steps,
                       world_rank=0,
                       world_size=1,
                       loss_scaler=loss_scaler,
                       use_mixed_precision=use_mixed_precision,
                       allreduce_post_accumulation=allreduce_post_accumulation)

    if loss_scaler is None:
        loss_scaler = LossScaler(model.loss_scale_input_name, True)

    input_ids_batches = []
    segment_ids_batches = []
    input_mask_batches = []
    masked_lm_labels_batches = []
    next_sentence_labels_batches = []
    batch_size = 16
    num_batches = 8
    for batch in range(num_batches):
        input_ids_batches = [
            *input_ids_batches,
            generate_sample_batch(model_desc.inputs_[0], batch_size, device)
        ]
        segment_ids_batches = [
            *segment_ids_batches,
            generate_sample_batch(model_desc.inputs_[1], batch_size, device)
        ]
        input_mask_batches = [
            *input_mask_batches,
            generate_sample_batch(model_desc.inputs_[2], batch_size, device)
        ]
        masked_lm_labels_batches = [
            *masked_lm_labels_batches,
            generate_sample_batch(model_desc.inputs_[3], batch_size, device)
        ]
        next_sentence_labels_batches = [
            *next_sentence_labels_batches,
            generate_sample_batch(model_desc.inputs_[4], batch_size, device)
        ]

    lr_batch_list = [
        0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06,
        1.8404908e-06, 2.3006135e-06, 2.7607362e-06, 3.2208588e-06,
        3.6809815e-06
    ]

    actual_losses = []
    actual_all_finites = []

    for batch_count in range(num_batches):
        input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size,
                                          device)
        segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size,
                                            device)
        input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size,
                                           device)
        masked_lm_labels = generate_sample_batch(model_desc.inputs_[3],
                                                 batch_size, device)
        next_sentence_labels = generate_sample_batch(model_desc.inputs_[4],
                                                     batch_size, device)
        lr = lr_batch_list[batch_count]

        learning_rate = torch.tensor([lr]).to(device)
        training_args = [
            input_ids, segment_ids, input_mask, masked_lm_labels,
            next_sentence_labels, learning_rate
        ]
        if use_mixed_precision:
            if not use_internel_loss_scale:
                loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device)
                training_args.append(loss_scale)
            actual_loss = model.train_step(*training_args)
            if isinstance(actual_loss, (list, tuple)):
                assert len(actual_loss) == 2
                actual_loss, actual_all_finite = actual_loss
                if not use_internel_loss_scale:
                    loss_scaler.update_loss_scale(actual_all_finite.item())
                    actual_all_finites = [
                        *actual_all_finites,
                        actual_all_finite.cpu().numpy().item(0)
                    ]

            actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)]
        else:
            loss = model(*training_args)
            actual_losses = [*actual_losses, loss.cpu().numpy().item(0)]

        if batch_count == num_batches - 1:
            # test eval_step api with fetches at the end of the training.
            # if eval_step is called during the training, it will affect the actual training loss (training session is stateful),
            eval_loss = model.eval_step(input_ids,
                                        segment_ids,
                                        input_mask,
                                        masked_lm_labels,
                                        next_sentence_labels,
                                        fetches=['loss'])
            eval_loss = eval_loss.cpu().numpy().item(0)

    # If using internal loss scale, all_finites are handled internally too.
    if use_mixed_precision and not use_internel_loss_scale:
        return actual_losses, actual_all_finites, eval_loss
    else:
        return actual_losses, eval_loss