コード例 #1
0
    def test_bert(self):
        device = torch.device("cpu")

        model_tester = BertModelTest.BertModelTester(self)
        config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = model_tester.prepare_config_and_inputs(
        )

        model = BertForPreTraining(config=config)
        model.eval()

        loss, prediction_scores, seq_relationship_score = model(
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            masked_lm_labels=token_labels,
            next_sentence_label=sequence_labels)

        model_desc = ModelDescription([
            model_tester.input_ids_desc, model_tester.attention_mask_desc,
            model_tester.token_type_ids_desc,
            model_tester.masked_lm_labels_desc,
            model_tester.next_sentence_label_desc
        ], [
            model_tester.loss_desc, model_tester.prediction_scores_desc,
            model_tester.seq_relationship_scores_desc
        ])

        from collections import namedtuple
        MyArgs = namedtuple(
            "MyArgs",
            "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len"
        )
        args = MyArgs(local_rank=0,
                      world_size=1,
                      max_steps=100,
                      learning_rate=0.00001,
                      warmup_proportion=0.01,
                      batch_size=13,
                      seq_len=7)

        dataset_len = 100
        dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                                args.batch_size, args.seq_len,
                                                dataset_len, device)
        learning_rate = torch.tensor(1.0e+0, dtype=torch.float32).to(device)
        for b in dataloader:
            batch = b
            break
        learning_rate = torch.tensor([1.00e+00]).to(device)
        inputs = batch + [
            learning_rate,
        ]

        onnx_model = self.get_onnx_model(model,
                                         model_desc,
                                         inputs,
                                         device,
                                         _extra_postprocess=postprocess_model)

        self._bert_helper(onnx_model)
コード例 #2
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
コード例 #3
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            device)

    model = ORTTrainer(
        model,
        None,
        model_desc,
        "LambOptimizer",
        map_optimizer_attributes=map_optimizer_attributes,
        learning_rate_description=IODescription('Learning_Rate', [
            1,
        ], torch.float32),
        device=device,
        postprocess_model=postprocess_model,
        gradient_accumulation_steps=gradient_accumulation_steps,
        # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
        world_rank=args.local_rank,
        world_size=args.world_size,
        use_mixed_precision=fp16,
        allreduce_post_accumulation=allreduce_post_accumulation,
        get_lr_this_step=get_lr_this_step
        if use_internal_get_lr_this_step else None,
        loss_scaler=loss_scaler if use_internal_loss_scaler else None,
        _opset_version=12)

    # trainig loop
    eval_batch = None
    model.train()
    for step, batch in enumerate(dataloader):
        if eval_batch is None:
            eval_batch = batch

        if not use_internal_get_lr_this_step:
            lr = get_lr_this_step(step)
            learning_rate = torch.tensor([lr])

        if not use_internal_loss_scaler and fp16:
            loss_scale = torch.tensor([loss_scaler.loss_scale_])

        if batch_args_option == BatchArgsOption.List:
            if not use_internal_get_lr_this_step:
                batch = batch + [
                    learning_rate,
                ]
            if not use_internal_loss_scaler and fp16:
                batch = batch + [
                    loss_scale,
                ]
            outputs = model(*batch)
        elif batch_args_option == BatchArgsOption.Dict:
            args, kwargs = split_batch(batch, model_desc.inputs_, 0)
            if not use_internal_get_lr_this_step:
                kwargs['Learning_Rate'] = learning_rate
            if not use_internal_loss_scaler and fp16:
                kwargs[model.loss_scale_input_name] = loss_scale
            outputs = model(*args, **kwargs)
        else:
            args_count = int(len(model_desc.inputs_) /
                             2)  # approx helf args, half kwargs
            args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
            if not use_internal_get_lr_this_step:
                kwargs['Learning_Rate'] = learning_rate
            if not use_internal_loss_scaler and fp16:
                kwargs[model.loss_scale_input_name] = loss_scale
            outputs = model(*args, **kwargs)

        print(outputs[0])

    # eval
    model.eval()
    if batch_args_option == BatchArgsOption.List:
        outputs = model(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)