Example #1
0
    def fit(cls,
            train_data,
            label,
            tuning_data=None,
            time_limits=None,
            output_directory='./ag_text',
            feature_columns=None,
            holdout_frac=None,
            eval_metric=None,
            stopping_metric=None,
            nthreads_per_trial=None,
            ngpus_per_trial=None,
            dist_ip_addrs=None,
            scheduler=None,
            num_trials=None,
            search_strategy=None,
            search_options=None,
            hyperparameters=None,
            plot_results=None,
            seed=None,
            verbosity=2):
        """Fit models to make predictions based on text inputs.

        Parameters
        ----------
        train_data : :class:`autogluon.task.tabular_prediction.TabularDataset` or `pandas.DataFrame`
            Training dataset where rows = individual training examples, columns = features.
        label : str
            Name of the label column. It can be a stringBy default, we will search for a column named
        tuning_data : :class:`autogluon.task.tabular_prediction.TabularDataset` or `pandas.DataFrame`, default = None
            Another dataset containing validation data reserved for hyperparameter tuning (in same format as training data).
            If `tuning_data = None`, `fit()` will automatically hold out random examples from `train_data` for validation.
        time_limits : int or str, default = None
            Approximately how long `fit()` should run for (wallclock time in seconds if int).
            String values may instead be used to specify time in different units such as: '1min' or '1hour'.
            Longer `time_limits` will usually improve predictive accuracy.
            If not specified, `fit()` will run until all models to try by default have completed training.
        output_directory : str, default = './ag_text'
            Path to directory where models and intermediate outputs should be saved.
        feature_columns : List[str], default = None
            Which columns of table to consider as predictive features (other columns will be ignored, except for label-column).
            If None (by default), all columns of table are considered predictive features.
        holdout_frac : float, default = None
            Fraction of train_data to holdout as tuning data for optimizing hyperparameters (ignored unless `tuning_data = None`).
            If None, default value is selected based on the number of training examples.
        eval_metric : str, default = None
            The evaluation metric that will be used to evaluate the model's predictive performance.
            If None, an appropriate default metric will be selected (accuracy for classification, mean-squared-error for regression).
            Options for classification include: 'acc' (accuracy), 'nll' (negative log-likelihood).
            Additional options for binary classification include: 'f1' (F1 score), 'mcc' (Matthews coefficient), 'auc' (area under ROC curve).
            Options for regression include: 'mse' (mean squared error), 'rmse' (root mean squared error), 'mae' (mean absolute error).
        stopping_metric, default = None
            Metric which iteratively-trained models use to early stop to avoid overfitting.
            Defaults to `eval_metric` value (if None).
            Options are identical to options for `eval_metric`.
        nthreads_per_trial, default = None
            The number of threads per individual model training run. By default, all available CPUs are used.
        ngpus_per_trial, default = None
            The number of GPUs to use per individual model training run. If unspecified, a default value is chosen based on total number of GPUs available.
        dist_ip_addrs, default = None
            List of IP addresses corresponding to remote workers, in order to leverage distributed computation.
        scheduler : str, default = None
            Controls scheduling of model training runs during HPO.
            Options include: 'fifo' (first in first out) or 'hyperband'.
            If unspecified, the default is 'fifo'.
        num_trials : , default = None
            The number of trials in the HPO search
        search_strategy : str, default = None
            The search strategy
        search_options : , default = None
            Which hyperparameter search algorithm to use (only matters if `hyperparameter_tune=True`).
            Options include: 'random' (random search), 'bayesopt' (Gaussian process Bayesian optimization), 'skopt' (SKopt Bayesian optimization), 'grid' (grid search).
        hyperparameters : dict, default = None
            Determines the hyperparameters used by the models. Each hyperparameter may be either fixed value or search space of many values.
            For example of default hyperparameters, see: `autogluon.task.text_prediction.text_prediction.default()`
        plot_results : bool, default = None
            Whether or not to plot intermediate training results during `fit()`.
        seed : int, default = None
            Seed value for random state used inside `fit()`. 
        verbosity : int, default = 2
            Verbosity levels range from 0 to 4 and control how much information is printed
            during fit().
            Higher levels correspond to more detailed print statements
            (you can set verbosity = 0 to suppress warnings).
            If using logging, you can alternatively control amount of information printed
            via `logger.setLevel(L)`,
            where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print
            statements, opposite of verbosity levels)

        Returns
        -------
        model
            A `BertForTextPredictionBasic` object that can be used for making predictions on new data.
        """
        assert dist_ip_addrs is None, 'Training on remote machine is currently not supported.'
        # Version check of MXNet
        if version.parse(mxnet.__version__) < version.parse('1.7.0') \
                or version.parse(mxnet.__version__) >= version.parse('2.0.0'):
            raise ImportError(
                'You will need to ensure that you have mxnet>=1.7.0, <2.0.0. '
                'For more information about how to install mxnet, you can refer to '
                'https://sxjscience.github.io/KDD2020/ .')

        if verbosity < 0:
            verbosity = 0
        elif verbosity > 4:
            verbosity = 4
        console_log = verbosity >= 2
        logging_config(folder=output_directory,
                       name='ag_text_prediction',
                       logger=logger,
                       level=verbosity2loglevel(verbosity),
                       console=console_log)
        # Parse the hyper-parameters
        if hyperparameters is None:
            hyperparameters = ag_text_prediction_params.create('default')
        elif isinstance(hyperparameters, str):
            hyperparameters = ag_text_prediction_params.create(hyperparameters)
        else:
            base_params = ag_text_prediction_params.create('default')
            hyperparameters = merge_params(base_params, hyperparameters)
        np.random.seed(seed)
        if not isinstance(train_data, pd.DataFrame):
            train_data = load_pd.load(train_data)
        # Inference the label
        if not isinstance(label, list):
            label = [label]
        label_columns = []
        for ele in label:
            if isinstance(ele, int):
                label_columns.append(train_data.columns[ele])
            else:
                label_columns.append(ele)
        if feature_columns is None:
            all_columns = list(train_data.columns)
            feature_columns = [
                ele for ele in all_columns if ele not in label_columns
            ]
        else:
            if isinstance(feature_columns, str):
                feature_columns = [feature_columns]
            for col in feature_columns:
                assert col not in label_columns, 'Feature columns and label columns cannot overlap.'
                assert col in train_data.columns,\
                    'Feature columns must be in the pandas dataframe! Received col = "{}", ' \
                    'all columns = "{}"'.format(col, train_data.columns)
            all_columns = feature_columns + label_columns
            all_columns = [
                ele for ele in train_data.columns if ele in all_columns
            ]
        if tuning_data is None:
            if holdout_frac is None:
                holdout_frac = default_holdout_frac(len(train_data), True)
            train_data, tuning_data = random_split_train_val(
                train_data, valid_ratio=holdout_frac)
        else:
            if not isinstance(tuning_data, pd.DataFrame):
                tuning_data = load_pd.load(tuning_data)
        train_data = train_data[all_columns]
        tuning_data = tuning_data[all_columns]
        column_properties = get_column_properties(
            pd.concat([train_data, tuning_data]),
            metadata=None,
            label_columns=label_columns,
            provided_column_properties=None,
            categorical_default_handle_missing_value=True)
        train_data = TabularDataset(train_data,
                                    column_properties=column_properties,
                                    label_columns=label_columns)
        tuning_data = TabularDataset(
            tuning_data,
            column_properties=train_data.column_properties,
            label_columns=label_columns)

        logger.info('Train Dataset:')
        logger.info(train_data)
        logger.info('Tuning Dataset:')
        logger.info(tuning_data)
        logger.debug('Hyperparameters:')
        logger.debug(hyperparameters)
        has_text_column = False
        for k, v in column_properties.items():
            if v.type == _C.TEXT:
                has_text_column = True
                break
        if not has_text_column:
            raise NotImplementedError('No Text Column is found! This is currently not supported by '
                                      'the TextPrediction task. You may try to use '
                                      'TabularPrediction.fit().\n' \
                                      'The inferred column properties of the training data is {}'
                                      .format(train_data))
        problem_types = []
        label_shapes = []
        for label_col_name in label_columns:
            problem_type, label_shape = infer_problem_type(
                column_properties=column_properties,
                label_col_name=label_col_name)
            problem_types.append(problem_type)
            label_shapes.append(label_shape)
        logging.info(
            'Label columns={}, Feature columns={}, Problem types={}, Label shapes={}'
            .format(label_columns, feature_columns, problem_types,
                    label_shapes))
        eval_metric, stopping_metric, log_metrics =\
            infer_eval_stop_log_metrics(problem_types[0],
                                        label_shapes[0],
                                        eval_metric=eval_metric,
                                        stopping_metric=stopping_metric)
        logging.info('Eval Metric={}, Stop Metric={}, Log Metrics={}'.format(
            eval_metric, stopping_metric, log_metrics))
        model_candidates = []
        for model_type, kwargs in hyperparameters['models'].items():
            search_space = kwargs['search_space']
            if model_type == 'BertForTextPredictionBasic':
                model = BertForTextPredictionBasic(
                    column_properties=column_properties,
                    label_columns=label_columns,
                    feature_columns=feature_columns,
                    label_shapes=label_shapes,
                    problem_types=problem_types,
                    stopping_metric=stopping_metric,
                    log_metrics=log_metrics,
                    base_config=None,
                    search_space=search_space,
                    output_directory=output_directory,
                    logger=logger)
                model_candidates.append(model)
            else:
                raise ValueError(
                    'model_type = "{}" is not supported. You can try to use '
                    'model_type = "BertForTextPredictionBasic"'.format(
                        model_type))
        assert len(
            model_candidates) == 1, 'Only one model is supported currently'
        recommended_resource = get_recommended_resource(
            nthreads_per_trial=nthreads_per_trial,
            ngpus_per_trial=ngpus_per_trial)
        if scheduler is None:
            scheduler = hyperparameters['hpo_params']['scheduler']
        if search_strategy is None:
            search_strategy = hyperparameters['hpo_params']['search_strategy']
        if time_limits is None:
            time_limits = hyperparameters['hpo_params']['time_limits']
        else:
            if isinstance(time_limits, str):
                if time_limits.endswith('min'):
                    time_limits = int(float(time_limits[:-3]) * 60)
                elif time_limits.endswith('hour'):
                    time_limits = int(float(time_limits[:-4]) * 60 * 60)
                else:
                    raise ValueError(
                        'The given time_limits="{}" cannot be parsed!'.format(
                            time_limits))
        if num_trials is None:
            num_trials = hyperparameters['hpo_params']['num_trials']

        # Setting the HPO-specific parameters.
        reduction_factor = hyperparameters['hpo_params']['reduction_factor']
        grace_period = hyperparameters['hpo_params']['grace_period']
        max_t = hyperparameters['hpo_params']['max_t']
        if recommended_resource['num_gpus'] == 0:
            warnings.warn(
                'Recommend to use GPU to run the TextPrediction task!')
        model = model_candidates[0]
        if plot_results is None:
            if in_ipynb():
                plot_results = True
            else:
                plot_results = False
        model.train(train_data=train_data,
                    tuning_data=tuning_data,
                    resource=recommended_resource,
                    time_limits=time_limits,
                    scheduler=scheduler,
                    searcher=search_strategy,
                    num_trials=num_trials,
                    reduction_factor=reduction_factor,
                    grace_period=grace_period,
                    max_t=max_t,
                    plot_results=plot_results,
                    console_log=verbosity > 2,
                    ignore_warning=verbosity <= 2)
        return model
Example #2
0
 def train(self,
           train_data,
           tuning_data,
           resource,
           time_limits=None,
           search_strategy='random',
           search_options=None,
           scheduler_options=None,
           num_trials=None,
           plot_results=False,
           console_log=True,
           ignore_warning=True,
           verbosity=2):
     force_forkserver()
     start_tick = time.time()
     logging_config(folder=self._output_directory,
                    name='main',
                    console=console_log,
                    logger=self._logger)
     assert len(self._label_columns) == 1
     # TODO(sxjscience) Try to support S3
     os.makedirs(self._output_directory, exist_ok=True)
     search_space_reg = args(search_space=space.Dict(**self.search_space))
     # Scheduler and searcher for HPO
     if scheduler_options is None:
         scheduler_options = dict()
     scheduler_options = compile_scheduler_options(
         scheduler_options=scheduler_options,
         search_strategy=search_strategy,
         search_options=search_options,
         nthreads_per_trial=resource['num_cpus'],
         ngpus_per_trial=resource['num_gpus'],
         checkpoint=os.path.join(self._output_directory, 'checkpoint.ag'),
         num_trials=num_trials,
         time_out=time_limits,
         resume=False,
         visualizer=scheduler_options.get('visualizer'),
         time_attr='report_idx',
         reward_attr='reward_attr',
         dist_ip_addrs=scheduler_options.get('dist_ip_addrs'))
     # Create a temporary cache file and then ask the inner function to load the
     # temporary cache.
     train_df_path = os.path.join(self._output_directory,
                                  'cache_train_dataframe.pq')
     tuning_df_path = os.path.join(self._output_directory,
                                   'cache_tuning_dataframe.pq')
     train_data.table.to_parquet(train_df_path)
     tuning_data.table.to_parquet(tuning_df_path)
     train_fn = search_space_reg(
         functools.partial(train_function,
                           train_df_path=train_df_path,
                           time_limits=time_limits,
                           time_start=start_tick,
                           tuning_df_path=tuning_df_path,
                           base_config=self.base_config,
                           problem_types=self.problem_types,
                           column_properties=self._column_properties,
                           label_columns=self._label_columns,
                           label_shapes=self._label_shapes,
                           log_metrics=self._log_metrics,
                           stopping_metric=self._stopping_metric,
                           console_log=console_log,
                           ignore_warning=ignore_warning))
     scheduler_cls = schedulers[search_strategy.lower()]
     # Create scheduler, run HPO experiment
     scheduler = scheduler_cls(train_fn, **scheduler_options)
     scheduler.run()
     scheduler.join_jobs()
     if len(scheduler.config_history) == 0:
         raise RuntimeError(
             'No training job has been completed! '
             'There are two possibilities: '
             '1) The time_limits is too small, '
             'or 2) There are some internal errors in AutoGluon. '
             'For the first case, you can increase the time_limits or set it to '
             'None, e.g., setting "TextPrediction.fit(..., time_limits=None). To '
             'further investigate the root cause, you can also try to train with '
             '"verbosity=3", i.e., TextPrediction.fit(..., verbosity=3).')
     best_config = scheduler.get_best_config()
     if verbosity >= 2:
         self._logger.info('Results=', scheduler.searcher._results)
         self._logger.info('Best_config={}'.format(best_config))
     best_task_id = scheduler.get_best_task_id()
     best_model_saved_dir_path = os.path.join(self._output_directory,
                                              'task{}'.format(best_task_id))
     best_cfg_path = os.path.join(best_model_saved_dir_path, 'cfg.yml')
     cfg = self.base_config.clone_merge(best_cfg_path)
     self._results = dict()
     self._results.update(best_reward=scheduler.get_best_reward(),
                          best_config=scheduler.get_best_config(),
                          total_time=time.time() - start_tick,
                          metadata=scheduler.metadata,
                          training_history=scheduler.training_history,
                          config_history=scheduler.config_history,
                          reward_attr=scheduler._reward_attr,
                          config=cfg)
     if plot_results:
         plot_training_curves = os.path.join(self._output_directory,
                                             'plot_training_curves.png')
         scheduler.get_training_curves(filename=plot_training_curves,
                                       plot=plot_results,
                                       use_legend=True)
     # Consider to move this to a separate predictor
     self._config = cfg
     backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
         = get_backbone(cfg.model.backbone.name)
     text_backbone = backbone_model_cls.from_cfg(backbone_cfg)
     preprocessor = TabularBasicBERTPreprocessor(
         tokenizer=tokenizer,
         column_properties=self._column_properties,
         label_columns=self._label_columns,
         max_length=cfg.model.preprocess.max_length,
         merge_text=cfg.model.preprocess.merge_text)
     self._preprocessor = preprocessor
     net = BERTForTabularBasicV1(
         text_backbone=text_backbone,
         feature_field_info=preprocessor.feature_field_info(),
         label_shape=self._label_shapes[0],
         cfg=cfg.model.network)
     net.hybridize()
     ctx_l = get_mxnet_available_ctx()
     net.load_parameters(os.path.join(best_model_saved_dir_path,
                                      'best_model.params'),
                         ctx=ctx_l)
     self._net = net
     mx.npx.waitall()
Example #3
0
 def train(self,
           train_data,
           tuning_data,
           resource,
           time_limits=None,
           scheduler='fifo',
           searcher=None,
           num_trials=10,
           grace_period=None,
           max_t=None,
           reduction_factor=4,
           brackets=1,
           plot_results=False,
           console_log=True,
           ignore_warning=True):
     start_tick = time.time()
     logging_config(folder=self._output_directory,
                    name='main',
                    console=console_log,
                    logger=self._logger)
     assert len(self._label_columns) == 1
     # TODO(sxjscience) Try to support S3
     os.makedirs(self._output_directory, exist_ok=True)
     search_space_reg = args(search_space=space.Dict(**self.search_space))
     if scheduler == 'hyperband' and time_limits is None:
         time_limits = 5 * 60 * 60  # 5 hour
     train_fn = search_space_reg(
         functools.partial(train_function,
                           train_data=train_data,
                           time_limits=time_limits,
                           tuning_data=tuning_data,
                           base_config=self.base_config,
                           problem_types=self.problem_types,
                           column_properties=self._column_properties,
                           label_columns=self._label_columns,
                           label_shapes=self._label_shapes,
                           log_metrics=self._log_metrics,
                           stopping_metric=self._stopping_metric,
                           console_log=console_log,
                           ignore_warning=ignore_warning))
     if scheduler == 'fifo':
         if searcher is None:
             searcher = 'random'
         scheduler = FIFOScheduler(train_fn,
                                   time_out=time_limits,
                                   num_trials=num_trials,
                                   resource=resource,
                                   searcher=searcher,
                                   checkpoint=None,
                                   reward_attr='reward',
                                   time_attr='time_spent')
     elif scheduler == 'hyperband':
         if searcher is None:
             searcher = 'random'
         if grace_period is None:
             grace_period = 1
         if max_t is None:
             max_t = 5
         scheduler = HyperbandScheduler(train_fn,
                                        time_out=time_limits,
                                        max_t=max_t,
                                        resource=resource,
                                        searcher=searcher,
                                        grace_period=grace_period,
                                        reduction_factor=reduction_factor,
                                        brackets=brackets,
                                        checkpoint=None,
                                        reward_attr='reward',
                                        time_attr='report_idx')
     else:
         raise NotImplementedError
     scheduler.run()
     scheduler.join_jobs()
     self._logger.info('Best_config={}'.format(scheduler.get_best_config()))
     best_task_id = scheduler.get_best_task_id()
     best_model_saved_dir_path = os.path.join(self._output_directory,
                                              'task{}'.format(best_task_id))
     best_cfg_path = os.path.join(best_model_saved_dir_path, 'cfg.yml')
     cfg = self.base_config.clone_merge(best_cfg_path)
     self._results = dict()
     self._results.update(best_reward=scheduler.get_best_reward(),
                          best_config=scheduler.get_best_config(),
                          total_time=time.time() - start_tick,
                          metadata=scheduler.metadata,
                          training_history=scheduler.training_history,
                          config_history=scheduler.config_history,
                          reward_attr=scheduler._reward_attr,
                          config=cfg)
     if plot_results:
         plot_training_curves = os.path.join(self._output_directory,
                                             'plot_training_curves.png')
         scheduler.get_training_curves(filename=plot_training_curves,
                                       plot=plot_results,
                                       use_legend=True)
     # Consider to move this to a separate predictor
     self._config = cfg
     backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
         = get_backbone(cfg.model.backbone.name)
     text_backbone = backbone_model_cls.from_cfg(backbone_cfg)
     preprocessor = TabularBasicBERTPreprocessor(
         tokenizer=tokenizer,
         column_properties=self._column_properties,
         label_columns=self._label_columns,
         max_length=cfg.model.preprocess.max_length,
         merge_text=cfg.model.preprocess.merge_text)
     self._preprocessor = preprocessor
     net = BERTForTabularBasicV1(
         text_backbone=text_backbone,
         feature_field_info=preprocessor.feature_field_info(),
         label_shape=self._label_shapes[0],
         cfg=cfg.model.network)
     # Here, we cannot use GPU due to https://github.com/awslabs/autogluon/issues/602
     net.load_parameters(os.path.join(best_model_saved_dir_path,
                                      'best_model.params'),
                         ctx=mx.cpu())
     self._net = net
     mx.npx.waitall()
Example #4
0
def train_function(args,
                   reporter,
                   train_df_path,
                   tuning_df_path,
                   time_limits,
                   time_start,
                   base_config,
                   problem_types,
                   column_properties,
                   label_columns,
                   label_shapes,
                   log_metrics,
                   stopping_metric,
                   console_log,
                   ignore_warning=False):
    if time_limits is not None:
        start_train_tick = time.time()
        time_left = time_limits - (start_train_tick - time_start)
        if time_left <= 0:
            reporter.terminate()
            return
    import os
    # Get the log metric scorers
    if isinstance(log_metrics, str):
        log_metrics = [log_metrics]
    # Load the training and tuning data from the parquet file
    train_data = pd.read_parquet(train_df_path)
    tuning_data = pd.read_parquet(tuning_df_path)
    log_metric_scorers = [get_metric(ele) for ele in log_metrics]
    stopping_metric_scorer = get_metric(stopping_metric)
    greater_is_better = stopping_metric_scorer.greater_is_better
    os.environ['MKL_NUM_THREADS'] = '1'
    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_DYNAMIC'] = 'FALSE'
    if ignore_warning:
        import warnings
        warnings.filterwarnings("ignore")
    search_space = args['search_space']
    cfg = base_config.clone()
    specified_values = []
    for key in search_space:
        specified_values.append(key)
        specified_values.append(search_space[key])
    cfg.merge_from_list(specified_values)
    exp_dir = cfg.misc.exp_dir
    if reporter is not None:
        # When the reporter is not None,
        # we create the saved directory based on the task_id + time
        task_id = args.task_id
        exp_dir = os.path.join(exp_dir, 'task{}'.format(task_id))
        os.makedirs(exp_dir, exist_ok=True)
        cfg.defrost()
        cfg.misc.exp_dir = exp_dir
        cfg.freeze()
    logger = logging.getLogger()
    logging_config(folder=exp_dir,
                   name='training',
                   logger=logger,
                   console=console_log)
    logger.info(cfg)
    # Load backbone model
    backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
        = get_backbone(cfg.model.backbone.name)
    with open(os.path.join(exp_dir, 'cfg.yml'), 'w') as f:
        f.write(str(cfg))
    text_backbone = backbone_model_cls.from_cfg(backbone_cfg)
    # Build Preprocessor + Preprocess the training dataset + Inference problem type
    # TODO Move preprocessor + Dataloader to outer loop to better cache the dataloader
    preprocessor = TabularBasicBERTPreprocessor(
        tokenizer=tokenizer,
        column_properties=column_properties,
        label_columns=label_columns,
        max_length=cfg.model.preprocess.max_length,
        merge_text=cfg.model.preprocess.merge_text)
    logger.info('Process training set...')
    processed_train = preprocessor.process_train(train_data)
    logger.info('Done!')
    logger.info('Process dev set...')
    processed_dev = preprocessor.process_test(tuning_data)
    logger.info('Done!')
    label = label_columns[0]
    # Get the ground-truth dev labels
    gt_dev_labels = np.array(tuning_data[label].apply(
        column_properties[label].transform))
    ctx_l = get_mxnet_available_ctx()
    base_batch_size = cfg.optimization.per_device_batch_size
    num_accumulated = int(
        np.ceil(cfg.optimization.batch_size / base_batch_size))
    inference_base_batch_size = base_batch_size * cfg.optimization.val_batch_size_mult
    train_dataloader = DataLoader(
        processed_train,
        batch_size=base_batch_size,
        shuffle=True,
        batchify_fn=preprocessor.batchify(is_test=False))
    dev_dataloader = DataLoader(
        processed_dev,
        batch_size=inference_base_batch_size,
        shuffle=False,
        batchify_fn=preprocessor.batchify(is_test=True))
    net = BERTForTabularBasicV1(
        text_backbone=text_backbone,
        feature_field_info=preprocessor.feature_field_info(),
        label_shape=label_shapes[0],
        cfg=cfg.model.network)
    net.initialize_with_pretrained_backbone(backbone_params_path, ctx=ctx_l)
    net.hybridize()
    num_total_params, num_total_fixed_params = count_parameters(
        net.collect_params())
    logger.info('#Total Params/Fixed Params={}/{}'.format(
        num_total_params, num_total_fixed_params))
    # Initialize the optimizer
    updates_per_epoch = int(
        len(train_dataloader) / (num_accumulated * len(ctx_l)))
    optimizer, optimizer_params, max_update \
        = get_optimizer(cfg.optimization,
                        updates_per_epoch=updates_per_epoch)
    valid_interval = math.ceil(cfg.optimization.valid_frequency *
                               updates_per_epoch)
    train_log_interval = math.ceil(cfg.optimization.log_frequency *
                                   updates_per_epoch)
    trainer = mx.gluon.Trainer(net.collect_params(),
                               optimizer,
                               optimizer_params,
                               update_on_kvstore=False)
    if 0 < cfg.optimization.layerwise_lr_decay < 1:
        apply_layerwise_decay(net.text_backbone,
                              cfg.optimization.layerwise_lr_decay,
                              backbone_name=cfg.model.backbone.name)
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if num_accumulated > 1:
        logger.info('Using gradient accumulation.'
                    ' Global batch size = {}'.format(
                        cfg.optimization.batch_size))
        for p in params:
            p.grad_req = 'add'
        net.collect_params().zero_grad()
    train_loop_dataloader = grouper(repeat(train_dataloader), len(ctx_l))
    log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
    log_num_samples_l = [0 for _ in ctx_l]
    logging_start_tick = time.time()
    best_performance_score = None
    mx.npx.waitall()
    no_better_rounds = 0
    report_idx = 0
    start_tick = time.time()
    if time_limits is not None:
        time_limits -= start_tick - time_start
        if time_limits <= 0:
            reporter.terminate()
            return
    best_report_items = None
    for update_idx in tqdm.tqdm(range(max_update), disable=None):
        num_samples_per_update_l = [0 for _ in ctx_l]
        for accum_idx in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            loss_l = []
            num_samples_l = [0 for _ in ctx_l]
            for i, (sample, ctx) in enumerate(zip(sample_l, ctx_l)):
                feature_batch, label_batch = sample
                feature_batch = move_to_ctx(feature_batch, ctx)
                label_batch = move_to_ctx(label_batch, ctx)
                with mx.autograd.record():
                    pred = net(feature_batch)
                    if problem_types[0] == _C.CLASSIFICATION:
                        logits = mx.npx.log_softmax(pred, axis=-1)
                        loss = -mx.npx.pick(logits, label_batch[0])
                    elif problem_types[0] == _C.REGRESSION:
                        loss = mx.np.square(pred - label_batch[0])
                    loss_l.append(loss.mean() / len(ctx_l))
                    num_samples_l[i] = loss.shape[0]
                    num_samples_per_update_l[i] += loss.shape[0]
            for loss in loss_l:
                loss.backward()
            for i in range(len(ctx_l)):
                log_loss_l[i] += loss_l[i] * len(ctx_l) * num_samples_l[i]
                log_num_samples_l[i] += num_samples_per_update_l[i]
        # Begin to update
        trainer.allreduce_grads()
        num_samples_per_update = sum(num_samples_per_update_l)
        total_norm, ratio, is_finite = \
            clip_grad_global_norm(params, cfg.optimization.max_grad_norm * num_accumulated)
        total_norm = total_norm / num_accumulated
        trainer.update(num_samples_per_update)

        # Clear after update
        if num_accumulated > 1:
            net.collect_params().zero_grad()
        if (update_idx + 1) % train_log_interval == 0:
            log_loss = sum([ele.as_in_ctx(ctx_l[0])
                            for ele in log_loss_l]).asnumpy()
            log_num_samples = sum(log_num_samples_l)
            logger.info(
                '[Iter {}/{}, Epoch {}] train loss={:0.4e}, gnorm={:0.4e}, lr={:0.4e}, #samples processed={},'
                ' #sample per second={:.2f}'.format(
                    update_idx + 1, max_update,
                    int(update_idx / updates_per_epoch),
                    log_loss / log_num_samples, total_norm,
                    trainer.learning_rate, log_num_samples,
                    log_num_samples / (time.time() - logging_start_tick)))
            logging_start_tick = time.time()
            log_loss_l = [
                mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l
            ]
            log_num_samples_l = [0 for _ in ctx_l]
        if (update_idx + 1) % valid_interval == 0 or (update_idx +
                                                      1) == max_update:
            valid_start_tick = time.time()
            dev_predictions = \
                _classification_regression_predict(net, dataloader=dev_dataloader,
                                                   problem_type=problem_types[0],
                                                   has_label=False)
            log_scores = [
                calculate_metric(scorer, gt_dev_labels, dev_predictions,
                                 problem_types[0])
                for scorer in log_metric_scorers
            ]
            dev_score = calculate_metric(stopping_metric_scorer, gt_dev_labels,
                                         dev_predictions, problem_types[0])
            valid_time_spent = time.time() - valid_start_tick

            if best_performance_score is None or \
                    (greater_is_better and dev_score >= best_performance_score) or \
                    (not greater_is_better and dev_score <= best_performance_score):
                find_better = True
                no_better_rounds = 0
                best_performance_score = dev_score
                net.save_parameters(os.path.join(exp_dir, 'best_model.params'))
            else:
                find_better = False
                no_better_rounds += 1
            mx.npx.waitall()
            loss_string = ', '.join([
                '{}={:0.4e}'.format(metric.name, score)
                for score, metric in zip(log_scores, log_metric_scorers)
            ])
            logger.info('[Iter {}/{}, Epoch {}] valid {}, time spent={:.3f}s,'
                        ' total_time={:.2f}min'.format(
                            update_idx + 1, max_update,
                            int(update_idx / updates_per_epoch), loss_string,
                            valid_time_spent, (time.time() - start_tick) / 60))
            report_items = [('iteration', update_idx + 1),
                            ('report_idx', report_idx + 1),
                            ('epoch', int(update_idx / updates_per_epoch))] +\
                           [(metric.name, score)
                            for score, metric in zip(log_scores, log_metric_scorers)] + \
                           [('find_better', find_better),
                            ('time_spent', int(time.time() - start_tick))]
            total_time_spent = time.time() - start_tick

            if stopping_metric_scorer._sign < 0:
                report_items.append(('reward_attr', -dev_score))
            else:
                report_items.append(('reward_attr', dev_score))
            report_items.append(('eval_metric', stopping_metric_scorer.name))
            report_items.append(('exp_dir', exp_dir))
            if find_better:
                best_report_items = report_items
            reporter(**dict(report_items))
            report_idx += 1
            if no_better_rounds >= cfg.learning.early_stopping_patience:
                logger.info('Early stopping patience reached!')
                break
            if time_limits is not None and total_time_spent > time_limits:
                break

    best_report_items_dict = dict(best_report_items)
    best_report_items_dict['report_idx'] = report_idx + 1
    reporter(**best_report_items_dict)
    def fit(cls,
            train_data,
            label,
            tuning_data=None,
            time_limits=None,
            output_directory='./ag_text',
            feature_columns=None,
            holdout_frac=None,
            eval_metric=None,
            stopping_metric=None,
            nthreads_per_trial=None,
            ngpus_per_trial=None,
            dist_ip_addrs=None,
            scheduler=None,
            num_trials=None,
            search_strategy=None,
            search_options=None,
            hyperparameters=None,
            plot_results=None,
            seed=None,
            verbosity=2):
        """

        Parameters
        ----------
        train_data
            Training dataset
        label
            Name of the label column. It can be a stringBy default, we will search for a column named
        tuning_data
            The tuning dataset. We will tune the model
        time_limits
            The time limits. By default, there won't be any time limit and we will try to
            find the best model.
        output_directory
            The output directory
        feature_columns
            The feature columns
        holdout_frac
            Ratio of the training data that will be held out as the tuning data.
            By default, we will choose the appropriate holdout_frac based on the number of
            training samples.
        eval_metric
            The evaluation metric, i.e., how you will finally evaluate the model.
        stopping_metric
            The intrinsic metric used for early stopping.
            By default, we will select the best metric that
        nthreads_per_trial
            The number of threads per trial. By default, we will use all available CPUs.
        ngpus_per_trial
            The number of GPUs to use for the fit job. By default, we decide the usage
            based on the total number of GPUs available.
        dist_ip_addrs
            The distributed IP address
        scheduler
            The scheduler of HPO
        num_trials
            The number of trials in the HPO search
        search_strategy
            The search strategy
        search_options
            The search options
        hyperparameters
            The hyper-parameters of the search-space.
        plot_results
            Whether to plot the fitting results
        seed
            The seed of the random state
        verbosity
            Verbosity levels range from 0 to 4 and control how much information is printed
            during fit().
            Higher levels correspond to more detailed print statements
            (you can set verbosity = 0 to suppress warnings).
            If using logging, you can alternatively control amount of information printed
            via `logger.setLevel(L)`,
            where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print
            statements, opposite of verbosity levels)

        Returns
        -------
        model
            A model object
        """
        assert dist_ip_addrs is None, 'Training on remote machine is currently not supported.'
        if verbosity < 0:
            verbosity = 0
        elif verbosity > 4:
            verbosity = 4
        console_log = verbosity >= 2
        logging_config(folder=output_directory,
                       name='ag_text_prediction',
                       logger=logger,
                       level=verbosity2loglevel(verbosity),
                       console=console_log)
        # Parse the hyper-parameters
        if hyperparameters is None:
            hyperparameters = ag_text_prediction_params.create('default')
        elif isinstance(hyperparameters, str):
            hyperparameters = ag_text_prediction_params.create(hyperparameters)
        else:
            base_params = ag_text_prediction_params.create('default')
            hyperparameters = merge_params(base_params, hyperparameters)
        np.random.seed(seed)
        if not isinstance(train_data, pd.DataFrame):
            train_data = load_pd.load(train_data)
        # Inference the label
        if not isinstance(label, list):
            label = [label]
        label_columns = []
        for ele in label:
            if isinstance(ele, int):
                label_columns.append(train_data.columns[ele])
            else:
                label_columns.append(ele)
        if feature_columns is None:
            all_columns = list(train_data.columns)
            feature_columns = [
                ele for ele in all_columns if ele not in label_columns
            ]
        else:
            if isinstance(feature_columns, str):
                feature_columns = [feature_columns]
            for col in feature_columns:
                assert col not in label_columns, 'Feature columns and label columns cannot overlap.'
            all_columns = feature_columns + label_columns
            all_columns = [
                ele for ele in train_data.columns if ele in all_columns
            ]
        if tuning_data is None:
            if holdout_frac is None:
                holdout_frac = default_holdout_frac(len(train_data), True)
            train_data, tuning_data = random_split_train_val(
                train_data, valid_ratio=holdout_frac)
        else:
            if not isinstance(tuning_data, pd.DataFrame):
                tuning_data = load_pd.load(tuning_data)
        train_data = TabularDataset(train_data,
                                    columns=all_columns,
                                    label_columns=label_columns)
        tuning_data = TabularDataset(
            tuning_data, column_properties=train_data.column_properties)

        logger.info('Train Dataset:')
        logger.info(train_data)
        logger.info('Tuning Dataset:')
        logger.info(tuning_data)
        logger.debug('Hyperparameters:')
        logger.debug(hyperparameters)
        column_properties = train_data.column_properties

        problem_types = []
        label_shapes = []
        for label_col_name in label_columns:
            problem_type, label_shape = infer_problem_type(
                column_properties=column_properties,
                label_col_name=label_col_name)
            problem_types.append(problem_type)
            label_shapes.append(label_shape)
        logging.info(
            'Label columns={}, Problem types={}, Label shapes={}'.format(
                label_columns, problem_types, label_shapes))
        eval_metric, stopping_metric, log_metrics =\
            infer_eval_stop_log_metrics(problem_types[0],
                                        label_shapes[0],
                                        eval_metric=eval_metric,
                                        stopping_metric=stopping_metric)
        logging.info('Eval Metric={}, Stop Metric={}, Log Metrics={}'.format(
            eval_metric, stopping_metric, log_metrics))
        model_candidates = []
        for model_type, kwargs in hyperparameters['models'].items():
            search_space = kwargs['search_space']
            if model_type == 'BertForTextPredictionBasic':
                model = BertForTextPredictionBasic(
                    column_properties=column_properties,
                    label_columns=label_columns,
                    feature_columns=feature_columns,
                    label_shapes=label_shapes,
                    problem_types=problem_types,
                    stopping_metric=stopping_metric,
                    log_metrics=log_metrics,
                    base_config=None,
                    search_space=search_space,
                    output_directory=output_directory,
                    logger=logger)
                model_candidates.append(model)
            else:
                raise ValueError(
                    'model_type = "{}" is not supported. You can try to use '
                    'model_type = "BertForTextPredictionBasic"'.format(
                        model_type))
        assert len(
            model_candidates) == 1, 'Only one model is supported currently'
        recommended_resource = get_recommended_resource(
            nthreads_per_trial=nthreads_per_trial,
            ngpus_per_trial=ngpus_per_trial)
        if scheduler is None:
            scheduler = hyperparameters['hpo_params']['scheduler']
        if search_strategy is None:
            search_strategy = hyperparameters['hpo_params']['search_strategy']
        if time_limits is None:
            time_limits = hyperparameters['hpo_params']['time_limits']
        else:
            if isinstance(time_limits, str):
                if time_limits.endswith('min'):
                    time_limits = int(float(time_limits[:-3]) * 60)
                elif time_limits.endswith('hour'):
                    time_limits = int(float(time_limits[:-4]) * 60 * 60)
                else:
                    raise ValueError(
                        'The given time_limits="{}" cannot be parsed!'.format(
                            time_limits))
        if num_trials is None:
            num_trials = hyperparameters['hpo_params']['num_trials']

        # Setting the HPO-specific parameters.
        reduction_factor = hyperparameters['hpo_params']['reduction_factor']
        grace_period = hyperparameters['hpo_params']['grace_period']
        max_t = hyperparameters['hpo_params']['max_t']
        if recommended_resource['num_gpus'] == 0:
            warnings.warn(
                'Recommend to use GPU to run the TextPrediction task!')
        model = model_candidates[0]
        if plot_results is None:
            if in_ipynb():
                plot_results = True
            else:
                plot_results = False
        model.train(train_data=train_data,
                    tuning_data=tuning_data,
                    resource=recommended_resource,
                    time_limits=time_limits,
                    scheduler=scheduler,
                    searcher=search_strategy,
                    num_trials=num_trials,
                    reduction_factor=reduction_factor,
                    grace_period=grace_period,
                    max_t=max_t,
                    plot_results=plot_results,
                    console_log=verbosity > 2,
                    ignore_warning=verbosity <= 2)
        return model