def _classification_regression_predict(net, dataloader, problem_type, has_label=True, extract_embedding=False): """ Parameters ---------- net The network dataloader The dataloader problem_type Types of the labels has_label Whether label is used extract_embedding Whether to extract the embedding Returns ------- predictions The predictions """ predictions = [] ctx_l = net.collect_params().list_ctx() for sample_l in grouper(dataloader, len(ctx_l)): iter_pred_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue if has_label: batch_feature, batch_label = sample else: batch_feature = sample batch_feature = move_to_ctx(batch_feature, ctx) if extract_embedding: _, embeddings = net(batch_feature) iter_pred_l.append(embeddings) else: pred = net(batch_feature) if problem_type == _C.CLASSIFICATION: pred = mx.npx.softmax(pred, axis=-1) iter_pred_l.append(pred) for pred in iter_pred_l: predictions.append(pred.asnumpy()) predictions = np.concatenate(predictions, axis=0) return predictions
def train_function(args, reporter, train_df_path, tuning_df_path, time_limits, time_start, base_config, problem_types, column_properties, label_columns, label_shapes, log_metrics, stopping_metric, console_log, ignore_warning=False): if time_limits is not None: start_train_tick = time.time() time_left = time_limits - (start_train_tick - time_start) if time_left <= 0: reporter.terminate() return import os # Get the log metric scorers if isinstance(log_metrics, str): log_metrics = [log_metrics] # Load the training and tuning data from the parquet file train_data = pd.read_parquet(train_df_path) tuning_data = pd.read_parquet(tuning_df_path) log_metric_scorers = [get_metric(ele) for ele in log_metrics] stopping_metric_scorer = get_metric(stopping_metric) greater_is_better = stopping_metric_scorer.greater_is_better os.environ['MKL_NUM_THREADS'] = '1' os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_DYNAMIC'] = 'FALSE' if ignore_warning: import warnings warnings.filterwarnings("ignore") search_space = args['search_space'] cfg = base_config.clone() specified_values = [] for key in search_space: specified_values.append(key) specified_values.append(search_space[key]) cfg.merge_from_list(specified_values) exp_dir = cfg.misc.exp_dir if reporter is not None: # When the reporter is not None, # we create the saved directory based on the task_id + time task_id = args.task_id exp_dir = os.path.join(exp_dir, 'task{}'.format(task_id)) os.makedirs(exp_dir, exist_ok=True) cfg.defrost() cfg.misc.exp_dir = exp_dir cfg.freeze() logger = logging.getLogger() logging_config(folder=exp_dir, name='training', logger=logger, console=console_log) logger.info(cfg) # Load backbone model backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \ = get_backbone(cfg.model.backbone.name) with open(os.path.join(exp_dir, 'cfg.yml'), 'w') as f: f.write(str(cfg)) text_backbone = backbone_model_cls.from_cfg(backbone_cfg) # Build Preprocessor + Preprocess the training dataset + Inference problem type # TODO Move preprocessor + Dataloader to outer loop to better cache the dataloader preprocessor = TabularBasicBERTPreprocessor( tokenizer=tokenizer, column_properties=column_properties, label_columns=label_columns, max_length=cfg.model.preprocess.max_length, merge_text=cfg.model.preprocess.merge_text) logger.info('Process training set...') processed_train = preprocessor.process_train(train_data) logger.info('Done!') logger.info('Process dev set...') processed_dev = preprocessor.process_test(tuning_data) logger.info('Done!') label = label_columns[0] # Get the ground-truth dev labels gt_dev_labels = np.array(tuning_data[label].apply( column_properties[label].transform)) ctx_l = get_mxnet_available_ctx() base_batch_size = cfg.optimization.per_device_batch_size num_accumulated = int( np.ceil(cfg.optimization.batch_size / base_batch_size)) inference_base_batch_size = base_batch_size * cfg.optimization.val_batch_size_mult train_dataloader = DataLoader( processed_train, batch_size=base_batch_size, shuffle=True, batchify_fn=preprocessor.batchify(is_test=False)) dev_dataloader = DataLoader( processed_dev, batch_size=inference_base_batch_size, shuffle=False, batchify_fn=preprocessor.batchify(is_test=True)) net = BERTForTabularBasicV1( text_backbone=text_backbone, feature_field_info=preprocessor.feature_field_info(), label_shape=label_shapes[0], cfg=cfg.model.network) net.initialize_with_pretrained_backbone(backbone_params_path, ctx=ctx_l) net.hybridize() num_total_params, num_total_fixed_params = count_parameters( net.collect_params()) logger.info('#Total Params/Fixed Params={}/{}'.format( num_total_params, num_total_fixed_params)) # Initialize the optimizer updates_per_epoch = int( len(train_dataloader) / (num_accumulated * len(ctx_l))) optimizer, optimizer_params, max_update \ = get_optimizer(cfg.optimization, updates_per_epoch=updates_per_epoch) valid_interval = math.ceil(cfg.optimization.valid_frequency * updates_per_epoch) train_log_interval = math.ceil(cfg.optimization.log_frequency * updates_per_epoch) trainer = mx.gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=False) if 0 < cfg.optimization.layerwise_lr_decay < 1: apply_layerwise_decay(net.text_backbone, cfg.optimization.layerwise_lr_decay, backbone_name=cfg.model.backbone.name) # Do not apply weight decay to all the LayerNorm and bias for _, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 params = [p for p in net.collect_params().values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if num_accumulated > 1: logger.info('Using gradient accumulation.' ' Global batch size = {}'.format( cfg.optimization.batch_size)) for p in params: p.grad_req = 'add' net.collect_params().zero_grad() train_loop_dataloader = grouper(repeat(train_dataloader), len(ctx_l)) log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l] log_num_samples_l = [0 for _ in ctx_l] logging_start_tick = time.time() best_performance_score = None mx.npx.waitall() no_better_rounds = 0 report_idx = 0 start_tick = time.time() if time_limits is not None: time_limits -= start_tick - time_start if time_limits <= 0: reporter.terminate() return best_report_items = None for update_idx in tqdm.tqdm(range(max_update), disable=None): num_samples_per_update_l = [0 for _ in ctx_l] for accum_idx in range(num_accumulated): sample_l = next(train_loop_dataloader) loss_l = [] num_samples_l = [0 for _ in ctx_l] for i, (sample, ctx) in enumerate(zip(sample_l, ctx_l)): feature_batch, label_batch = sample feature_batch = move_to_ctx(feature_batch, ctx) label_batch = move_to_ctx(label_batch, ctx) with mx.autograd.record(): pred = net(feature_batch) if problem_types[0] == _C.CLASSIFICATION: logits = mx.npx.log_softmax(pred, axis=-1) loss = -mx.npx.pick(logits, label_batch[0]) elif problem_types[0] == _C.REGRESSION: loss = mx.np.square(pred - label_batch[0]) loss_l.append(loss.mean() / len(ctx_l)) num_samples_l[i] = loss.shape[0] num_samples_per_update_l[i] += loss.shape[0] for loss in loss_l: loss.backward() for i in range(len(ctx_l)): log_loss_l[i] += loss_l[i] * len(ctx_l) * num_samples_l[i] log_num_samples_l[i] += num_samples_per_update_l[i] # Begin to update trainer.allreduce_grads() num_samples_per_update = sum(num_samples_per_update_l) total_norm, ratio, is_finite = \ clip_grad_global_norm(params, cfg.optimization.max_grad_norm * num_accumulated) total_norm = total_norm / num_accumulated trainer.update(num_samples_per_update) # Clear after update if num_accumulated > 1: net.collect_params().zero_grad() if (update_idx + 1) % train_log_interval == 0: log_loss = sum([ele.as_in_ctx(ctx_l[0]) for ele in log_loss_l]).asnumpy() log_num_samples = sum(log_num_samples_l) logger.info( '[Iter {}/{}, Epoch {}] train loss={:0.4e}, gnorm={:0.4e}, lr={:0.4e}, #samples processed={},' ' #sample per second={:.2f}'.format( update_idx + 1, max_update, int(update_idx / updates_per_epoch), log_loss / log_num_samples, total_norm, trainer.learning_rate, log_num_samples, log_num_samples / (time.time() - logging_start_tick))) logging_start_tick = time.time() log_loss_l = [ mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l ] log_num_samples_l = [0 for _ in ctx_l] if (update_idx + 1) % valid_interval == 0 or (update_idx + 1) == max_update: valid_start_tick = time.time() dev_predictions = \ _classification_regression_predict(net, dataloader=dev_dataloader, problem_type=problem_types[0], has_label=False) log_scores = [ calculate_metric(scorer, gt_dev_labels, dev_predictions, problem_types[0]) for scorer in log_metric_scorers ] dev_score = calculate_metric(stopping_metric_scorer, gt_dev_labels, dev_predictions, problem_types[0]) valid_time_spent = time.time() - valid_start_tick if best_performance_score is None or \ (greater_is_better and dev_score >= best_performance_score) or \ (not greater_is_better and dev_score <= best_performance_score): find_better = True no_better_rounds = 0 best_performance_score = dev_score net.save_parameters(os.path.join(exp_dir, 'best_model.params')) else: find_better = False no_better_rounds += 1 mx.npx.waitall() loss_string = ', '.join([ '{}={:0.4e}'.format(metric.name, score) for score, metric in zip(log_scores, log_metric_scorers) ]) logger.info('[Iter {}/{}, Epoch {}] valid {}, time spent={:.3f}s,' ' total_time={:.2f}min'.format( update_idx + 1, max_update, int(update_idx / updates_per_epoch), loss_string, valid_time_spent, (time.time() - start_tick) / 60)) report_items = [('iteration', update_idx + 1), ('report_idx', report_idx + 1), ('epoch', int(update_idx / updates_per_epoch))] +\ [(metric.name, score) for score, metric in zip(log_scores, log_metric_scorers)] + \ [('find_better', find_better), ('time_spent', int(time.time() - start_tick))] total_time_spent = time.time() - start_tick if stopping_metric_scorer._sign < 0: report_items.append(('reward_attr', -dev_score)) else: report_items.append(('reward_attr', dev_score)) report_items.append(('eval_metric', stopping_metric_scorer.name)) report_items.append(('exp_dir', exp_dir)) if find_better: best_report_items = report_items reporter(**dict(report_items)) report_idx += 1 if no_better_rounds >= cfg.learning.early_stopping_patience: logger.info('Early stopping patience reached!') break if time_limits is not None and total_time_spent > time_limits: break best_report_items_dict = dict(best_report_items) best_report_items_dict['report_idx'] = report_idx + 1 reporter(**best_report_items_dict)