def evaluate(loader_dev, metric, segment): """Evaluate the model on validation dataset.""" logging.info('Now we are doing evaluation on %s with %s.', segment, ctxs) metric.reset() step_loss = 0 tic = time.time() out_list = [] label_list = [] for batch_id, seqs in enumerate(loader_dev): batch_loss = [] # forward and backward data_list = list(split_and_load(seqs, ctxs)) for splited_data in data_list: input_ids, valid_length, segment_ids, label = splited_data label = label.reshape((-1)) out = model(input_ids, segment_ids, valid_length=valid_length) out_list.append(out.as_in_context(mx.cpu(0))) label_list.append(label.as_in_context(mx.cpu(0))) batch_loss.append(loss_function(out, label).mean() / len(ctxs)) batch_loss = sum([ls.asscalar() for ls in batch_loss]) step_loss += batch_loss if (batch_id + 1) % (args.log_interval) == 0: log_eval(batch_id, len(loader_dev), step_loss, args.log_interval) step_loss = 0 label_list = mx.nd.concat(*label_list, dim=0) out_list = mx.nd.concat(*out_list, dim=0) metric.update([label_list], [out_list]) metric_nm, metric_val = log_metric(metric, is_training=False) mx.nd.waitall() toc = time.time() logging.info('Time cost=%.2fs, throughput=%.2f samples/s', toc - tic, args.dev_batch_size * len(loader_dev) / (toc - tic)) return metric_nm, metric_val
def test(loader_test, segment): """Inference function on the test dataset.""" logging.info('Now we are doing testing on %s with %s.', segment, ctxs) tic = time.time() results = [] for _, seqs in enumerate(loader_test): #input_ids, valid_length, segment_ids = seqs data_list = list(split_and_load(seqs, ctxs)) out_list = [] for splited_data in data_list: input_ids, valid_length, segment_ids = splited_data out = model(input_ids, segment_ids, valid_length=valid_length) out_list.append(out) out_list = np.vstack([o.asnumpy() for o in out_list]) if not task.class_labels: # regression task for result in out_list.reshape(-1).tolist(): results.append('{:.3f}'.format(result)) else: # classification task out = out_list.reshape(-1, out_list.shape[-1]) indices = out.argmax(axis=-1) for index in indices: results.append(task.class_labels[int(index)]) mx.nd.waitall() toc = time.time() logging.info('Time cost=%.2fs, throughput=%.2f samples/s', toc - tic, args.dev_batch_size * len(loader_test) / (toc - tic)) # write result to a file. segment = segment.replace('_mismatched', '-mm') segment = segment.replace('_matched', '-m') segment = segment.replace('SST', 'SST-2') filename = args.task_name + segment.replace('test', '') + '.tsv' test_path = os.path.join(args.output_dir, filename) with io.open(test_path, 'w', encoding='utf-8') as f: f.write(u'index\tprediction\n') for i, pred in enumerate(results): f.write(u'%d\t%s\n' % (i, str(pred)))
def train(metric): """Training function.""" if not args.only_inference: logging.info('Now we are doing XLNet classification training on %s!', ctxs) all_model_params = model.collect_params() optimizer_params = { 'learning_rate': args.lr, 'epsilon': args.epsilon, 'wd': 0 } trainer = gluon.Trainer(all_model_params, args.optimizer, optimizer_params, update_on_kvstore=False) step_size = args.batch_size * args.accumulate if args.accumulate else args.batch_size num_train_steps = int(num_train_examples / step_size * args.epochs) epoch_number = args.epochs if args.training_steps: num_train_steps = args.training_steps epoch_number = 9999 logging.info('training steps=%d', num_train_steps) warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) step_num = 0 # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in all_model_params.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if args.accumulate and args.accumulate > 1: for p in params: p.grad_req = 'add' # track best eval score metric_history = [] best_metric = None patience = args.early_stop tic = time.time() finish_flag = False for epoch_id in range(epoch_number): if args.early_stop and patience == 0: logging.info('Early stopping at epoch %d', epoch_id) break if finish_flag: break if not args.only_inference: metric.reset() step_loss = 0 tic = time.time() all_model_params.zero_grad() for batch_id, seqs in enumerate(train_data): new_lr = args.lr # learning rate schedule if step_num < num_warmup_steps: new_lr = args.lr * step_num / num_warmup_steps elif args.lr_decay == 'linear': non_warmup_steps = step_num - num_warmup_steps offset = non_warmup_steps / (num_train_steps - num_warmup_steps) new_lr = max(0, args.lr - offset * args.lr) trainer.set_learning_rate(new_lr) batch_loss = [] # forward and backward with mx.autograd.record(): data_list = list(split_and_load(seqs, ctxs)) for splited_data in data_list: input_ids, valid_length, segment_ids, label = splited_data out = model(input_ids, segment_ids, valid_length=valid_length) ls = loss_function(out, label).mean() / len(ctxs) batch_loss.append(ls) if args.accumulate: ls = ls / args.accumulate ls.backward() # update if not args.accumulate or (batch_id + 1) % args.accumulate == 0: trainer.allreduce_grads() nlp.utils.clip_grad_global_norm(params, 1) trainer.update(args.accumulate if args.accumulate else 1, ignore_stale_grad=True) step_num += 1 if args.accumulate and args.accumulate > 1: # set grad to zero for gradient accumulation all_model_params.zero_grad() if batch_id == 0 and epoch_id == 0: toc = time.time() logging.info( 'Time cost for the first forward-backward =%.2fs', toc - tic) batch_loss = sum([ls.asscalar() for ls in batch_loss]) step_loss += batch_loss if (batch_id + 1) % (args.log_interval) == 0: log_train(batch_id, len(train_data), step_loss, args.log_interval, epoch_id, trainer.learning_rate) step_loss = 0 if step_num >= num_train_steps: logging.info('Finish training step: %d', step_num) finish_flag = True break mx.nd.waitall() # inference on dev data for segment, dev_data in dev_data_list: metric_nm, metric_val = evaluate(dev_data, metric, segment) if best_metric is None or metric_val >= best_metric: best_metric = metric_val patience = args.early_stop else: if args.early_stop is not None: patience -= 1 metric_history.append((epoch_id, metric_nm, metric_val)) if not args.only_inference: # save params ckpt_name = 'model_xlnet_{0}_{1}.params'.format( args.task_name, epoch_id) params_saved = os.path.join(output_dir, ckpt_name) nlp.utils.save_parameters(model, params_saved) logging.info('params saved in: %s', params_saved) toc = time.time() logging.info('Time cost=%.2fs', toc - tic) tic = toc if not args.only_inference: # we choose the best model based on metric[0], # assuming higher score stands for better model quality metric_history.sort(key=lambda x: x[2][0], reverse=True) epoch_id, metric_nm, metric_val = metric_history[0] ckpt_name = 'model_xlnet_{0}_{1}.params'.format( args.task_name, epoch_id) params_saved = os.path.join(output_dir, ckpt_name) nlp.utils.load_parameters(model, params_saved) metric_str = 'Best model at epoch {}. Validation metrics:'.format( epoch_id + 1) metric_str += ','.join([i + ':%.4f' for i in metric_nm]) logging.info(metric_str, *metric_val) # inference on test data for segment, test_data in test_data_list: test(test_data, segment) print('finish test!')