Exemple #1
0
    def testWrapModelLossFnStateDict(self):
        torch.manual_seed(1)
        device = torch.device("cuda")
        class LinearModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(2, 4)
            def forward(self, y=None, x=None):
                if y is not None:
                    return self.linear(x) + y
                else:
                    return self.linear(x) + torch.ones(2, 4)

        pt_model = LinearModel()
        data = torch.randn(2, 2)
        label = torch.tensor([0, 1], dtype=torch.int64)
        input_desc = IODescription('x', [2, 2], torch.float32)
        label_desc = IODescription('label', [2, ], torch.int64, num_classes=4)
        output_desc = IODescription('output', [2, 4], torch.float32)
        loss_desc = IODescription('loss', [], torch.float32)
        model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc])
        def loss_fn(x, label):
            return F.nll_loss(F.log_softmax(x, dim=1), label)

        def get_lr_this_step(global_step):
            learningRate = 0.02
            return torch.tensor([learningRate])

        ort_trainer = ORTTrainer(
            pt_model, loss_fn, model_desc, "SGDOptimizer", None,
            IODescription('Learning_Rate', [1, ], torch.float32), device,
            get_lr_this_step=get_lr_this_step)
        ort_trainer.train_step(x=data, label=label)
        state_dict = ort_trainer.state_dict()
        assert state_dict.keys() == {'linear.bias', 'linear.weight'}
Exemple #2
0
def bart_model_description(args):
    vocab_size = 50349
    batch = 3
    max_tokens_valid = 1023
    max_tokens = 3069
    #'''
    # allow variable input sizes:
    src_tokens_desc = IODescription('src_tokens', ['batch', 'max_src_tokens'],
                                    torch.int64,
                                    num_classes=vocab_size)
    src_lengths_desc = IODescription('src_lengths', ['batch'],
                                     torch.int64,
                                     num_classes=args.max_tokens_valid)
    prev_output_tokens_desc = IODescription('prev_output_tokens',
                                            ['batch', 'max_out_tokens'],
                                            torch.int64,
                                            num_classes=vocab_size)
    target_desc = IODescription('target', ['max_tgt_tokens'],
                                torch.int64,
                                num_classes=vocab_size)
    #'''
    '''
    # set concrete input sizes to permit optimization
    src_tokens_desc = IODescription('src_tokens', [batch, max_tokens_valid], torch.int64, num_classes = vocab_size)
    src_lengths_desc = IODescription('src_lengths', [batch], torch.int64, num_classes = args.max_tokens_valid)
    prev_output_tokens_desc = IODescription('prev_output_tokens', [batch, max_tokens_valid], torch.int64, num_classes = vocab_size)
    target_desc = IODescription('target', [max_tokens], torch.int64, num_classes = vocab_size)
    '''
    loss_desc = IODescription('loss', [], torch.float32)
    #return ModelDescription([src_tokens_desc, src_lengths_desc, prev_output_tokens_desc, target_desc], [loss_desc])
    return ModelDescription(
        [src_tokens_desc, prev_output_tokens_desc, target_desc], [loss_desc])
    def test_layer_norm(self):
        class LayerNormNet(nn.Module):
            def __init__(self, target):
                super(LayerNormNet, self).__init__()
                self.ln_1 = nn.LayerNorm(10)
                self.loss = nn.CrossEntropyLoss()
                self.target = target

            def forward(self, x):
                output1 = self.ln_1(x)
                loss = self.loss(output1, self.target)
                return loss, output1

        device = torch.device("cpu")
        target = torch.ones(20, 10, 10, dtype=torch.int64).to(device)
        model = LayerNormNet(target)
        input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device)

        input_desc = IODescription('input', [], "float32")
        output0_desc = IODescription('output0', [], "float32")
        output1_desc = IODescription('output1', [20, 5, 10, 10], "float32")
        model_desc = ModelDescription([input_desc], [output0_desc, output1_desc])

        learning_rate = torch.tensor([1.0000000e+00]).to(device)
        input_args=[input, learning_rate]

        onnx_model = self.get_onnx_model(model, model_desc, input_args, device)

        count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization")
        count_nodes = self.count_all_nodes(onnx_model)

        assert count_layer_norm == 1
        assert count_nodes == 3
def model_description():
    input_desc = IODescription('src', [bptt, batch_size], torch.float32)
    label_desc = IODescription('label', [bptt, batch_size, ntokens],
                               torch.int64)
    loss_desc = IODescription('loss', [], torch.float32)
    output_desc = IODescription('output', [bptt, batch_size, ntokens],
                                torch.float32)
    return ModelDescription([input_desc, label_desc], [loss_desc, output_desc])
def transformer_model_description():
    input_desc = IODescription('input1', [bptt, batch_size], torch.float32)
    label_desc = IODescription('label', [bptt, batch_size, ntokens],
                               torch.int64)
    loss_desc = IODescription('loss', [], torch.float32)
    return ModelDescription([input_desc, label_desc],
                            [loss_desc]), IODescription(
                                'Learning_Rate', [
                                    lr,
                                ], torch.float32)
Exemple #6
0
 def mnist_model_description():
     input_desc = IODescription('input1', ['batch', 784], torch.float32)
     label_desc = IODescription('label', [
         'batch',
     ],
                                torch.int64,
                                num_classes=10)
     loss_desc = IODescription('loss', [], torch.float32)
     probability_desc = IODescription('probability', ['batch', 10],
                                      torch.float32)
     return ModelDescription([input_desc, label_desc],
                             [loss_desc, probability_desc])
def mnist_model_description():
    input_desc = IODescription("input1", ["batch", 784], torch.float32)
    label_desc = IODescription(
        "label",
        [
            "batch",
        ],
        torch.int64,
        num_classes=10,
    )
    loss_desc = IODescription("loss", [], torch.float32)
    probability_desc = IODescription("probability", ["batch", 10],
                                     torch.float32)
    return ModelDescription([input_desc, label_desc],
                            [loss_desc, probability_desc])
Exemple #8
0
    def testTrainingAndEvalDropout(self):
        # Temporarily disable this test.
        # The graph below will trigger ORT
        # to sort backward graph before forward graph which gives incorrect result.
        # TODO Re-enable when that is fixed.
        return
        class TwoDropoutNet(nn.Module):
            def __init__(self, drop_prb_1, drop_prb_2, dim_size):
                super(TwoDropoutNet, self).__init__()
                self.drop_1 = nn.Dropout(drop_prb_1)
                self.drop_2 = nn.Dropout(drop_prb_2)
                self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32))
            def forward(self, x):
                x = x + self.weight_1
                x = self.drop_1(x)
                x = self.drop_2(x)
                output = x
                return output[0]
        dim_size = 3
        device = torch.device("cuda", 0)
        # This will drop all values, therefore expecting all 0 in output tensor
        model = TwoDropoutNet(0.999, 0.999, dim_size)
        input_desc = IODescription('input', [dim_size], torch.float32)
        output_desc = IODescription('output', [], torch.float32)
        model_desc = ModelDescription([input_desc], [output_desc])
        lr_desc = ort_trainer_learning_rate_description()
        model = ORTTrainer(model, None, model_desc, "LambOptimizer",
                        map_optimizer_attributes,
                        lr_desc,
                        device,
                        postprocess_model=process_dropout,
                        world_rank=0, world_size=1)
        input = torch.ones(dim_size, dtype=torch.float32).to(device)
        expected_training_output = [0.0]
        expected_eval_output = [1.0]
        learning_rate = torch.tensor([1.0000000e+00]).to(device)
        input_args=[input, learning_rate]
        train_output = model.train_step(*input_args)

        rtol = 1e-04
        assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch")

        eval_output = model.eval_step(input)
        assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch")
 
        # Do another train step to make sure it's using original ratios
        train_output_2 = model.train_step(*input_args)
        assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch")
 def get_trainer(
     self,
     model,
     model_desc,
     device,
     onnx_opset_ver=12,
     frozen_weights=[],
     internal_loss_fn=False,
     get_lr_this_step=None,
     optimizer="SGDOptimizer",
 ):
     loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None
     return ORTTrainer(
         model,
         loss_fn,
         model_desc,
         optimizer,
         None,
         IODescription(
             "Learning_Rate",
             [
                 1,
             ],
             torch.float32,
         ),
         device,
         _opset_version=onnx_opset_ver,
         frozen_weights=frozen_weights,
         get_lr_this_step=get_lr_this_step,
     )
Exemple #10
0
def create_ort_trainer(args, device, model):
    # set GPU memory limitation
    from onnxruntime.capi._pybind_state import set_cuda_mem_limit
    ort_cuda_mem_limit_in_gbs = 1
    set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 *1024))

    # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
    def map_optimizer_attributes(name):
        no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
        no_decay = False
        for no_decay_key in no_decay_keys:
            if no_decay_key in name:
                no_decay = True
                break
        if no_decay:
            return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
        else:
            return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}

    # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. 
    # train_step does forward, backward, and optimize step.
    model = ORTTrainer(model, None, bert_model_description(args), "LambOptimizer", 
        map_optimizer_attributes,
        IODescription('Learning_Rate', [1,], torch.float32),
        device,
        _opset_version = 10)

    if args.fp16:
        setattr(args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000))

    return model
    def get_onnx_model(self,
                       model,
                       model_desc,
                       inputs,
                       device,
                       _enable_internal_postprocess=True,
                       _extra_postprocess=None):
        lr_desc = IODescription('Learning_Rate', [
            1,
        ], torch.float32)
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes,
            lr_desc,
            device,
            world_rank=0,
            world_size=1,
            _opset_version=12,
            _enable_internal_postprocess=_enable_internal_postprocess,
            _extra_postprocess=_extra_postprocess)

        train_output = model.train_step(*inputs)
        return model.onnx_model_
def ort_trainer_learning_rate_description():
    return IODescription(
        "Learning_Rate",
        [
            1,
        ],
        torch.float32,
    )
def create_ort_trainer(args, device, model):

    # set GPU memory limitation (per card!)
    from onnxruntime.capi._pybind_state import set_cuda_mem_limit
    ort_cuda_mem_limit_in_gbs = args.gpu_memory_limit_gb
    set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 * 1024))

    # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
    def map_optimizer_attributes(name):
        no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
        no_decay = False
        for no_decay_key in no_decay_keys:
            if no_decay_key in name:
                no_decay = True
                break
        if no_decay:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.0,
                "epsilon": 1e-6
            }
        else:
            return {
                "alpha": 0.9,
                "beta": 0.999,
                "lambda": 0.01,
                "epsilon": 1e-6
            }

    # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes.
    # train_step does forward, backward, and optimize step.
    model = ORTTrainer(
        model,
        None,
        bert_model_description(args),
        "LambOptimizer",
        map_optimizer_attributes,
        IODescription('Learning_Rate', [
            1,
        ], torch.float32),
        device,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        world_rank=args.world_rank,
        world_size=args.world_size,
        use_mixed_precision=True if args.fp16 else False,
        allreduce_post_accumulation=True
        if args.allreduce_post_accumulation else False,
        deepspeed_zero_stage=1 if args.deepspeed_zero_stage else 0,
        _opset_version=12)

    if args.fp16:
        setattr(
            args, 'ort_loss_scale',
            LossScaler(model.loss_scale_input_name, True,
                       up_scale_window=2000))

    return model
        def create_and_check_bert_model(self, config, input_ids,
                                        token_type_ids, input_mask,
                                        sequence_labels, token_labels,
                                        choice_labels):
            model = BertModel(config=config)
            model.to(input_ids.device)
            model.eval()

            sequence_output, pooled_output = model(
                input_ids,
                attention_mask=input_mask,
                token_type_ids=token_type_ids)

            # failed because there is not loss output
            model_desc = ModelDescription([
                self.input_ids_desc, self.attention_mask_desc,
                self.token_type_ids_desc
            ], [self.last_hidden_state_desc, self.pooler_output_desc])
            args_gradient_accumulation_steps = 8
            args_local_rank = 0
            args_world_size = 1
            args_fp16 = True
            args_allreduce_post_accumulation = True

            model = ORTTrainer(
                model,
                None,
                model_desc,
                "LambOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription(
                    'Learning_Rate', [
                        1,
                    ], torch.float32),
                device=self.device,
                postprocess_model=postprocess_model,
                gradient_accumulation_steps=args_gradient_accumulation_steps,
                world_rank=args_local_rank,
                world_size=args_world_size,
                use_mixed_precision=True if args_fp16 else False,
                allreduce_post_accumulation=True
                if args_allreduce_post_accumulation else False)

            sequence_output, pooled_output = model(
                input_ids, token_type_ids=token_type_ids)
            sequence_output, pooled_output = model(input_ids)

            result = {
                "sequence_output": sequence_output,
                "pooled_output": pooled_output,
            }
            self.parent.assertListEqual(
                list(result["sequence_output"].size()),
                [self.batch_size, self.seq_length, self.hidden_size])
            self.parent.assertListEqual(list(result["pooled_output"].size()),
                                        [self.batch_size, self.hidden_size])
    def gpt2_model_description(self, n_head, vocab_size, n_hidden, n_layer,
                               n_ctx, batch_size):

        logger.info("****num of head is: {}".format(n_head))
        logger.info("****vocab size is: {}".format(vocab_size))
        logger.info("****num of hidden layer is: {}".format(n_hidden))
        logger.info("****num of layer is: {}".format(n_layer))
        logger.info("****seq length is: {}".format(n_ctx))

        input_ids_desc = IODescription('input_ids', [batch_size, n_ctx],
                                       torch.int64,
                                       num_classes=vocab_size)
        labels_desc = IODescription('labels', [batch_size, n_ctx],
                                    torch.int64,
                                    num_classes=vocab_size)

        loss_desc = IODescription('loss', [], torch.float32)

        return ModelDescription([input_ids_desc, labels_desc], [loss_desc])
 def get_trainer(self, model, model_desc, device, onnx_opset_ver=12):
     return ORTTrainer(model,
                       MNISTWrapper.my_loss,
                       model_desc,
                       "SGDOptimizer",
                       None,
                       IODescription('Learning_Rate', [
                           1,
                       ], torch.float32),
                       device,
                       _opset_version=onnx_opset_ver)
    def test_expand(self):
        class ExpandNet(nn.Module):
            def __init__(self, target):
                super(ExpandNet, self).__init__()
                self.loss = nn.CrossEntropyLoss()
                self.target = target
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x, x1):
                output = x.expand_as(x1)
                output = self.linear(output)
                output = output + output
                loss = self.loss(output, self.target)
                return loss, output

        device = torch.device("cpu")
        target = torch.ones(5, 5, 2, dtype=torch.int64).to(device)
        model = ExpandNet(target).to(device)

        x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device)
        x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device)

        input0_desc = IODescription('x', [5, 3, 1, 2], "float32")
        input1_desc = IODescription('x1', [5, 3, 5, 2], "float32")
        output0_desc = IODescription('output0', [], "float32")
        output1_desc = IODescription('output1', [5, 3, 5, 2], "float32")
        model_desc = ModelDescription([input0_desc, input1_desc],
                                      [output0_desc, output1_desc])

        learning_rate = torch.tensor([1.0000000e+00]).to(device)
        input_args = [x, x1, learning_rate]

        onnx_model = self.get_onnx_model(model, model_desc, input_args, device)

        # check that expand output has shape
        expand_nodes = self.find_nodes(onnx_model, "Expand")
        assert len(expand_nodes) == 1

        model_info = onnx_model.graph.value_info
        assert model_info[0].name == expand_nodes[0].output[0]
        assert model_info[0].type == onnx_model.graph.input[1].type
Exemple #18
0
def bert_model_description():
    vocab_size = 30528
    input_ids_desc = IODescription('input_ids',
                                   ['batch', 'max_seq_len_in_batch'],
                                   torch.int64,
                                   num_classes=vocab_size)
    segment_ids_desc = IODescription('segment_ids',
                                     ['batch', 'max_seq_len_in_batch'],
                                     torch.int64,
                                     num_classes=2)
    input_mask_desc = IODescription('input_mask',
                                    ['batch', 'max_seq_len_in_batch'],
                                    torch.int64,
                                    num_classes=2)
    masked_lm_labels_desc = IODescription('masked_lm_labels',
                                          ['batch', 'max_seq_len_in_batch'],
                                          torch.int64,
                                          num_classes=vocab_size)
    next_sentence_labels_desc = IODescription('next_sentence_labels', [
        'batch',
    ],
                                              torch.int64,
                                              num_classes=2)
    loss_desc = IODescription('loss', [], torch.float32)

    return ModelDescription([
        input_ids_desc, segment_ids_desc, input_mask_desc,
        masked_lm_labels_desc, next_sentence_labels_desc
    ], [loss_desc])
        def create_and_check_bert_for_masked_lm(self, config, input_ids,
                                                token_type_ids, input_mask,
                                                sequence_labels, token_labels,
                                                choice_labels):
            model = BertForMaskedLM(config=config)
            model.eval()
            loss, prediction_scores = model(input_ids,
                                            attention_mask=input_mask,
                                            token_type_ids=token_type_ids,
                                            masked_lm_labels=token_labels)

            #####
            model_desc = ModelDescription([
                self.input_ids_desc, self.attention_mask_desc,
                self.token_type_ids_desc, self.masked_lm_labels_desc
            ], [self.loss_desc, self.prediction_scores_desc])
            args_gradient_accumulation_steps = 8
            args_local_rank = 0
            args_world_size = 1
            args_fp16 = True
            args_allreduce_post_accumulation = True

            model = ORTTrainer(
                model,
                None,
                model_desc,
                "LambOptimizer",
                map_optimizer_attributes=map_optimizer_attributes,
                learning_rate_description=IODescription(
                    'Learning_Rate', [
                        1,
                    ], torch.float32),
                device=self.device,
                postprocess_model=postprocess_model,
                gradient_accumulation_steps=args_gradient_accumulation_steps,
                world_rank=args_local_rank,
                world_size=args_world_size,
                use_mixed_precision=True if args_fp16 else False,
                allreduce_post_accumulation=True
                if args_allreduce_post_accumulation else False)
            model(input_ids,
                  attention_mask=input_mask,
                  token_type_ids=token_type_ids,
                  masked_lm_labels=token_labels)
Exemple #20
0
def bert_model_description(args):
    vocab_size = 30528

    # allow variable input sizes:
    # input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = vocab_size)
    # segment_ids_desc = IODescription('segment_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = 2)
    # input_mask_desc = IODescription('input_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = 2)
    # masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = vocab_size)
    # next_sentence_labels_desc = IODescription('next_sentence_labels', ['batch',], torch.int64, num_classes = 2)

    # set concrete input sizes to permit optimization
    input_ids_desc = IODescription('input_ids', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = vocab_size)
    segment_ids_desc = IODescription('segment_ids', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = 2)
    input_mask_desc = IODescription('input_mask', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = 2)
    masked_lm_labels_desc = IODescription('masked_lm_labels', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = vocab_size)
    next_sentence_labels_desc = IODescription('next_sentence_labels', [args.train_batch_size,2], torch.int64, num_classes = 2)

    loss_desc = IODescription('loss', [], torch.float32)
    return ModelDescription([input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], [loss_desc])
def bert_model_description():
    vocab_size = 30528
    input_ids_desc = IODescription(
        "input_ids",
        ["batch", "max_seq_len_in_batch"],
        torch.int64,
        num_classes=vocab_size,
    )
    segment_ids_desc = IODescription("segment_ids",
                                     ["batch", "max_seq_len_in_batch"],
                                     torch.int64,
                                     num_classes=2)
    input_mask_desc = IODescription("input_mask",
                                    ["batch", "max_seq_len_in_batch"],
                                    torch.int64,
                                    num_classes=2)
    masked_lm_labels_desc = IODescription(
        "masked_lm_labels",
        ["batch", "max_seq_len_in_batch"],
        torch.int64,
        num_classes=vocab_size,
    )
    next_sentence_labels_desc = IODescription(
        "next_sentence_labels",
        [
            "batch",
        ],
        torch.int64,
        num_classes=2,
    )
    loss_desc = IODescription("loss", [], torch.float32)

    return ModelDescription(
        [
            input_ids_desc,
            segment_ids_desc,
            input_mask_desc,
            masked_lm_labels_desc,
            next_sentence_labels_desc,
        ],
        [loss_desc],
    )
    def model_to_desc(self, model_name, model):
        if model_name.startswith('bert') or model_name.startswith('xlnet'):
            new_model_desc = {
                'inputs': [(
                    'input_ids',
                    ['batch', 'max_seq_len_in_batch'],
                ), (
                    'attention_mask',
                    ['batch', 'max_seq_len_in_batch'],
                ), (
                    'token_type_ids',
                    ['batch', 'max_seq_len_in_batch'],
                ), (
                    'labels',
                    [
                        'batch',
                    ],
                )],
                'outputs': [('loss', [], True), ('logits', ['batch', 2])]
            }
            model_desc = ModelDescription([
                IODescription('input_ids', ['batch', 'max_seq_len_in_batch']),
                IODescription('attention_mask',
                              ['batch', 'max_seq_len_in_batch']),
                IODescription('token_type_ids',
                              ['batch', 'max_seq_len_in_batch']),
                IODescription('labels', [
                    'batch',
                ])
            ], [
                IODescription('loss', []),
                IODescription('logits', ['batch', 2])
            ])
        elif model_name.startswith('roberta'):
            new_model_desc = {
                'inputs': [(
                    'input_ids',
                    ['batch', 'max_seq_len_in_batch'],
                ), (
                    'attention_mask',
                    ['batch', 'max_seq_len_in_batch'],
                ), (
                    'labels',
                    [
                        'batch',
                    ],
                )],
                'outputs': [('loss', [], True), ('logits', ['batch', 2])]
            }
            model_desc = ModelDescription([
                IODescription('input_ids', ['batch', 'max_seq_len_in_batch']),
                IODescription('attention_mask',
                              ['batch', 'max_seq_len_in_batch']),
                IODescription('labels', [
                    'batch',
                ])
            ], [
                IODescription('loss', []),
                IODescription('logits', ['batch', 2])
            ])
        else:
            raise RuntimeError(
                "unsupported base model name {}.".format(model_name))

        return model_desc, new_model_desc
    def run_multiple_choice(self, model_name, task_name, fp16):
        model_args = ModelArguments(model_name_or_path=model_name,
                                    cache_dir=self.cache_dir)
        data_args = DataTrainingArguments(task_name=task_name,
                                          data_dir=self.data_dir,
                                          max_seq_length=self.max_seq_length)

        training_args = TrainingArguments(
            output_dir=os.path.join(self.output_dir, task_name),
            do_train=True,
            do_eval=True,
            per_gpu_train_batch_size=self.train_batch_size,
            per_gpu_eval_batch_size=self.eval_batch_size,
            learning_rate=self.learning_rate,
            num_train_epochs=self.num_train_epochs,
            local_rank=self.local_rank,
            overwrite_output_dir=self.overwrite_output_dir,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            fp16=fp16,
            logging_steps=self.logging_steps)

        # Setup logging
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO
            if training_args.local_rank in [-1, 0] else logging.WARN,
        )
        logger.warning(
            "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
            training_args.local_rank,
            training_args.device,
            training_args.n_gpu,
            bool(training_args.local_rank != -1),
            training_args.fp16,
        )
        logger.info("Training/evaluation parameters %s", training_args)

        set_seed(training_args.seed)
        onnxruntime.set_seed(training_args.seed)

        try:
            processor = SwagProcessor()
            label_list = processor.get_labels()
            num_labels = len(label_list)
        except KeyError:
            raise ValueError("Task not found: %s" % (data_args.task_name))

        config = AutoConfig.from_pretrained(
            model_args.config_name
            if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name
            if model_args.tokenizer_name else model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )

        model = AutoModelForMultipleChoice.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
        )

        # Get datasets
        train_dataset = (MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            processor=processor,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        ) if training_args.do_train else None)
        eval_dataset = (MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            processor=processor,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        ) if training_args.do_eval else None)

        def compute_metrics(p: EvalPrediction) -> Dict:
            preds = np.argmax(p.predictions, axis=1)
            return {"acc": simple_accuracy(preds, p.label_ids)}

        if model_name.startswith('bert'):
            model_desc = ModelDescription([
                IODescription('input_ids', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=2),
                IODescription('token_type_ids', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=2),
                IODescription('labels', [self.train_batch_size, num_labels],
                              torch.int64,
                              num_classes=num_labels)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('reshaped_logits',
                              [self.train_batch_size, num_labels],
                              torch.float32)
            ])
        else:
            model_desc = ModelDescription([
                IODescription('input_ids',
                              ['batch', num_labels, 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask',
                              ['batch', num_labels, 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=2),
                IODescription('labels', ['batch', num_labels],
                              torch.int64,
                              num_classes=num_labels)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('reshaped_logits', ['batch', num_labels],
                              torch.float32)
            ])

        # Initialize the ORTTrainer within ORTTransformerTrainer
        trainer = ORTTransformerTrainer(
            model=model,
            model_desc=model_desc,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
        )

        # Training
        if training_args.do_train:
            trainer.train()
            trainer.save_model()

        # Evaluation
        results = {}
        if training_args.do_eval and training_args.local_rank in [-1, 0]:
            logger.info("*** Evaluate ***")

            result = trainer.evaluate()

            logger.info("***** Eval results {} *****".format(
                data_args.task_name))
            for key, value in result.items():
                logger.info("  %s = %s", key, value)

            results.update(result)

        return results
    def test_extra_postpass(self):
        def postpass_replace_first_add_with_sub(model):
            # this post pass replaces the first Add node with Sub in the model.
            # Previous graph
            #   (subgraph 1)        (subgraph 2)
            #        |                   |
            #        |                   |
            #        |________   ________|
            #                 | |
            #                 Add
            #                  |
            #             (subgraph 3)
            #
            # Post graph
            #   (subgraph 1)        (subgraph 2)
            #        |                   |
            #        |                   |
            #        |________   ________|
            #                 | |
            #                 Sub
            #                  |
            #             (subgraph 3)
            add_nodes = [n for n in model.graph.node if n.op_type == 'Add']
            add_nodes[0].op_type = "Sub"

        class MultiAdd(nn.Module):
            def __init__(self, target):
                super(MultiAdd, self).__init__()
                self.loss = nn.CrossEntropyLoss()
                self.target = target
                self.linear = torch.nn.Linear(2, 2, bias=False)

            def forward(self, x, x1):
                output = x + x1
                output = output + x
                output = output + x1
                output = self.linear(output)
                loss = self.loss(output, self.target)
                return loss, output

        device = torch.device("cpu")
        target = torch.ones(5, 2, dtype=torch.int64).to(device)
        model = MultiAdd(target).to(device)

        x = torch.randn(5, 5, 2, dtype=torch.float32).to(device)
        x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device)

        input0_desc = IODescription('x', [5, 5, 2], "float32")
        input1_desc = IODescription('x1', [5, 5, 2], "float32")
        output0_desc = IODescription('output0', [], "float32")
        output1_desc = IODescription('output1', [5, 5, 2], "float32")
        model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc])

        learning_rate = torch.tensor([1.0000000e+00]).to(device)
        input_args = [x, x1, learning_rate]

        onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
                _extra_postprocess=postpass_replace_first_add_with_sub)

        # check that extra postpass is called, and called only once.
        add_nodes = self.find_nodes(onnx_model, "Add")
        sub_nodes = self.find_nodes(onnx_model, "Sub")
        assert len(add_nodes) == 2
        assert len(sub_nodes) == 1


        unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device,
                _extra_postprocess=None, _enable_internal_postprocess=False)
        # check that the model is unchanged.
        add_nodes = self.find_nodes(unprocessed_onnx_model, "Add")
        sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub")
        assert len(add_nodes) == 3
        assert len(sub_nodes) == 0

        processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device,
                _extra_postprocess=postpass_replace_first_add_with_sub)
        # check that extra postpass is called, and called only once.
        add_nodes = self.find_nodes(processed_onnx_model, "Add")
        sub_nodes = self.find_nodes(processed_onnx_model, "Sub")
        assert len(add_nodes) == 2
        assert len(sub_nodes) == 1
Exemple #25
0
    def train(self):
        """
        Main training entry point.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        get_lr_this_step = get_linear_schedule_with_warmup(
            self.args.warmup_steps, t_total, self.args.learning_rate)
        loss_scaler = LossScaler('loss_scale_input_name',
                                 True,
                                 up_scale_window=2000)

        def map_optimizer_attributes(name):
            # no_decay_keys = ["bias", "LayerNorm.weight"]
            no_decay = "bias" in name or "LayerNorm.weight" in name
            if no_decay:
                return {"weight_decay": 0.0, "weight_decay_mode": 1}
            else:
                return {
                    "weight_decay": self.args.weight_decay,
                    "weight_decay_mode": 1
                }

        self.model = ORTTrainer(
            self.model,
            None,
            self.model_desc,
            "AdamOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=self.args.device,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            world_rank=0,
            world_size=1,  # only support single GPU cases
            use_mixed_precision=self.args.fp16,
            allreduce_post_accumulation=True,
            get_lr_this_step=get_lr_this_step,
            loss_scaler=loss_scaler,
            enable_grad_norm_clip=False,
            _opset_version=12,
            _use_deterministic_compute=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataloader.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.args.train_batch_size *
            self.args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if self.args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        tr_loss = 0.0
        logging_loss = 0.0
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(self.model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    global_step += 1

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            learning_rate_scalar = get_lr_this_step(
                                global_step)
                            logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        logger.info("\n\nTraining completed. \n\n")
        return TrainOutput(global_step, tr_loss / global_step)
Exemple #26
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument("--batch-size",
                        type=int,
                        default=64,
                        metavar="N",
                        help="input batch size for training (default: 64)")
    parser.add_argument("--test-batch-size",
                        type=int,
                        default=1000,
                        metavar="N",
                        help="input batch size for testing (default: 1000)")
    parser.add_argument("--epochs",
                        type=int,
                        default=10,
                        metavar="N",
                        help="number of epochs to train (default: 10)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.01,
                        metavar="LR",
                        help="learning rate (default: 0.01)")
    parser.add_argument("--no-cuda",
                        action="store_true",
                        default=False,
                        help="disables CUDA training")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    kwargs = {"num_workers": 0, "pin_memory": True}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    comm = MPI.COMM_WORLD
    args.local_rank = (int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if
                       ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0)
    args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if (
        "OMPI_COMM_WORLD_RANK" in os.environ) else 0
    args.world_size = comm.Get_size()
    if use_cuda:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        set_cuda_device_id(args.local_rank)
    else:
        device = torch.device("cpu")

    input_size = 784
    hidden_size = 500
    num_classes = 10
    model = NeuralNet(input_size, hidden_size, num_classes)

    model_desc = mnist_model_description()
    # use log_interval as gradient accumulate steps
    trainer = ORTTrainer(
        model,
        my_loss,
        model_desc,
        "SGDOptimizer",
        None,
        IODescription(
            "Learning_Rate",
            [
                1,
            ],
            torch.float32,
        ),
        device,
        1,
        args.world_rank,
        args.world_size,
        use_mixed_precision=False,
        allreduce_post_accumulation=True,
    )
    print("\nBuild ort model done.")

    for epoch in range(1, args.epochs + 1):
        train_with_trainer(args, trainer, device, train_loader, epoch)
        import pdb

        test_with_trainer(args, trainer, device, test_loader)
Exemple #27
0
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_mask=True,
            use_token_type_ids=True,
            use_labels=True,
            vocab_size=99,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            intermediate_size=37,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=16,
            type_sequence_label_size=2,
            initializer_range=0.02,
            num_labels=3,
            num_choices=4,
            scope=None,
            device='cpu',
        ):
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_mask = use_input_mask
            self.use_token_type_ids = use_token_type_ids
            self.use_labels = use_labels
            self.vocab_size = vocab_size
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.intermediate_size = intermediate_size
            self.hidden_act = hidden_act
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.type_sequence_label_size = type_sequence_label_size
            self.initializer_range = initializer_range
            self.num_labels = num_labels
            self.num_choices = num_choices
            self.scope = scope
            self.device = device

            # 1. superset of bert input/output descs
            # see BertPreTrainedModel doc
            self.input_ids_desc = IODescription(
                'input_ids', ['batch', 'max_seq_len_in_batch'],
                torch.int64,
                num_classes=self.vocab_size)
            self.attention_mask_desc = IODescription(
                'attention_mask', ['batch', 'max_seq_len_in_batch'],
                torch.int64,
                num_classes=2)
            self.token_type_ids_desc = IODescription(
                'token_type_ids', ['batch', 'max_seq_len_in_batch'],
                torch.int64,
                num_classes=2)
            self.position_ids_desc = IODescription(
                'position_ids', ['batch', 'max_seq_len_in_batch'],
                torch.int64,
                num_classes=self.max_position_embeddings)
            self.head_mask_desc = IODescription(
                'head_mask',
                [self.num_hidden_layers, self.num_attention_heads],
                torch.int64,
                num_classes=2)
            self.inputs_embeds_desc = IODescription(
                'inputs_embeds',
                ['batch', 'max_seq_len_in_batch', self.hidden_size],
                torch.float32)

            self.encoder_hidden_states_desc = IODescription(
                'encoder_hidden_states',
                ['batch', 'max_seq_len_in_batch', self.hidden_size],
                torch.float32)
            self.encoder_attention_mask_desc = IODescription(
                'encoder_attention_mask', ['batch', 'max_seq_len_in_batch'],
                torch.float32)

            # see BertForPreTraining doc
            self.masked_lm_labels_desc = IODescription(
                'masked_lm_labels', ['batch', 'max_seq_len_in_batch'],
                torch.int64,
                num_classes=self.vocab_size)
            self.next_sentence_label_desc = IODescription(
                'next_sentence_label', [
                    'batch',
                ], torch.int64, num_classes=2)

            # outputs
            self.loss_desc = IODescription('loss', [
                1,
            ], torch.float32)
            self.prediction_scores_desc = IODescription(
                'prediction_scores',
                ['batch', 'max_seq_len_in_batch', self.vocab_size],
                torch.float32)

            self.seq_relationship_scores_desc = IODescription(
                'seq_relationship_scores', ['batch', 2], torch.float32
            )  # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
            self.hidden_states_desc = IODescription('hidden_states', [
                self.num_hidden_layers, 'batch', 'max_seq_len_in_batch',
                self.hidden_size
            ], torch.float32)
            self.attentions_desc = IODescription('attentions', [
                self.num_hidden_layers, 'batch', self.num_attention_heads,
                'max_seq_len_in_batch', 'max_seq_len_in_batch'
            ], torch.float32)
            self.last_hidden_state_desc = IODescription(
                'last_hidden_state',
                ['batch', 'max_seq_len_in_batch', self.hidden_size],
                torch.float32)
            self.pooler_output_desc = IODescription(
                'pooler_output', ['batch', self.hidden_size], torch.float32)
Exemple #28
0
def run_test(model, model_desc, device, args, gradient_accumulation_steps,
             fp16, allreduce_post_accumulation, get_lr_this_step,
             use_internal_get_lr_this_step, loss_scaler,
             use_internal_loss_scaler, batch_args_option, dataset_len, epochs,
             use_new_api):
    dataloader = create_ort_test_dataloader(model_desc.inputs_,
                                            args.batch_size, args.seq_len,
                                            dataset_len, device)

    if use_new_api:
        assert use_internal_loss_scaler, 'new api should always use internal loss scaler'

        new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step)

        new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None
        options = orttrainer.ORTTrainerOptions({
            'batch': {
                'gradient_accumulation_steps': gradient_accumulation_steps
            },
            'device': {
                'id': device
            },
            'mixed_precision': {
                'enabled': fp16,
                'loss_scaler': new_api_loss_scaler
            },
            'debug': {
                'deterministic_compute': True,
            },
            'utils': {
                'grad_norm_clip': True
            },
            'distributed': {
                'allreduce_post_accumulation': True
            },
            'lr_scheduler':
            new_api_lr_scheduler
        })

        param_optimizer = list(model.named_parameters())
        params = [{
            'params': [
                n for n, p in param_optimizer
                if "bias" in n or "LayerNorm.weight" in n
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }, {
            'params': [
                n for n, p in param_optimizer
                if not ("bias" in n or "LayerNorm.weight" in n)
            ],
            "alpha":
            0.9,
            "beta":
            0.999,
            "lambda":
            0.0,
            "epsilon":
            1e-6
        }]

        vocab_size = 99
        new_model_desc = {
            'inputs': [(
                'input_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'attention_mask',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'token_type_ids',
                ['batch', 'max_seq_len_in_batch'],
            ), (
                'masked_lm_labels',
                ['batch', 'max_seq_len_in_batch'],
            ), ('next_sentence_label', [
                'batch',
            ])],
            'outputs': [('loss', [
                1,
            ], True),
                        ('prediction_scores',
                         ['batch', 'max_seq_len_in_batch', vocab_size]),
                        ('seq_relationship_scores', ['batch', 2])]
        }

        optim_config = optim.LambConfig(params=params, lr=2e-5)
        model = orttrainer.ORTTrainer(model,
                                      new_model_desc,
                                      optim_config,
                                      options=options)
        print("running with new frontend API")
    else:
        model = ORTTrainer(
            model,
            None,
            model_desc,
            "LambOptimizer",
            map_optimizer_attributes=map_optimizer_attributes,
            learning_rate_description=IODescription('Learning_Rate', [
                1,
            ], torch.float32),
            device=device,
            _enable_internal_postprocess=True,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
            world_rank=args.local_rank,
            world_size=args.world_size,
            use_mixed_precision=fp16,
            allreduce_post_accumulation=allreduce_post_accumulation,
            get_lr_this_step=get_lr_this_step
            if use_internal_get_lr_this_step else None,
            loss_scaler=loss_scaler if use_internal_loss_scaler else None,
            _opset_version=14,
            _use_deterministic_compute=True)
        print("running with old frontend API")

    # trainig loop
    eval_batch = None
    if not use_new_api:
        model.train()
    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            if eval_batch is None:
                eval_batch = batch

            if not use_internal_get_lr_this_step:
                lr = get_lr_this_step(step)
                learning_rate = torch.tensor([lr])

            if not use_internal_loss_scaler and fp16:
                loss_scale = torch.tensor([loss_scaler.loss_scale_])

            if batch_args_option == BatchArgsOption.List:
                if not use_internal_get_lr_this_step:
                    batch = batch + [
                        learning_rate,
                    ]
                if not use_internal_loss_scaler and fp16:
                    batch = batch + [
                        loss_scale,
                    ]
                outputs = model.train_step(*batch)
            elif batch_args_option == BatchArgsOption.Dict:
                args, kwargs = split_batch(batch, model_desc.inputs_, 0)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)
            else:
                args_count = int(len(model_desc.inputs_) /
                                 2)  # approx helf args, half kwargs
                args, kwargs = split_batch(batch, model_desc.inputs_,
                                           args_count)
                if not use_internal_get_lr_this_step:
                    kwargs['Learning_Rate'] = learning_rate
                if not use_internal_loss_scaler and fp16:
                    kwargs[model.loss_scale_input_name] = loss_scale
                outputs = model.train_step(*args, **kwargs)

    # eval
    if batch_args_option == BatchArgsOption.List:
        outputs = model.eval_step(*batch)
    elif batch_args_option == BatchArgsOption.Dict:
        args, kwargs = split_batch(batch, model_desc.inputs_, 0)
        outputs = model.eval_step(*args, **kwargs)
    else:
        args_count = int(len(model_desc.inputs_) /
                         2)  # approx helf args, half kwargs
        args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
        outputs = model.eval_step(*args, **kwargs)

    return (output.cpu().numpy() for output in outputs)
    def model_to_desc(self, model_name, model):
        if model_name.startswith('bert') or model_name.startswith('xlnet'):
            model_desc = ModelDescription([
                IODescription('input_ids', ['batch', 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask',
                              ['batch', 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=2),
                IODescription('token_type_ids',
                              ['batch', 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=2),
                IODescription(
                    'labels', [
                        'batch',
                    ], torch.int64, num_classes=2)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('logits', ['batch', 2], torch.float32)
            ])
        elif model_name.startswith('roberta'):
            model_desc = ModelDescription([
                IODescription('input_ids', ['batch', 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask',
                              ['batch', 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=2),
                IODescription(
                    'labels', [
                        'batch',
                    ], torch.int64, num_classes=2)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('logits', ['batch', 2], torch.float32)
            ])
        else:
            raise RuntimeError(
                "unsupported base model name {}.".format(model_name))

        return model_desc
Exemple #30
0
def main():
#Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')

    parser.add_argument('--use-ort', action='store_true', default=False,
                        help='to use onnxruntime as training backend')

    parser.add_argument('--use-ort-trainer', action='store_true', default=False,
                        help='to use onnxruntime as training backend')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    kwargs = {'num_workers': 0, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)


    comm = MPI.COMM_WORLD
    args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0
    args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0
    args.world_size=comm.Get_size()
    torch.cuda.set_device(args.local_rank)
    if use_cuda:
        device = torch.device("cuda", args.local_rank)
    else:
        device = torch.device("cpu")
    args.n_gpu = 1
    set_cuda_device_id(args.local_rank)

    input_size = 784
    hidden_size = 500
    num_classes = 10
    model = NeuralNet(input_size, hidden_size, num_classes)

    model_desc = mnist_model_description()
    if args.use_ort_trainer:
        # use log_interval as gradient accumulate steps
        trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None,
        args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True)
        print('\nBuild ort model done.')

        for epoch in range(1, args.epochs + 1):
            train_with_trainer(args, trainer, device, train_loader, epoch)
            import pdb
            test_with_trainer(args, trainer, device, test_loader)
    else:
        model = ORTModel(model, my_loss, model_desc, device, None, args.world_rank, args.world_size)
        print('\nBuild ort model done.')

        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

        for epoch in range(1, args.epochs + 1):
            train_with_model(args, model, device, train_loader, optimizer, epoch)