Exemplo n.º 1
0
    def fit_batch(self, estimator,
                  train_batch,
                  batch_axis=0):
        """Trains the estimator model on a batch of training data.

        Parameters
        ----------
        estimator : Estimator
            Reference to the estimator
        train_batch : tuple
            Data and label of a batch from the training data loader.
        batch_axis : int, default 0
            Batch axis to split the training data into devices.

        Returns
        -------
        data: List of NDArray
            Sharded data from the batch. Data is sharded with
            `gluon.split_and_load`.
        label: List of NDArray
            Sharded label from the batch. Labels are sharded with
            `gluon.split_and_load`.
        pred: List of NDArray
            Prediction on each of the sharded inputs.
        loss: List of NDArray
            Loss on each of the sharded inputs.
        """
        # data = split_and_load(train_batch[0], ctx_list=estimator.context, batch_axis=0, even_split=False)
        # label = split_and_load(train_batch[1], ctx_list=estimator.context, batch_axis=0, even_split=False)
        # targets = list(zip(*[split_and_load(t, ctx_list=estimator.context, batch_axis=0, even_split=False)
        #                      for t in estimator.net.extract_training_targets(*train_batch)]))
        data, fixed_targets, gt_bboxes = self._get_data_and_label(train_batch, estimator.context)

        # fixed_targets = [split_and_load(train_batch[it], ctx_list=estimator.context, batch_axis=0)
        #                  for it in range(1, 7)]
        # gt_boxes = split_and_load(train_batch[7], ctx_list=estimator.context, batch_axis=0)

        with autograd.record():
            # bbox, raw_box_centers, raw_box_scales, objness, class_pred
            preds = [estimator.net(x) for x in data]
            loss = [estimator.loss(*pred, *target, gt_bbox) for pred, target, gt_bbox in
                    zip(preds, fixed_targets, gt_bboxes)]

            if amp._amp_initialized:
                with amp.scale_loss(loss, estimator.trainer) as scaled_loss:
                    autograd.backward(scaled_loss)
            else:
                autograd.backward(loss)

        return data, fixed_targets, preds, loss
Exemplo n.º 2
0
def train(metric):
    """Training function."""
    if not only_inference:
        logging.info('Now we are doing BERT classification training on %s!',
                     ctx)

    all_model_params = model.collect_params()
    optimizer_params = {'learning_rate': lr, 'epsilon': epsilon, 'wd': 0.01}
    trainer = gluon.Trainer(all_model_params,
                            args.optimizer,
                            optimizer_params,
                            update_on_kvstore=False)
    if args.dtype == 'float16':
        amp.init_trainer(trainer)

    epoch_number = args.epochs
    step_size = batch_size * accumulate if accumulate else batch_size
    num_train_steps = int(num_train_examples / step_size * args.epochs)
    if args.training_steps:
        num_train_steps = args.training_steps
        epoch_number = 9999

    logging.info('training steps=%d', num_train_steps)
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    step_num = 0

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in all_model_params.values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if accumulate and accumulate > 1:
        for p in params:
            p.grad_req = 'add'
    # track best eval score
    metric_history = []
    best_metric = None
    patience = args.early_stop

    tic = time.time()
    finish_flag = False
    for epoch_id in range(epoch_number):
        if args.early_stop and patience == 0:
            logging.info('Early stopping at epoch %d', epoch_id)
            break
        if finish_flag:
            break
        if not only_inference:
            metric.reset()
            step_loss = 0
            tic = time.time()
            all_model_params.zero_grad()

            for batch_id, seqs in enumerate(train_data):
                # learning rate schedule
                if step_num < num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    non_warmup_steps = step_num - num_warmup_steps
                    offset = non_warmup_steps / (num_train_steps -
                                                 num_warmup_steps)
                    new_lr = lr - offset * lr
                trainer.set_learning_rate(new_lr)

                # forward and backward
                with mx.autograd.record():
                    input_ids, segment_ids, valid_length, label = seqs
                    input_ids = input_ids.as_in_context(ctx)
                    valid_length = valid_length.as_in_context(ctx).astype(
                        'float32')
                    label = label.as_in_context(ctx)
                    if use_roberta:
                        out = model(input_ids, valid_length)
                    else:
                        out = model(input_ids, segment_ids.as_in_context(ctx),
                                    valid_length)
                    ls = loss_function(out, label).mean()
                    if args.dtype == 'float16':
                        with amp.scale_loss(ls, trainer) as scaled_loss:
                            mx.autograd.backward(scaled_loss)
                    else:
                        ls.backward()

                # update
                if not accumulate or (batch_id + 1) % accumulate == 0:
                    trainer.allreduce_grads()
                    nlp.utils.clip_grad_global_norm(params, 1)
                    trainer.update(accumulate if accumulate else 1)
                    step_num += 1
                    if accumulate and accumulate > 1:
                        # set grad to zero for gradient accumulation
                        all_model_params.zero_grad()

                step_loss += ls.asscalar()
                if not do_regression:
                    label = label.reshape((-1))
                metric.update([label], [out])
                if (batch_id + 1) % (args.log_interval) == 0:
                    log_train(batch_id, len(train_data), metric, step_loss,
                              args.log_interval, epoch_id,
                              trainer.learning_rate)
                    step_loss = 0
                if step_num >= num_train_steps:
                    logging.info('Finish training step: %d', step_num)
                    finish_flag = True
                    break
            mx.nd.waitall()

        # inference on dev data
        for segment, dev_data in dev_data_list:
            metric_nm, metric_val = evaluate(dev_data, metric, segment)
            if best_metric is None or metric_val >= best_metric:
                best_metric = metric_val
                patience = args.early_stop
            else:
                if args.early_stop is not None:
                    patience -= 1
            metric_history.append((epoch_id, metric_nm, metric_val))

        if not only_inference:
            # save params
            ckpt_name = 'model_bert_{0}_{1}.params'.format(task_name, epoch_id)
            params_saved = os.path.join(output_dir, ckpt_name)

            nlp.utils.save_parameters(model, params_saved)
            logging.info('params saved in: %s', params_saved)
            toc = time.time()
            logging.info('Time cost=%.2fs', toc - tic)
            tic = toc

    if not only_inference:
        # we choose the best model based on metric[0],
        # assuming higher score stands for better model quality
        metric_history.sort(key=lambda x: x[2][0], reverse=True)
        epoch_id, metric_nm, metric_val = metric_history[0]
        ckpt_name = 'model_bert_{0}_{1}.params'.format(task_name, epoch_id)
        params_saved = os.path.join(output_dir, ckpt_name)
        nlp.utils.load_parameters(model, params_saved)
        metric_str = 'Best model at epoch {}. Validation metrics:'.format(
            epoch_id)
        metric_str += ','.join([i + ':%.4f' for i in metric_nm])
        logging.info(metric_str, *metric_val)

    # inference on test data
    for segment, test_data in test_data_list:
        test(test_data, segment)