def fit_batch(self, estimator, train_batch, batch_axis=0): """Trains the estimator model on a batch of training data. Parameters ---------- estimator : Estimator Reference to the estimator train_batch : tuple Data and label of a batch from the training data loader. batch_axis : int, default 0 Batch axis to split the training data into devices. Returns ------- data: List of NDArray Sharded data from the batch. Data is sharded with `gluon.split_and_load`. label: List of NDArray Sharded label from the batch. Labels are sharded with `gluon.split_and_load`. pred: List of NDArray Prediction on each of the sharded inputs. loss: List of NDArray Loss on each of the sharded inputs. """ # data = split_and_load(train_batch[0], ctx_list=estimator.context, batch_axis=0, even_split=False) # label = split_and_load(train_batch[1], ctx_list=estimator.context, batch_axis=0, even_split=False) # targets = list(zip(*[split_and_load(t, ctx_list=estimator.context, batch_axis=0, even_split=False) # for t in estimator.net.extract_training_targets(*train_batch)])) data, fixed_targets, gt_bboxes = self._get_data_and_label(train_batch, estimator.context) # fixed_targets = [split_and_load(train_batch[it], ctx_list=estimator.context, batch_axis=0) # for it in range(1, 7)] # gt_boxes = split_and_load(train_batch[7], ctx_list=estimator.context, batch_axis=0) with autograd.record(): # bbox, raw_box_centers, raw_box_scales, objness, class_pred preds = [estimator.net(x) for x in data] loss = [estimator.loss(*pred, *target, gt_bbox) for pred, target, gt_bbox in zip(preds, fixed_targets, gt_bboxes)] if amp._amp_initialized: with amp.scale_loss(loss, estimator.trainer) as scaled_loss: autograd.backward(scaled_loss) else: autograd.backward(loss) return data, fixed_targets, preds, loss
def train(metric): """Training function.""" if not only_inference: logging.info('Now we are doing BERT classification training on %s!', ctx) all_model_params = model.collect_params() optimizer_params = {'learning_rate': lr, 'epsilon': epsilon, 'wd': 0.01} trainer = gluon.Trainer(all_model_params, args.optimizer, optimizer_params, update_on_kvstore=False) if args.dtype == 'float16': amp.init_trainer(trainer) epoch_number = args.epochs step_size = batch_size * accumulate if accumulate else batch_size num_train_steps = int(num_train_examples / step_size * args.epochs) if args.training_steps: num_train_steps = args.training_steps epoch_number = 9999 logging.info('training steps=%d', num_train_steps) warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) step_num = 0 # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in all_model_params.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if accumulate and accumulate > 1: for p in params: p.grad_req = 'add' # track best eval score metric_history = [] best_metric = None patience = args.early_stop tic = time.time() finish_flag = False for epoch_id in range(epoch_number): if args.early_stop and patience == 0: logging.info('Early stopping at epoch %d', epoch_id) break if finish_flag: break if not only_inference: metric.reset() step_loss = 0 tic = time.time() all_model_params.zero_grad() for batch_id, seqs in enumerate(train_data): # learning rate schedule if step_num < num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: non_warmup_steps = step_num - num_warmup_steps offset = non_warmup_steps / (num_train_steps - num_warmup_steps) new_lr = lr - offset * lr trainer.set_learning_rate(new_lr) # forward and backward with mx.autograd.record(): input_ids, segment_ids, valid_length, label = seqs input_ids = input_ids.as_in_context(ctx) valid_length = valid_length.as_in_context(ctx).astype( 'float32') label = label.as_in_context(ctx) if use_roberta: out = model(input_ids, valid_length) else: out = model(input_ids, segment_ids.as_in_context(ctx), valid_length) ls = loss_function(out, label).mean() if args.dtype == 'float16': with amp.scale_loss(ls, trainer) as scaled_loss: mx.autograd.backward(scaled_loss) else: ls.backward() # update if not accumulate or (batch_id + 1) % accumulate == 0: trainer.allreduce_grads() nlp.utils.clip_grad_global_norm(params, 1) trainer.update(accumulate if accumulate else 1) step_num += 1 if accumulate and accumulate > 1: # set grad to zero for gradient accumulation all_model_params.zero_grad() step_loss += ls.asscalar() if not do_regression: label = label.reshape((-1)) metric.update([label], [out]) if (batch_id + 1) % (args.log_interval) == 0: log_train(batch_id, len(train_data), metric, step_loss, args.log_interval, epoch_id, trainer.learning_rate) step_loss = 0 if step_num >= num_train_steps: logging.info('Finish training step: %d', step_num) finish_flag = True break mx.nd.waitall() # inference on dev data for segment, dev_data in dev_data_list: metric_nm, metric_val = evaluate(dev_data, metric, segment) if best_metric is None or metric_val >= best_metric: best_metric = metric_val patience = args.early_stop else: if args.early_stop is not None: patience -= 1 metric_history.append((epoch_id, metric_nm, metric_val)) if not only_inference: # save params ckpt_name = 'model_bert_{0}_{1}.params'.format(task_name, epoch_id) params_saved = os.path.join(output_dir, ckpt_name) nlp.utils.save_parameters(model, params_saved) logging.info('params saved in: %s', params_saved) toc = time.time() logging.info('Time cost=%.2fs', toc - tic) tic = toc if not only_inference: # we choose the best model based on metric[0], # assuming higher score stands for better model quality metric_history.sort(key=lambda x: x[2][0], reverse=True) epoch_id, metric_nm, metric_val = metric_history[0] ckpt_name = 'model_bert_{0}_{1}.params'.format(task_name, epoch_id) params_saved = os.path.join(output_dir, ckpt_name) nlp.utils.load_parameters(model, params_saved) metric_str = 'Best model at epoch {}. Validation metrics:'.format( epoch_id) metric_str += ','.join([i + ':%.4f' for i in metric_nm]) logging.info(metric_str, *metric_val) # inference on test data for segment, test_data in test_data_list: test(test_data, segment)