Example #1
0
def _classification_regression_predict(net,
                                       dataloader,
                                       problem_type,
                                       has_label=True,
                                       extract_embedding=False):
    """

    Parameters
    ----------
    net
        The network
    dataloader
        The dataloader
    problem_type
        Types of the labels
    has_label
        Whether label is used
    extract_embedding
        Whether to extract the embedding

    Returns
    -------
    predictions
        The predictions
    """
    predictions = []
    ctx_l = net.collect_params().list_ctx()
    for sample_l in grouper(dataloader, len(ctx_l)):
        iter_pred_l = []
        for sample, ctx in zip(sample_l, ctx_l):
            if sample is None:
                continue
            if has_label:
                batch_feature, batch_label = sample
            else:
                batch_feature = sample
            batch_feature = move_to_ctx(batch_feature, ctx)
            if extract_embedding:
                _, embeddings = net(batch_feature)
                iter_pred_l.append(embeddings)
            else:
                pred = net(batch_feature)
                if problem_type == _C.CLASSIFICATION:
                    pred = mx.npx.softmax(pred, axis=-1)
                iter_pred_l.append(pred)
        for pred in iter_pred_l:
            predictions.append(pred.asnumpy())
    predictions = np.concatenate(predictions, axis=0)
    return predictions
Example #2
0
def train_function(args,
                   reporter,
                   train_df_path,
                   tuning_df_path,
                   time_limits,
                   time_start,
                   base_config,
                   problem_types,
                   column_properties,
                   label_columns,
                   label_shapes,
                   log_metrics,
                   stopping_metric,
                   console_log,
                   ignore_warning=False):
    if time_limits is not None:
        start_train_tick = time.time()
        time_left = time_limits - (start_train_tick - time_start)
        if time_left <= 0:
            reporter.terminate()
            return
    import os
    # Get the log metric scorers
    if isinstance(log_metrics, str):
        log_metrics = [log_metrics]
    # Load the training and tuning data from the parquet file
    train_data = pd.read_parquet(train_df_path)
    tuning_data = pd.read_parquet(tuning_df_path)
    log_metric_scorers = [get_metric(ele) for ele in log_metrics]
    stopping_metric_scorer = get_metric(stopping_metric)
    greater_is_better = stopping_metric_scorer.greater_is_better
    os.environ['MKL_NUM_THREADS'] = '1'
    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_DYNAMIC'] = 'FALSE'
    if ignore_warning:
        import warnings
        warnings.filterwarnings("ignore")
    search_space = args['search_space']
    cfg = base_config.clone()
    specified_values = []
    for key in search_space:
        specified_values.append(key)
        specified_values.append(search_space[key])
    cfg.merge_from_list(specified_values)
    exp_dir = cfg.misc.exp_dir
    if reporter is not None:
        # When the reporter is not None,
        # we create the saved directory based on the task_id + time
        task_id = args.task_id
        exp_dir = os.path.join(exp_dir, 'task{}'.format(task_id))
        os.makedirs(exp_dir, exist_ok=True)
        cfg.defrost()
        cfg.misc.exp_dir = exp_dir
        cfg.freeze()
    logger = logging.getLogger()
    logging_config(folder=exp_dir,
                   name='training',
                   logger=logger,
                   console=console_log)
    logger.info(cfg)
    # Load backbone model
    backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
        = get_backbone(cfg.model.backbone.name)
    with open(os.path.join(exp_dir, 'cfg.yml'), 'w') as f:
        f.write(str(cfg))
    text_backbone = backbone_model_cls.from_cfg(backbone_cfg)
    # Build Preprocessor + Preprocess the training dataset + Inference problem type
    # TODO Move preprocessor + Dataloader to outer loop to better cache the dataloader
    preprocessor = TabularBasicBERTPreprocessor(
        tokenizer=tokenizer,
        column_properties=column_properties,
        label_columns=label_columns,
        max_length=cfg.model.preprocess.max_length,
        merge_text=cfg.model.preprocess.merge_text)
    logger.info('Process training set...')
    processed_train = preprocessor.process_train(train_data)
    logger.info('Done!')
    logger.info('Process dev set...')
    processed_dev = preprocessor.process_test(tuning_data)
    logger.info('Done!')
    label = label_columns[0]
    # Get the ground-truth dev labels
    gt_dev_labels = np.array(tuning_data[label].apply(
        column_properties[label].transform))
    ctx_l = get_mxnet_available_ctx()
    base_batch_size = cfg.optimization.per_device_batch_size
    num_accumulated = int(
        np.ceil(cfg.optimization.batch_size / base_batch_size))
    inference_base_batch_size = base_batch_size * cfg.optimization.val_batch_size_mult
    train_dataloader = DataLoader(
        processed_train,
        batch_size=base_batch_size,
        shuffle=True,
        batchify_fn=preprocessor.batchify(is_test=False))
    dev_dataloader = DataLoader(
        processed_dev,
        batch_size=inference_base_batch_size,
        shuffle=False,
        batchify_fn=preprocessor.batchify(is_test=True))
    net = BERTForTabularBasicV1(
        text_backbone=text_backbone,
        feature_field_info=preprocessor.feature_field_info(),
        label_shape=label_shapes[0],
        cfg=cfg.model.network)
    net.initialize_with_pretrained_backbone(backbone_params_path, ctx=ctx_l)
    net.hybridize()
    num_total_params, num_total_fixed_params = count_parameters(
        net.collect_params())
    logger.info('#Total Params/Fixed Params={}/{}'.format(
        num_total_params, num_total_fixed_params))
    # Initialize the optimizer
    updates_per_epoch = int(
        len(train_dataloader) / (num_accumulated * len(ctx_l)))
    optimizer, optimizer_params, max_update \
        = get_optimizer(cfg.optimization,
                        updates_per_epoch=updates_per_epoch)
    valid_interval = math.ceil(cfg.optimization.valid_frequency *
                               updates_per_epoch)
    train_log_interval = math.ceil(cfg.optimization.log_frequency *
                                   updates_per_epoch)
    trainer = mx.gluon.Trainer(net.collect_params(),
                               optimizer,
                               optimizer_params,
                               update_on_kvstore=False)
    if 0 < cfg.optimization.layerwise_lr_decay < 1:
        apply_layerwise_decay(net.text_backbone,
                              cfg.optimization.layerwise_lr_decay,
                              backbone_name=cfg.model.backbone.name)
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if num_accumulated > 1:
        logger.info('Using gradient accumulation.'
                    ' Global batch size = {}'.format(
                        cfg.optimization.batch_size))
        for p in params:
            p.grad_req = 'add'
        net.collect_params().zero_grad()
    train_loop_dataloader = grouper(repeat(train_dataloader), len(ctx_l))
    log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
    log_num_samples_l = [0 for _ in ctx_l]
    logging_start_tick = time.time()
    best_performance_score = None
    mx.npx.waitall()
    no_better_rounds = 0
    report_idx = 0
    start_tick = time.time()
    if time_limits is not None:
        time_limits -= start_tick - time_start
        if time_limits <= 0:
            reporter.terminate()
            return
    best_report_items = None
    for update_idx in tqdm.tqdm(range(max_update), disable=None):
        num_samples_per_update_l = [0 for _ in ctx_l]
        for accum_idx in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            loss_l = []
            num_samples_l = [0 for _ in ctx_l]
            for i, (sample, ctx) in enumerate(zip(sample_l, ctx_l)):
                feature_batch, label_batch = sample
                feature_batch = move_to_ctx(feature_batch, ctx)
                label_batch = move_to_ctx(label_batch, ctx)
                with mx.autograd.record():
                    pred = net(feature_batch)
                    if problem_types[0] == _C.CLASSIFICATION:
                        logits = mx.npx.log_softmax(pred, axis=-1)
                        loss = -mx.npx.pick(logits, label_batch[0])
                    elif problem_types[0] == _C.REGRESSION:
                        loss = mx.np.square(pred - label_batch[0])
                    loss_l.append(loss.mean() / len(ctx_l))
                    num_samples_l[i] = loss.shape[0]
                    num_samples_per_update_l[i] += loss.shape[0]
            for loss in loss_l:
                loss.backward()
            for i in range(len(ctx_l)):
                log_loss_l[i] += loss_l[i] * len(ctx_l) * num_samples_l[i]
                log_num_samples_l[i] += num_samples_per_update_l[i]
        # Begin to update
        trainer.allreduce_grads()
        num_samples_per_update = sum(num_samples_per_update_l)
        total_norm, ratio, is_finite = \
            clip_grad_global_norm(params, cfg.optimization.max_grad_norm * num_accumulated)
        total_norm = total_norm / num_accumulated
        trainer.update(num_samples_per_update)

        # Clear after update
        if num_accumulated > 1:
            net.collect_params().zero_grad()
        if (update_idx + 1) % train_log_interval == 0:
            log_loss = sum([ele.as_in_ctx(ctx_l[0])
                            for ele in log_loss_l]).asnumpy()
            log_num_samples = sum(log_num_samples_l)
            logger.info(
                '[Iter {}/{}, Epoch {}] train loss={:0.4e}, gnorm={:0.4e}, lr={:0.4e}, #samples processed={},'
                ' #sample per second={:.2f}'.format(
                    update_idx + 1, max_update,
                    int(update_idx / updates_per_epoch),
                    log_loss / log_num_samples, total_norm,
                    trainer.learning_rate, log_num_samples,
                    log_num_samples / (time.time() - logging_start_tick)))
            logging_start_tick = time.time()
            log_loss_l = [
                mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l
            ]
            log_num_samples_l = [0 for _ in ctx_l]
        if (update_idx + 1) % valid_interval == 0 or (update_idx +
                                                      1) == max_update:
            valid_start_tick = time.time()
            dev_predictions = \
                _classification_regression_predict(net, dataloader=dev_dataloader,
                                                   problem_type=problem_types[0],
                                                   has_label=False)
            log_scores = [
                calculate_metric(scorer, gt_dev_labels, dev_predictions,
                                 problem_types[0])
                for scorer in log_metric_scorers
            ]
            dev_score = calculate_metric(stopping_metric_scorer, gt_dev_labels,
                                         dev_predictions, problem_types[0])
            valid_time_spent = time.time() - valid_start_tick

            if best_performance_score is None or \
                    (greater_is_better and dev_score >= best_performance_score) or \
                    (not greater_is_better and dev_score <= best_performance_score):
                find_better = True
                no_better_rounds = 0
                best_performance_score = dev_score
                net.save_parameters(os.path.join(exp_dir, 'best_model.params'))
            else:
                find_better = False
                no_better_rounds += 1
            mx.npx.waitall()
            loss_string = ', '.join([
                '{}={:0.4e}'.format(metric.name, score)
                for score, metric in zip(log_scores, log_metric_scorers)
            ])
            logger.info('[Iter {}/{}, Epoch {}] valid {}, time spent={:.3f}s,'
                        ' total_time={:.2f}min'.format(
                            update_idx + 1, max_update,
                            int(update_idx / updates_per_epoch), loss_string,
                            valid_time_spent, (time.time() - start_tick) / 60))
            report_items = [('iteration', update_idx + 1),
                            ('report_idx', report_idx + 1),
                            ('epoch', int(update_idx / updates_per_epoch))] +\
                           [(metric.name, score)
                            for score, metric in zip(log_scores, log_metric_scorers)] + \
                           [('find_better', find_better),
                            ('time_spent', int(time.time() - start_tick))]
            total_time_spent = time.time() - start_tick

            if stopping_metric_scorer._sign < 0:
                report_items.append(('reward_attr', -dev_score))
            else:
                report_items.append(('reward_attr', dev_score))
            report_items.append(('eval_metric', stopping_metric_scorer.name))
            report_items.append(('exp_dir', exp_dir))
            if find_better:
                best_report_items = report_items
            reporter(**dict(report_items))
            report_idx += 1
            if no_better_rounds >= cfg.learning.early_stopping_patience:
                logger.info('Early stopping patience reached!')
                break
            if time_limits is not None and total_time_spent > time_limits:
                break

    best_report_items_dict = dict(best_report_items)
    best_report_items_dict['report_idx'] = report_idx + 1
    reporter(**best_report_items_dict)