Пример #1
0
def gen_dataset():
    """generate dataset for lpot."""
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    batchify_fn_calib = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
        nlp.data.batchify.Stack('float32'),
        nlp.data.batchify.Stack('float32'))

    dev_data_transform = preprocess_dataset(tokenizer,
                                            dev_data,
                                            max_seq_length=max_seq_length,
                                            doc_stride=doc_stride,
                                            max_query_length=max_query_length,
                                            input_features=True,
                                            for_calibration=True)

    dev_dataloader = mx.gluon.data.DataLoader(
        dev_data_transform,
        batchify_fn=batchify_fn_calib,
        num_workers=4, batch_size=test_batch_size,
        shuffle=False, last_batch='keep')
    
    return dev_dataloader
Пример #2
0
def calibration(net, num_calib_batches, quantized_dtype, calib_mode):
    """calibration function on the dev dataset."""
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    batchify_fn_calib = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0,
                              pad_val=vocab[vocab.padding_token],
                              round_to=args.round_to),
        nlp.data.batchify.Pad(axis=0,
                              pad_val=vocab[vocab.padding_token],
                              round_to=args.round_to),
        nlp.data.batchify.Stack('float32'), nlp.data.batchify.Stack('float32'))

    dev_data_transform = preprocess_dataset(tokenizer,
                                            dev_data,
                                            max_seq_length=max_seq_length,
                                            doc_stride=doc_stride,
                                            max_query_length=max_query_length,
                                            input_features=True,
                                            for_calibration=True)

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn_calib,
                                              num_workers=4,
                                              batch_size=test_batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    assert ctx == mx.cpu(), \
        'Currently only supports CPU with MKL-DNN backend.'
    log.info('Now we are doing calibration on dev with %s.', ctx)
    collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=log)
    num_calib_examples = test_batch_size * num_calib_batches
    net = mx.contrib.quantization.quantize_net_v2(
        net,
        quantized_dtype=quantized_dtype,
        exclude_layers=[],
        quantize_mode='smart',
        quantize_granularity='channel-wise',
        calib_data=dev_dataloader,
        calib_mode=calib_mode,
        num_calib_examples=num_calib_examples,
        ctx=ctx,
        LayerOutputCollector=collector,
        logger=log)
    # save params
    ckpt_name = 'model_bert_squad_quantized_{0}'.format(calib_mode)
    params_saved = os.path.join(output_dir, ckpt_name)
    net.export(params_saved, epoch=0)
    log.info('Saving quantized model at %s', output_dir)
Пример #3
0
def get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size):

    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Stack(),
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]),
        nlp.data.batchify.Stack('float32'),
        nlp.data.batchify.Stack('float32'),
        nlp.data.batchify.Stack(),
    )

    train_data = SQuAD("train", version='2.0')[:train_dataset_size]

    train_data_transform, _ = preprocess_dataset(
        train_data,
        SQuADTransform(nlp.data.BERTTokenizer(vocab=vocab, lower=True),
                       max_seq_length=384,
                       doc_stride=128,
                       max_query_length=64,
                       is_pad=True,
                       is_training=True))

    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_size=batch_size,
                                                num_workers=4,
                                                shuffle=True)

    #we only get 4 validation samples
    dev_data = SQuAD("dev", version='2.0')[:val_dataset_size]
    dev_data = mx.gluon.data.SimpleDataset(dev_data)

    dev_dataset = dev_data.transform(SQuADTransform(
        nlp.data.BERTTokenizer(vocab=vocab, lower=True),
        max_seq_length=384,
        doc_stride=128,
        max_query_length=64,
        is_pad=False,
        is_training=False)._transform,
                                     lazy=False)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(nlp.data.BERTTokenizer(vocab=vocab, lower=True),
                       max_seq_length=384,
                       doc_stride=128,
                       max_query_length=64,
                       is_pad=False,
                       is_training=False))

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              num_workers=1,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    return train_dataloader, dev_dataloader, dev_dataset
    def __init__(self, mxnet_vocab=None, perf_count=None, logger=None):
        self.logger = logger
        if self.logger:
            self.logger.info("Constructing QSL...")
        test_batch_size = 1
        eval_features = []

        if self.logger:
            self.logger.info("Creating tokenizer...")
        with open(mxnet_vocab, 'r') as f:
            vocab = nlp.vocab.BERTVocab.from_json(f.read())
        tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)

        round_to = None
        if self.logger:
            self.logger.info("Reading examples...")
        dev_path = os.path.join(os.getcwd(), 'build/data')
        dev_data = SQuAD('dev', version='1.1', root=dev_path)
        dev_data_transform = preprocess_dataset(
            tokenizer,
            dev_data,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            input_features=True)

        self.eval_features = dev_data_transform
        self.count = len(self.eval_features)
        self.perf_count = perf_count if perf_count is not None else self.count
        self.qsl = lg.ConstructQSL(self.count, self.perf_count,
                                   self.load_query_samples,
                                   self.unload_query_samples)
        if self.logger:
            self.logger.info("Finished constructing QSL.")
    def __init__(self, mx_vocab, perf_count):
        import gluonnlp as nlp
        from preprocessing_utils import preprocess_dataset, max_seq_length, max_query_length, doc_stride
        from gluonnlp.data import SQuAD

        eval_features = []
        with open(mx_vocab, 'r') as f:
            vocab = nlp.vocab.BERTVocab.from_json(f.read())
        log.info("Creating tokenizer...")
        tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)

        round_to = None
        log.info("Reading examples...")
        dev_path = os.path.join(os.getcwd(), 'build/data')
        dev_data = SQuAD('dev', version='1.1', root=dev_path)
        dev_data_transform = preprocess_dataset(
            tokenizer,
            dev_data,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            input_features=True)

        self.eval_features = dev_data_transform
        self.count = len(self.eval_features)
        self.perf_count = perf_count if perf_count is not None else self.count
def test_transform_to_nd_array():
    dataset = SQuAD(segment='dev', version='1.1', root='tests/data/squad')
    vocab_provider = VocabProvider(dataset)
    transformer = SQuADTransform(vocab_provider, question_max_length,
                                 context_max_length)
    record = dataset[0]

    transformed_record = transformer(*record)
    assert transformed_record is not None
    assert len(transformed_record) == 7
def test_data_loader_able_to_read():
    dataset = SQuAD(segment='dev', root='tests/data/squad')
    vocab_provider = VocabProvider(dataset)
    transformer = SQuADTransform(vocab_provider, question_max_length,
                                 context_max_length)
    record = dataset[0]

    processed_dataset = SimpleDataset([transformer(*record)])
    loadable_data = SimpleDataset([(r[0], r[2], r[3], r[4], r[5], r[6])
                                   for r in processed_dataset])
    dataloader = DataLoader(loadable_data, batch_size=1)

    for data in dataloader:
        record_index, question_words, context_words, question_chars, context_chars, answers = data

        assert record_index is not None
        assert question_words is not None
        assert context_words is not None
        assert question_chars is not None
        assert context_chars is not None
        assert answers is not None
Пример #8
0
    def get_processed_data(self,
                           use_spacy=True,
                           shrink_word_vocab=True,
                           squad_data_root=None):
        """Main method to start data processing

        Parameters
        ----------
        use_spacy : bool, default True
            Shall use Spacy as a tokenizer. If not, uses NLTK
        shrink_word_vocab : bool, default True
            When True, only tokens that have embeddings in the embedding file are remained in the
            word_vocab. Otherwise tokens with no embedding also stay
        squad_data_root : str, default None
            Data path to store downloaded original SQuAD data
        Returns
        -------
        train_json_data : dict
            Train JSON data of SQuAD dataset as is to run official evaluation script
        dev_json_data : dict
            Dev JSON data of SQuAD dataset as is to run official evaluation script
        train_examples : SQuADQADataset
            Processed examples to be used for training
        dev_examples : SQuADQADataset
            Processed examples to be used for evaluation
        word_vocab : Vocab
            Word vocabulary
        char_vocab : Vocab
            Char vocabulary

        """
        if self._save_load_data and self._has_processed_data():
            return self._load_processed_data()

        train_dataset = SQuAD(segment='train', root=squad_data_root) \
            if squad_data_root else SQuAD(segment='train')
        dev_dataset = SQuAD(segment='dev', root=squad_data_root) \
            if squad_data_root else SQuAD(segment='dev')

        with contextlib.closing(mp.Pool(processes=self._num_workers)) as pool:
            train_examples, dev_examples = SQuADDataPipeline._tokenize_data(
                train_dataset, dev_dataset, use_spacy, pool)
            word_vocab, char_vocab = SQuADDataPipeline._get_vocabs(
                train_examples, dev_examples, self._emb_file_name,
                self._is_cased_embedding, shrink_word_vocab, pool)

        filter_provider = SQuADDataFilter(self._train_para_limit,
                                          self._train_ques_limit,
                                          self._ans_limit)
        train_examples = list(filter(filter_provider.filter, train_examples))

        train_featurizer = SQuADDataFeaturizer(word_vocab, char_vocab,
                                               self._train_para_limit,
                                               self._train_ques_limit,
                                               self._char_limit,
                                               self._is_cased_embedding)

        dev_featuarizer = SQuADDataFeaturizer(word_vocab, char_vocab,
                                              self._dev_para_limit,
                                              self._dev_ques_limit,
                                              self._char_limit,
                                              self._is_cased_embedding)

        train_examples, dev_examples = SQuADDataPipeline._featurize_data(
            train_examples, dev_examples, train_featurizer, dev_featuarizer)

        if self._save_load_data:
            self._save_processed_data(train_examples, dev_examples, word_vocab,
                                      char_vocab)

        return train_dataset._read_data(), dev_dataset._read_data(), \
               SQuADQADataset(train_examples), SQuADQADataset(dev_examples), word_vocab, char_vocab
Пример #9
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    dev_dataset = dev_data.transform(SQuADTransform(
        copy.copy(tokenizer),
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_pad=False,
        is_training=False)._transform,
                                     lazy=False)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(copy.copy(tokenizer),
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=False,
                       is_training=False))
    log.info('The number of examples after preprocessing:{}'.format(
        len(dev_data_transform)))

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              num_workers=4,
                                              batch_size=test_batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    log.info('start prediction')

    all_results = collections.defaultdict(list)

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)
        out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx))

        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(PredResult(start=start, end=end))

    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()

    for features in dev_dataset:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id

        prediction, _ = predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
            max_answer_length=max_answer_length,
            null_score_diff_threshold=null_score_diff_threshold,
            n_best_size=n_best_size,
            version_2=version_2)

        all_predictions[example_qas_id] = prediction

    with io.open(os.path.join(output_dir, 'predictions.json'),
                 'w',
                 encoding='utf-8') as fout:
        data = json.dumps(all_predictions, ensure_ascii=False)
        fout.write(data)

    if version_2:
        log.info(
            'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0'
        )
    else:
        F1_EM = get_F1_EM(dev_data, all_predictions)
        log.info(F1_EM)
Пример #10
0
def train():
    """Training function."""
    segment = 'train' if not args.debug else 'dev'
    log.info('Loading %s data...', segment)
    if version_2:
        train_data = SQuAD(segment, version='2.0')
    else:
        train_data = SQuAD(segment, version='1.1')
    if args.debug:
        sampled_data = [train_data[i] for i in range(1000)]
        train_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in Train data:{}'.format(len(train_data)))

    train_data_transform, _ = preprocess_dataset(
        train_data,
        SQuADTransform(copy.copy(tokenizer),
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=True,
                       is_training=True))
    log.info('The number of examples after preprocessing:{}'.format(
        len(train_data_transform)))

    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_size=batch_size,
                                                num_workers=4,
                                                shuffle=True)

    log.info('Start Training')

    optimizer_params = {'learning_rate': lr}
    try:
        trainer = mx.gluon.Trainer(net.collect_params(),
                                   optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    except ValueError as e:
        print(e)
        warnings.warn(
            'AdamW optimizer is not found. Please consider upgrading to '
            'mxnet>=1.5.0. Now the original Adam optimizer is used instead.')
        trainer = mx.gluon.Trainer(net.collect_params(),
                                   'adam',
                                   optimizer_params,
                                   update_on_kvstore=False)

    num_train_examples = len(train_data_transform)
    step_size = batch_size * accumulate if accumulate else batch_size
    num_train_steps = int(num_train_examples / step_size * epochs)
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    step_num = 0

    def set_new_lr(step_num, batch_id):
        """set new learning rate"""
        # set grad to zero for gradient accumulation
        if accumulate:
            if batch_id % accumulate == 0:
                net.collect_params().zero_grad()
                step_num += 1
        else:
            step_num += 1
        # learning rate schedule
        # Notice that this learning rate scheduler is adapted from traditional linear learning
        # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps
        if step_num < num_warmup_steps:
            new_lr = lr * step_num / num_warmup_steps
        else:
            offset = (step_num - num_warmup_steps) * lr / \
                (num_train_steps - num_warmup_steps)
            new_lr = lr - offset
        trainer.set_learning_rate(new_lr)
        return step_num

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    if accumulate:
        for p in params:
            p.grad_req = 'add'

    epoch_tic = time.time()
    total_num = 0
    log_num = 0
    for epoch_id in range(epochs):
        step_loss = 0.0
        tic = time.time()
        for batch_id, data in enumerate(train_dataloader):
            # set new lr
            step_num = set_new_lr(step_num, batch_id)
            # forward and backward
            with mx.autograd.record():
                _, inputs, token_types, valid_length, start_label, end_label = data

                log_num += len(inputs)
                total_num += len(inputs)

                out = net(
                    inputs.astype('float32').as_in_context(ctx),
                    token_types.astype('float32').as_in_context(ctx),
                    valid_length.astype('float32').as_in_context(ctx))

                ls = loss_function(out, [
                    start_label.astype('float32').as_in_context(ctx),
                    end_label.astype('float32').as_in_context(ctx)
                ]).mean()

                if accumulate:
                    ls = ls / accumulate
            ls.backward()
            # update
            if not accumulate or (batch_id + 1) % accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, 1)
                trainer.update(1)

            step_loss += ls.asscalar()

            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                log.info(
                    'Epoch: {}, Batch: {}/{}, Loss={:.4f}, lr={:.7f} Time cost={:.1f} Thoughput={:.2f} samples/s'  # pylint: disable=line-too-long
                    .format(epoch_id, batch_id, len(train_dataloader),
                            step_loss / log_interval, trainer.learning_rate,
                            toc - tic, log_num / (toc - tic)))
                tic = time.time()
                step_loss = 0.0
                log_num = 0
        epoch_toc = time.time()
        log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
            epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    net.save_parameters(os.path.join(output_dir, 'net.params'))
Пример #11
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = dev_data[:10]  # [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    dev_dataset = dev_data.transform(SQuADTransform(
        copy.copy(tokenizer),
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_pad=True,
        is_training=True)._transform,
                                     lazy=False)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(copy.copy(tokenizer),
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=True,
                       is_training=True))

    # refer to evaluation process
    # for feat in train_dataset:
    #     print(feat[0].example_id)
    #     print(feat[0].tokens)
    #     print(feat[0].token_to_orig_map)
    #     input()
    # exit(0)

    dev_features = {
        features[0].example_id: features
        for features in dev_dataset
    }

    #for line in train_data_transform:
    #    print(line)
    #    input()

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              batch_size=test_batch_size,
                                              num_workers=4,
                                              shuffle=True)
    '''

    dev_dataset = dev_data.transform(
        SQuADTransform(
            copy.copy(tokenizer),
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_pad=False,
            is_training=False)._transform, lazy=False)

    # for feat in dev_dataset:
    #     print(feat[0].example_id)
    #     print(feat[0].tokens)
    #     print(feat[0].token_to_orig_map)
    #     input()
    # exit(0)

    dev_features = {features[0].example_id: features for features in dev_dataset}

    dev_data_transform, _ = preprocess_dataset(
        dev_data, SQuADTransform(
            copy.copy(tokenizer),
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_pad=False,
            is_training=False))
    log.info('The number of examples after preprocessing:{}'.format(
        len(dev_data_transform)))

    dev_dataloader = mx.gluon.data.DataLoader(
        dev_data_transform,
        batchify_fn=batchify_fn,
        num_workers=4, batch_size=test_batch_size,
        shuffle=False, last_batch='keep')
    '''
    log.info('start prediction')

    all_results = collections.defaultdict(list)

    if args.verify and VERIFIER_ID in [2, 3]:
        all_pre_na_prob = collections.defaultdict(list)
    else:
        all_pre_na_prob = None

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)
        out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx))

        if all_pre_na_prob is not None:
            has_answer_tmp = verifier.evaluate(dev_features, example_ids,
                                               out).asnumpy().tolist()

        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(PredResult(start=start, end=end))
        if all_pre_na_prob is not None:
            for example_id, has_ans_prob in zip(example_ids, has_answer_tmp):
                all_pre_na_prob[example_id].append(has_ans_prob)

    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()

    for features in dev_dataset:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id

        if all_pre_na_prob is not None:
            has_ans_prob_list = all_pre_na_prob[features[0].example_id]
            has_ans_prob = sum(has_ans_prob_list) / max(
                len(has_ans_prob_list), 1)
            if has_ans_prob < 0.5:
                prediction = ""
                all_predictions[example_qas_id] = prediction
                continue

        prediction, _ = predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
            max_answer_length=max_answer_length,
            null_score_diff_threshold=null_score_diff_threshold,
            n_best_size=n_best_size,
            version_2=version_2)

        if args.verify and VERIFIER_ID == 1:
            if len(prediction) > 0:
                has_answer = verifier.evaluate(features, prediction)
                if not has_answer:
                    prediction = ""

        all_predictions[example_qas_id] = prediction
        # the form of hashkey - answer string

    with io.open(os.path.join(output_dir, 'predictions.json'),
                 'w',
                 encoding='utf-8') as fout:
        data = json.dumps(all_predictions, ensure_ascii=False)
        fout.write(data)

    if version_2:
        log.info(
            'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0'
        )
    else:
        F1_EM = get_F1_EM(dev_data, all_predictions)
        log.info(F1_EM)
Пример #12
0
def calibration(net, num_calib_batches, quantized_dtype, calib_mode):
    """calibration function on the dev dataset."""
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))
    origin_dev_data_len = len(dev_data)
    num_calib_examples = test_batch_size * num_calib_batches
    ### randomly select the calib data from full dataset
    random_indices = np.random.choice(origin_dev_data_len, num_calib_examples)
    print ('random_indices: ', random_indices)
    dev_data=list(dev_data[i] for i in random_indices)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    batchify_fn_calib = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
        nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to),
        nlp.data.batchify.Stack('float32'),
        nlp.data.batchify.Stack('float32'))

    dev_data_transform = preprocess_dataset(tokenizer,
                                            dev_data,
                                            max_seq_length=max_seq_length,
                                            doc_stride=doc_stride,
                                            max_query_length=max_query_length,
                                            input_features=True,
                                            for_calibration=True)

    dev_dataloader = mx.gluon.data.DataLoader(
        dev_data_transform,
        batchify_fn=batchify_fn_calib,
        num_workers=4, batch_size=test_batch_size,
        shuffle=True, last_batch='keep')

    net = run_pass(net, 'custom_pass')
    assert ctx == mx.cpu(), \
        'Currently only supports CPU with MKL-DNN backend.'
    log.info('Now we are doing calibration on dev with %s.', ctx)
    collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=log)
    net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype,
                                                  exclude_layers=[],
                                                  quantize_mode='smart',
                                                  quantize_granularity='tensor-wise',
                                                  calib_data=dev_dataloader,
                                                  calib_mode=calib_mode,
                                                  num_calib_examples=num_calib_examples,
                                                  ctx=ctx,
                                                  LayerOutputCollector=collector,
                                                  logger=log)
    if scenario == "offline":
        net = run_pass(net, 'softmax_mask')
    else:
        net = run_pass(net, 'normal_softmax')

    net = run_pass(net, 'bias_to_s32')

    # # save params
    ckpt_name = 'model_bert_squad_quantized_{0}'.format(calib_mode)
    params_saved = os.path.join(output_dir, ckpt_name)
    net.hybridize(static_alloc=True, static_shape=True)
    
    a = mx.nd.ones((test_batch_size, max_seq_length), dtype='float32')
    b = mx.nd.ones((test_batch_size, max_seq_length), dtype='float32')
    c = mx.nd.ones((test_batch_size, ), dtype='float32')
    net(a,b,c)
    mx.nd.waitall()
    net.export(params_saved, epoch=0)
    log.info('Saving quantized model at %s', output_dir)
Пример #13
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    log.info('Number of records in Train data:{}'.format(len(dev_data)))

    dev_dataset = dev_data.transform(
        SQuADTransform(berttoken,
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=False,
                       is_training=False)._transform)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(berttoken,
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=False,
                       is_training=False))
    log.info('The number of examples after preprocessing:{}'.format(
        len(dev_data_transform)))

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              num_workers=4,
                                              batch_size=test_batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    log.info('Start predict')

    _Result = collections.namedtuple(
        '_Result', ['example_id', 'start_logits', 'end_logits'])
    all_results = {}

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)
        out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx))

        output = nd.split(out, axis=2, num_outputs=2)
        start_logits = output[0].reshape((0, -3)).asnumpy()
        end_logits = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, start_logits,
                                          end_logits):
            example_id = example_id.asscalar()
            if example_id not in all_results:
                all_results[example_id] = []
            all_results[example_id].append(
                _Result(example_id, start.tolist(), end.tolist()))
        if args.test_mode:
            log.info('Exit early in test mode')
            break
    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))
    log.info('Get prediction results...')

    all_predictions, all_nbest_json, scores_diff_json = predictions(
        dev_dataset=dev_dataset,
        all_results=all_results,
        tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
        max_answer_length=max_answer_length,
        null_score_diff_threshold=null_score_diff_threshold,
        n_best_size=n_best_size,
        version_2=version_2,
        test_mode=args.test_mode)

    with open(os.path.join(output_dir, 'predictions.json'),
              'w',
              encoding='utf-8') as all_predictions_write:
        all_predictions_write.write(json.dumps(all_predictions))

    with open(os.path.join(output_dir, 'nbest_predictions.json'),
              'w',
              encoding='utf-8') as all_predictions_write:
        all_predictions_write.write(json.dumps(all_nbest_json))

    if version_2:
        with open(os.path.join(output_dir, 'null_odds.json'),
                  'w',
                  encoding='utf-8') as all_predictions_write:
            all_predictions_write.write(json.dumps(scores_diff_json))
    else:
        log.info(get_F1_EM(dev_data, all_predictions))
Пример #14
0
    else:
        print("ERROR: verifier with id {0} unknown to the model.".format(
            VERIFIER_ID))
        exit(0)
else:
    VERIFIER_ID = None

########################################
#         Prepare the data - Begin
########################################

# train dataset
segment = 'train' if not args.debug else 'dev'
log.info('Loading %s data...', segment)
if version_2:
    train_data = SQuAD(segment, version='2.0')
else:
    train_data = SQuAD(segment, version='1.1')
if args.debug:
    sampled_data = [train_data[i] for i in range(120)]  # 1000 # 120 # 60
    train_data = mx.gluon.data.SimpleDataset(sampled_data)
log.info('Number of records in Train data:{}'.format(len(train_data)))

train_dataset = train_data.transform(SQuADTransform(
    copy.copy(tokenizer),
    max_seq_length=max_seq_length,
    doc_stride=doc_stride,
    max_query_length=max_query_length,
    is_pad=True,
    is_training=True)._transform,
                                     lazy=False)
Пример #15
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if args.version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    (_, _), (data_file_name, _) \
        = dev_data._data_file[dev_data._version][dev_data._segment]
    dev_data_path = os.path.join(dev_data._root, data_file_name)

    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data: %d', len(dev_data))

    dev_data_features = preprocess_dataset(
        tokenizer,
        dev_data,
        vocab=vocab,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        num_workers=args.num_workers,
        max_query_length=args.max_query_length,
        load_from_pickle=args.load_pickle,
        feature_file=args.dev_dataset_file)

    dev_data_input = convert_full_features_to_input_features(dev_data_features)
    log.info('The number of examples after preprocessing: %d',
             len(dev_data_input))

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_input,
                                              batchify_fn=batchify_fn,
                                              num_workers=4,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    log.info('start prediction')

    all_results = collections.defaultdict(list)

    epoch_tic = time.time()
    total_num = 0
    for (batch_id, data) in enumerate(dev_dataloader):
        data_list = list(split_and_load(data, ctx))
        for splited_data in data_list:
            example_ids, inputs, token_types, valid_length, p_mask, _, _, _ = splited_data
            total_num += len(inputs)
            outputs = net_eval(inputs,
                               token_types,
                               valid_length,
                               p_mask=p_mask)
            example_ids = example_ids.asnumpy().tolist()
            for c, example_ids in enumerate(example_ids):
                result = RawResultExtended(
                    start_top_log_probs=outputs[0][c].asnumpy().tolist(),
                    start_top_index=outputs[1][c].asnumpy().tolist(),
                    end_top_log_probs=outputs[2][c].asnumpy().tolist(),
                    end_top_index=outputs[3][c].asnumpy().tolist(),
                    cls_logits=outputs[4][c].asnumpy().tolist())
                all_results[example_ids].append(result)
        if batch_id % args.log_interval == 0:
            log.info('Batch: %d/%d', batch_id + 1, len(dev_dataloader))

    epoch_toc = time.time()
    log.info('Time cost=%2f s, Thoughput=%.2f samples/s',
             epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()
    for features in dev_data_features:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id
        score_diff, best_non_null_entry, nbest_json = predict_extended(
            features=features,
            results=results,
            n_best_size=args.n_best_size,
            max_answer_length=args.max_answer_length,
            start_n_top=args.start_top_n,
            end_n_top=args.end_top_n)
        scores_diff_json[example_qas_id] = score_diff
        all_predictions[example_qas_id] = best_non_null_entry
        all_nbest_json[example_qas_id] = nbest_json

    output_prediction_file = os.path.join(args.output_dir, 'predictions.json')
    output_nbest_file = os.path.join(args.output_dir, 'nbest_predictions.json')
    output_null_log_odds_file = os.path.join(args.output_dir, 'null_odds.json')

    with open(output_prediction_file, 'w') as writer:
        writer.write(json.dumps(all_predictions, indent=4) + '\n')
    with open(output_nbest_file, 'w') as writer:
        writer.write(json.dumps(all_nbest_json, indent=4) + '\n')
    with open(output_null_log_odds_file, 'w') as writer:
        writer.write(json.dumps(scores_diff_json, indent=4) + '\n')

    if os.path.exists(sys.path[0] + '/evaluate-v2.0.py'):
        arguments = [
            dev_data_path, output_prediction_file, '--na-prob-thresh',
            str(args.null_score_diff_threshold)
        ]
        if args.version_2:
            arguments += ['--na-prob-file', output_null_log_odds_file]
        subprocess.call([sys.executable, sys.path[0] + '/evaluate-v2.0.py'] +
                        arguments)
    else:
        log.info(
            'Please download evaluate-v2.0.py to get evaluation results for SQuAD. '
            'Check index.rst for the detail.')
Пример #16
0
def train():
    """Training function."""
    segment = 'train'
    log.info('Loading %s data...', segment)
    # Note that for XLNet, the authors always use squad2 dataset for training
    train_data = SQuAD(segment, version='2.0')
    if args.debug:
        sampled_data = [train_data[i] for i in range(100)]
        train_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in Train data: %s', len(train_data))

    train_data_features = preprocess_dataset(
        tokenizer,
        train_data,
        vocab=vocab,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        num_workers=args.num_workers,
        max_query_length=args.max_query_length,
        load_from_pickle=args.load_pickle,
        feature_file=args.train_dataset_file)

    train_data_input = convert_full_features_to_input_features(
        train_data_features)
    log.info('The number of examples after preprocessing: %s',
             len(train_data_input))

    train_dataloader = mx.gluon.data.DataLoader(train_data_input,
                                                batchify_fn=batchify_fn,
                                                batch_size=args.batch_size,
                                                num_workers=4,
                                                shuffle=True)

    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd}
    try:
        trainer = mx.gluon.Trainer(net.collect_params(),
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    except ValueError as _:
        warnings.warn(
            'AdamW optimizer is not found. Please consider upgrading to '
            'mxnet>=1.5.0. Now the original Adam optimizer is used instead.')
        trainer = mx.gluon.Trainer(net.collect_params(),
                                   'bertadam',
                                   optimizer_params,
                                   update_on_kvstore=False)

    num_train_examples = len(train_data_input)
    step_size = args.batch_size * args.accumulate if args.accumulate else args.batch_size
    num_train_steps = int(num_train_examples / step_size * args.epochs)
    epoch_number = args.epochs
    if args.training_steps:
        num_train_steps = args.training_steps
        epoch_number = 100000

    log.info('training steps=%d', num_train_steps)
    num_warmup_steps = int(num_train_steps * args.warmup_ratio)
    step_num = 0

    def set_new_lr(step_num, batch_id):
        """set new learning rate"""
        # set grad to zero for gradient accumulation
        if args.accumulate:
            if batch_id % args.accumulate == 0:
                net.collect_params().zero_grad()
                step_num += 1
        else:
            step_num += 1
        # learning rate schedule
        # Notice that this learning rate scheduler is adapted from traditional linear learning
        # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps
        if step_num < num_warmup_steps:
            new_lr = args.lr * step_num / num_warmup_steps
        else:
            offset = (step_num - num_warmup_steps) * args.lr / \
                (num_train_steps - num_warmup_steps)
            new_lr = args.lr - offset
        trainer.set_learning_rate(new_lr)
        return step_num

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    if args.accumulate:
        for p in params:
            p.grad_req = 'add'

    epoch_tic = time.time()
    total_num = 0
    log_num = 0
    finish_flag = False
    for epoch_id in range(epoch_number):
        step_loss = 0.0
        step_loss_span = 0
        step_loss_cls = 0
        tic = time.time()
        if finish_flag:
            break
        for batch_id, data in enumerate(train_dataloader):
            # set new lr
            step_num = set_new_lr(step_num, batch_id)
            data_list = list(split_and_load(data, ctx))
            # forward and backward
            batch_loss = []
            batch_loss_sep = []
            with mx.autograd.record():
                for splited_data in data_list:
                    _, inputs, token_types, valid_length, p_mask, start_label, end_label, is_impossible = splited_data  # pylint: disable=line-too-long
                    valid_length = valid_length.astype('float32')
                    log_num += len(inputs)
                    total_num += len(inputs)
                    out_sep, out = net(
                        inputs,
                        token_types,
                        valid_length,
                        [start_label, end_label],
                        p_mask=p_mask,  # pylint: disable=line-too-long
                        is_impossible=is_impossible)
                    ls = out.mean() / len(ctx)
                    batch_loss_sep.append(out_sep)
                    batch_loss.append(ls)
                    if args.accumulate:
                        ls = ls / args.accumulate
                    ls.backward()
            # update
            if not args.accumulate or (batch_id + 1) % args.accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, 1)
                _apply_gradient_decay()
                trainer.update(1, ignore_stale_grad=True)

                step_loss_sep_tmp = np.array(
                    [[span_ls.mean().asscalar(),
                      cls_ls.mean().asscalar()]
                     for span_ls, cls_ls in batch_loss_sep])
                step_loss_sep_tmp = list(np.sum(step_loss_sep_tmp, axis=0))
                step_loss_span += step_loss_sep_tmp[0] / len(ctx)
                step_loss_cls += step_loss_sep_tmp[1] / len(ctx)

            step_loss += sum([ls.asscalar() for ls in batch_loss])
            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                log.info(
                    'Epoch: %d, Batch: %d/%d, Loss=%.4f, lr=%.7f '
                    'Time cost=%.1f Thoughput=%.2f samples/s', epoch_id + 1,
                    batch_id + 1, len(train_dataloader),
                    step_loss / log_interval, trainer.learning_rate, toc - tic,
                    log_num / (toc - tic))
                log.info('span_loss: %.4f, cls_loss: %.4f',
                         step_loss_span / log_interval,
                         step_loss_cls / log_interval)

                tic = time.time()
                step_loss = 0.0
                step_loss_span = 0
                step_loss_cls = 0
                log_num = 0
            if step_num >= num_train_steps:
                logging.info('Finish training step: %d', step_num)
                finish_flag = True
                break
        epoch_toc = time.time()
        log.info('Time cost=%.2f s, Thoughput=%.2f samples/s',
                 epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic))
        version_prefix = 'squad2' if args.version_2 else 'squad1'
        ckpt_name = 'model_{}_{}_{}.params'.format(args.model, version_prefix,
                                                   epoch_id + 1)
        params_saved = os.path.join(args.output_dir, ckpt_name)
        nlp.utils.save_parameters(net, params_saved)
        log.info('params saved in: %s', params_saved)
def test_load_vocabs():
    dataset = SQuAD(segment='dev', root='tests/data/squad')
    vocab_provider = VocabProvider(dataset)

    assert vocab_provider.get_word_level_vocab() is not None
    assert vocab_provider.get_char_level_vocab() is not None
Пример #18
0
def train():
    """Training function."""
    segment = 'train'  #if not args.debug else 'dev'
    log.info('Loading %s data...', segment)
    if version_2:
        train_data = SQuAD(segment, version='2.0')
    else:
        train_data = SQuAD(segment, version='1.1')
    if args.debug:
        sampled_data = [train_data[i] for i in range(0, 10000)]
        train_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in Train data:{}'.format(len(train_data)))
    train_data_transform = preprocess_dataset(
        tokenizer,
        train_data,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        input_features=True)

    log.info('The number of examples after preprocessing:{}'.format(
        len(train_data_transform)))

    sampler = nlp.data.SplitSampler(len(train_data_transform),
                                    num_parts=size,
                                    part_index=rank,
                                    even_size=True)
    num_train_examples = len(sampler)
    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_size=batch_size,
                                                num_workers=4,
                                                sampler=sampler)

    log.info('Start Training')

    optimizer_params = {'learning_rate': lr}
    param_dict = net.collect_params()
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.dtype == 'float16':
        amp.init_trainer(trainer)

    step_size = batch_size * accumulate if accumulate else batch_size
    num_train_steps = int(num_train_examples / step_size * args.epochs)
    if args.training_steps:
        num_train_steps = args.training_steps

    num_warmup_steps = int(num_train_steps * warmup_ratio)

    def set_new_lr(step_num, batch_id):
        """set new learning rate"""
        # set grad to zero for gradient accumulation
        if accumulate:
            if batch_id % accumulate == 0:
                step_num += 1
        else:
            step_num += 1
        # learning rate schedule
        # Notice that this learning rate scheduler is adapted from traditional linear learning
        # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps
        if step_num < num_warmup_steps:
            new_lr = lr * step_num / num_warmup_steps
        else:
            offset = (step_num - num_warmup_steps) * lr / \
                (num_train_steps - num_warmup_steps)
            new_lr = lr - offset
        trainer.set_learning_rate(new_lr)
        return step_num

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if accumulate:
        for p in params:
            p.grad_req = 'add'
    net.collect_params().zero_grad()

    epoch_tic = time.time()

    total_num = 0
    log_num = 0
    batch_id = 0
    step_loss = 0.0
    tic = time.time()
    step_num = 0

    tic = time.time()
    while step_num < num_train_steps:
        for _, data in enumerate(train_dataloader):
            # set new lr
            step_num = set_new_lr(step_num, batch_id)
            # forward and backward
            _, inputs, token_types, valid_length, start_label, end_label = data
            num_labels = len(inputs)
            log_num += num_labels
            total_num += num_labels

            with mx.autograd.record():
                out = net(inputs.as_in_context(ctx),
                          token_types.as_in_context(ctx),
                          valid_length.as_in_context(ctx).astype('float32'))

                loss = loss_function(out, [
                    start_label.as_in_context(ctx).astype('float32'),
                    end_label.as_in_context(ctx).astype('float32')
                ]).sum() / num_labels

                if accumulate:
                    loss = loss / accumulate
                if args.dtype == 'float16':
                    with amp.scale_loss(loss, trainer) as l:
                        mx.autograd.backward(l)
                        norm_clip = 1.0 * size * trainer._amp_loss_scaler.loss_scale
                else:
                    mx.autograd.backward(loss)
                    norm_clip = 1.0 * size

            # update
            if not accumulate or (batch_id + 1) % accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, norm_clip)
                trainer.update(1)
                if accumulate:
                    param_dict.zero_grad()

            if args.comm_backend == 'horovod':
                step_loss += hvd.allreduce(loss, average=True).asscalar()
            else:
                step_loss += loss.asscalar()

            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                log.info('Batch: {}/{}, Loss={:.4f}, lr={:.7f} '
                         'Thoughput={:.2f} samples/s'.format(
                             batch_id % len(train_dataloader),
                             len(train_dataloader), step_loss / log_interval,
                             trainer.learning_rate, log_num / (toc - tic)))
                tic = time.time()
                step_loss = 0.0
                log_num = 0

            if step_num >= num_train_steps:
                break
            batch_id += 1

        log.info('Finish training step: %d', step_num)
        epoch_toc = time.time()
        log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
            epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    if rank == 0:
        net.save_parameters(os.path.join(output_dir, 'net.params'))