def get_dataloaders(batch_size, vocab, train_dataset_size, val_dataset_size): batchify_fn = nlp.data.batchify.Tuple( nlp.data.batchify.Stack(), nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), nlp.data.batchify.Stack('float32'), nlp.data.batchify.Stack('float32'), nlp.data.batchify.Stack(), ) train_data = SQuAD("train", version='2.0')[:train_dataset_size] train_data_transform, _ = preprocess_dataset( train_data, SQuADTransform(nlp.data.BERTTokenizer(vocab=vocab, lower=True), max_seq_length=384, doc_stride=128, max_query_length=64, is_pad=True, is_training=True)) train_dataloader = mx.gluon.data.DataLoader(train_data_transform, batchify_fn=batchify_fn, batch_size=batch_size, num_workers=4, shuffle=True) #we only get 4 validation samples dev_data = SQuAD("dev", version='2.0')[:val_dataset_size] dev_data = mx.gluon.data.SimpleDataset(dev_data) dev_dataset = dev_data.transform(SQuADTransform( nlp.data.BERTTokenizer(vocab=vocab, lower=True), max_seq_length=384, doc_stride=128, max_query_length=64, is_pad=False, is_training=False)._transform, lazy=False) dev_data_transform, _ = preprocess_dataset( dev_data, SQuADTransform(nlp.data.BERTTokenizer(vocab=vocab, lower=True), max_seq_length=384, doc_stride=128, max_query_length=64, is_pad=False, is_training=False)) dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batchify_fn=batchify_fn, num_workers=1, batch_size=batch_size, shuffle=False, last_batch='keep') return train_dataloader, dev_dataloader, dev_dataset
def evaluate(): """Evaluate the model on validation dataset. """ log.info('Loading dev data...') if version_2: dev_data = SQuAD('dev', version='2.0') else: dev_data = SQuAD('dev', version='1.1') if args.debug: sampled_data = [dev_data[0], dev_data[1], dev_data[2]] dev_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in dev data:{}'.format(len(dev_data))) dev_dataset = dev_data.transform(SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)._transform, lazy=False) dev_data_transform, _ = preprocess_dataset( dev_data, SQuADTransform(copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)) log.info('The number of examples after preprocessing:{}'.format( len(dev_data_transform))) dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batchify_fn=batchify_fn, num_workers=4, batch_size=test_batch_size, shuffle=False, last_batch='keep') log.info('start prediction') all_results = collections.defaultdict(list) epoch_tic = time.time() total_num = 0 for data in dev_dataloader: example_ids, inputs, token_types, valid_length, _, _ = data total_num += len(inputs) out = net( inputs.astype('float32').as_in_context(ctx), token_types.astype('float32').as_in_context(ctx), valid_length.astype('float32').as_in_context(ctx)) output = mx.nd.split(out, axis=2, num_outputs=2) example_ids = example_ids.asnumpy().tolist() pred_start = output[0].reshape((0, -3)).asnumpy() pred_end = output[1].reshape((0, -3)).asnumpy() for example_id, start, end in zip(example_ids, pred_start, pred_end): all_results[example_id].append(PredResult(start=start, end=end)) epoch_toc = time.time() log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format( epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic))) log.info('Get prediction results...') all_predictions = collections.OrderedDict() for features in dev_dataset: results = all_results[features[0].example_id] example_qas_id = features[0].qas_id prediction, _ = predict( features=features, results=results, tokenizer=nlp.data.BERTBasicTokenizer(lower=lower), max_answer_length=max_answer_length, null_score_diff_threshold=null_score_diff_threshold, n_best_size=n_best_size, version_2=version_2) all_predictions[example_qas_id] = prediction with io.open(os.path.join(output_dir, 'predictions.json'), 'w', encoding='utf-8') as fout: data = json.dumps(all_predictions, ensure_ascii=False) fout.write(data) if version_2: log.info( 'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0' ) else: F1_EM = get_F1_EM(dev_data, all_predictions) log.info(F1_EM)
def evaluate(): """Evaluate the model on validation dataset. """ log.info('Loading dev data...') if version_2: dev_data = SQuAD('dev', version='2.0') else: dev_data = SQuAD('dev', version='1.1') log.info('Number of records in Train data:{}'.format(len(dev_data))) dev_dataset = dev_data.transform( SQuADTransform(berttoken, max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)._transform) dev_data_transform, _ = preprocess_dataset( dev_data, SQuADTransform(berttoken, max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)) log.info('The number of examples after preprocessing:{}'.format( len(dev_data_transform))) dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batchify_fn=batchify_fn, num_workers=4, batch_size=test_batch_size, shuffle=False, last_batch='keep') log.info('Start predict') _Result = collections.namedtuple( '_Result', ['example_id', 'start_logits', 'end_logits']) all_results = {} epoch_tic = time.time() total_num = 0 for data in dev_dataloader: example_ids, inputs, token_types, valid_length, _, _ = data total_num += len(inputs) out = net( inputs.astype('float32').as_in_context(ctx), token_types.astype('float32').as_in_context(ctx), valid_length.astype('float32').as_in_context(ctx)) output = nd.split(out, axis=2, num_outputs=2) start_logits = output[0].reshape((0, -3)).asnumpy() end_logits = output[1].reshape((0, -3)).asnumpy() for example_id, start, end in zip(example_ids, start_logits, end_logits): example_id = example_id.asscalar() if example_id not in all_results: all_results[example_id] = [] all_results[example_id].append( _Result(example_id, start.tolist(), end.tolist())) if args.test_mode: log.info('Exit early in test mode') break epoch_toc = time.time() log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format( epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic))) log.info('Get prediction results...') all_predictions, all_nbest_json, scores_diff_json = predictions( dev_dataset=dev_dataset, all_results=all_results, tokenizer=nlp.data.BERTBasicTokenizer(lower=lower), max_answer_length=max_answer_length, null_score_diff_threshold=null_score_diff_threshold, n_best_size=n_best_size, version_2=version_2, test_mode=args.test_mode) with open(os.path.join(output_dir, 'predictions.json'), 'w', encoding='utf-8') as all_predictions_write: all_predictions_write.write(json.dumps(all_predictions)) with open(os.path.join(output_dir, 'nbest_predictions.json'), 'w', encoding='utf-8') as all_predictions_write: all_predictions_write.write(json.dumps(all_nbest_json)) if version_2: with open(os.path.join(output_dir, 'null_odds.json'), 'w', encoding='utf-8') as all_predictions_write: all_predictions_write.write(json.dumps(scores_diff_json)) else: log.info(get_F1_EM(dev_data, all_predictions))
# train dataset segment = 'train' if not args.debug else 'dev' log.info('Loading %s data...', segment) if version_2: train_data = SQuAD(segment, version='2.0') else: train_data = SQuAD(segment, version='1.1') if args.debug: sampled_data = [train_data[i] for i in range(120)] # 1000 # 120 # 60 train_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in Train data:{}'.format(len(train_data))) train_dataset = train_data.transform(SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)._transform, lazy=False) train_data_transform, _ = preprocess_dataset( train_data, SQuADTransform(copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)) log.info('The number of examples after preprocessing:{}'.format( len(train_data_transform)))
def evaluate(): """Evaluate the model on validation dataset. """ log.info('Loading dev data...') if version_2: dev_data = SQuAD('dev', version='2.0') else: dev_data = SQuAD('dev', version='1.1') if args.debug: sampled_data = dev_data[:10] # [dev_data[0], dev_data[1], dev_data[2]] dev_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in dev data:{}'.format(len(dev_data))) dev_dataset = dev_data.transform(SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)._transform, lazy=False) dev_data_transform, _ = preprocess_dataset( dev_data, SQuADTransform(copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)) # refer to evaluation process # for feat in train_dataset: # print(feat[0].example_id) # print(feat[0].tokens) # print(feat[0].token_to_orig_map) # input() # exit(0) dev_features = { features[0].example_id: features for features in dev_dataset } #for line in train_data_transform: # print(line) # input() dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batchify_fn=batchify_fn, batch_size=test_batch_size, num_workers=4, shuffle=True) ''' dev_dataset = dev_data.transform( SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)._transform, lazy=False) # for feat in dev_dataset: # print(feat[0].example_id) # print(feat[0].tokens) # print(feat[0].token_to_orig_map) # input() # exit(0) dev_features = {features[0].example_id: features for features in dev_dataset} dev_data_transform, _ = preprocess_dataset( dev_data, SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=False, is_training=False)) log.info('The number of examples after preprocessing:{}'.format( len(dev_data_transform))) dev_dataloader = mx.gluon.data.DataLoader( dev_data_transform, batchify_fn=batchify_fn, num_workers=4, batch_size=test_batch_size, shuffle=False, last_batch='keep') ''' log.info('start prediction') all_results = collections.defaultdict(list) if args.verify and VERIFIER_ID in [2, 3]: all_pre_na_prob = collections.defaultdict(list) else: all_pre_na_prob = None epoch_tic = time.time() total_num = 0 for data in dev_dataloader: example_ids, inputs, token_types, valid_length, _, _ = data total_num += len(inputs) out = net( inputs.astype('float32').as_in_context(ctx), token_types.astype('float32').as_in_context(ctx), valid_length.astype('float32').as_in_context(ctx)) if all_pre_na_prob is not None: has_answer_tmp = verifier.evaluate(dev_features, example_ids, out).asnumpy().tolist() output = mx.nd.split(out, axis=2, num_outputs=2) example_ids = example_ids.asnumpy().tolist() pred_start = output[0].reshape((0, -3)).asnumpy() pred_end = output[1].reshape((0, -3)).asnumpy() for example_id, start, end in zip(example_ids, pred_start, pred_end): all_results[example_id].append(PredResult(start=start, end=end)) if all_pre_na_prob is not None: for example_id, has_ans_prob in zip(example_ids, has_answer_tmp): all_pre_na_prob[example_id].append(has_ans_prob) epoch_toc = time.time() log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format( epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic))) log.info('Get prediction results...') all_predictions = collections.OrderedDict() for features in dev_dataset: results = all_results[features[0].example_id] example_qas_id = features[0].qas_id if all_pre_na_prob is not None: has_ans_prob_list = all_pre_na_prob[features[0].example_id] has_ans_prob = sum(has_ans_prob_list) / max( len(has_ans_prob_list), 1) if has_ans_prob < 0.5: prediction = "" all_predictions[example_qas_id] = prediction continue prediction, _ = predict( features=features, results=results, tokenizer=nlp.data.BERTBasicTokenizer(lower=lower), max_answer_length=max_answer_length, null_score_diff_threshold=null_score_diff_threshold, n_best_size=n_best_size, version_2=version_2) if args.verify and VERIFIER_ID == 1: if len(prediction) > 0: has_answer = verifier.evaluate(features, prediction) if not has_answer: prediction = "" all_predictions[example_qas_id] = prediction # the form of hashkey - answer string with io.open(os.path.join(output_dir, 'predictions.json'), 'w', encoding='utf-8') as fout: data = json.dumps(all_predictions, ensure_ascii=False) fout.write(data) if version_2: log.info( 'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0' ) else: F1_EM = get_F1_EM(dev_data, all_predictions) log.info(F1_EM)
def train(): """Training function.""" segment = 'train' if not args.debug else 'dev' log.info('Loading %s data...', segment) if version_2: train_data = SQuAD(segment, version='2.0') else: train_data = SQuAD(segment, version='1.1') if args.debug: sampled_data = [train_data[i] for i in range(120)] # 1000 train_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in Train data:{}'.format(len(train_data))) train_dataset = train_data.transform(SQuADTransform( copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)._transform, lazy=False) train_data_transform, _ = preprocess_dataset( train_data, SQuADTransform(copy.copy(tokenizer), max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_pad=True, is_training=True)) log.info('The number of examples after preprocessing:{}'.format( len(train_data_transform))) # refer to evaluation process # for feat in train_dataset: # print(feat[0].example_id) # print(feat[0].tokens) # print(feat[0].token_to_orig_map) # input() # exit(0) train_features = { features[0].example_id: features for features in train_dataset } #for line in train_data_transform: # print(line) # input() train_dataloader = mx.gluon.data.DataLoader(train_data_transform, batchify_fn=batchify_fn, batch_size=batch_size, num_workers=4, shuffle=True) log.info('Start Training') optimizer_params = {'learning_rate': lr} try: trainer = mx.gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=False) except ValueError as e: print(e) warnings.warn( 'AdamW optimizer is not found. Please consider upgrading to ' 'mxnet>=1.5.0. Now the original Adam optimizer is used instead.') trainer = mx.gluon.Trainer(net.collect_params(), 'adam', optimizer_params, update_on_kvstore=False) num_train_examples = len(train_data_transform) step_size = batch_size * accumulate if accumulate else batch_size num_train_steps = int(num_train_examples / step_size * epochs) num_warmup_steps = int(num_train_steps * warmup_ratio) step_num = 0 def set_new_lr(step_num, batch_id): """set new learning rate""" # set grad to zero for gradient accumulation if accumulate: if batch_id % accumulate == 0: net.collect_params().zero_grad() step_num += 1 else: step_num += 1 # learning rate schedule # Notice that this learning rate scheduler is adapted from traditional linear learning # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps if step_num < num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = (step_num - num_warmup_steps) * lr / \ (num_train_steps - num_warmup_steps) new_lr = lr - offset trainer.set_learning_rate(new_lr) return step_num # Do not apply weight decay on LayerNorm and bias terms for _, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in net.collect_params().values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if accumulate: for p in params: p.grad_req = 'add' epoch_tic = time.time() total_num = 0 log_num = 0 for epoch_id in range(epochs): step_loss = 0.0 tic = time.time() for batch_id, data in enumerate(train_dataloader): # set new lr step_num = set_new_lr(step_num, batch_id) # forward and backward with mx.autograd.record(): example_ids, inputs, token_types, valid_length, start_label, end_label = data log_num += len(inputs) total_num += len(inputs) out = net( inputs.astype('float32').as_in_context(ctx), token_types.astype('float32').as_in_context(ctx), valid_length.astype('float32').as_in_context(ctx)) ls = loss_function(out, [ start_label.astype('float32').as_in_context(ctx), end_label.astype('float32').as_in_context(ctx) ]).mean() if accumulate: ls = ls / accumulate ls.backward() # update if not accumulate or (batch_id + 1) % accumulate == 0: trainer.allreduce_grads() nlp.utils.clip_grad_global_norm(params, 1) trainer.update(1) # pass the information to verifier and train it here # train_features # example_ids # out if args.verify: verifier.train(train_features, example_ids, out) step_loss += ls.asscalar() if (batch_id + 1) % log_interval == 0: toc = time.time() log.info( 'Epoch: {}, Batch: {}/{}, Loss={:.4f}, lr={:.7f} Time cost={:.1f} Thoughput={:.2f} samples/s' # pylint: disable=line-too-long .format(epoch_id, batch_id, len(train_dataloader), step_loss / log_interval, trainer.learning_rate, toc - tic, log_num / (toc - tic))) tic = time.time() step_loss = 0.0 log_num = 0 epoch_toc = time.time() log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format( epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))