Exemple #1
0
def tt_trial(args, reporter):
    """ Training and evaluation function used during a single trial of HPO """
    try:
        model, args, util_args = model_trial.prepare_inputs(args=args)

        fit_model_args = dict(X=util_args.X,
                              y=util_args.y,
                              X_val=util_args.X_val,
                              y_val=util_args.y_val,
                              **util_args.get('fit_kwargs', dict()))
        predict_proba_args = dict(X=util_args.X_val)
        model = model_trial.fit_and_save_model(
            model=model,
            params=args,
            fit_args=fit_model_args,
            predict_proba_args=predict_proba_args,
            y_val=util_args.y_val,
            time_start=util_args.time_start,
            time_limit=util_args.get('time_limit', None),
            reporter=reporter)
    except Exception as e:
        if not isinstance(e, TimeLimitExceeded):
            logger.exception(e, exc_info=True)
        reporter.terminate()
    else:
        reporter(epoch=model.params['epochs'] + 1,
                 validation_performance=model.val_score)
    pass
Exemple #2
0
def tabular_pytorch_trial(args, reporter):
    """ Training and evaluation function used during a single trial of HPO """
    try:
        model, args, util_args = model_trial.prepare_inputs(args=args)

        train_dataset = TabularPyTorchDataset.load(util_args.train_path)
        val_dataset = TabularPyTorchDataset.load(util_args.val_path)
        y_val = val_dataset.get_labels()

        fit_model_args = dict(X=train_dataset,
                              y=None,
                              X_val=val_dataset,
                              **util_args.get('fit_kwargs', dict()))
        predict_proba_args = dict(X=val_dataset)
        model_trial.fit_and_save_model(model=model,
                                       params=args,
                                       fit_args=fit_model_args,
                                       predict_proba_args=predict_proba_args,
                                       y_val=y_val,
                                       time_start=util_args.time_start,
                                       time_limit=util_args.get(
                                           'time_limit', None),
                                       reporter=reporter)
    except Exception as e:
        if not isinstance(e, TimeLimitExceeded):
            logger.exception(e, exc_info=True)
        reporter.terminate()
def lgb_trial(args, reporter):
    """ Training script for hyperparameter evaluation of Gradient Boosting model """
    try:
        model, args, util_args = model_trial.prepare_inputs(args=args)

        try_import_lightgbm()
        import lightgbm as lgb

        dataset_train = lgb.Dataset(util_args.directory + util_args.dataset_train_filename)
        dataset_val = lgb.Dataset(util_args.directory + util_args.dataset_val_filename)
        X_val, y_val = load_pkl.load(util_args.directory + util_args.dataset_val_pkl_filename)

        fit_model_args = dict(dataset_train=dataset_train, dataset_val=dataset_val, **util_args.get('fit_kwargs', dict()))
        predict_proba_args = dict(X=X_val)
        model_trial.fit_and_save_model(model=model, params=args, fit_args=fit_model_args, predict_proba_args=predict_proba_args, y_val=y_val,
                                       time_start=util_args.time_start, time_limit=util_args.get('time_limit', None), reporter=reporter)
    except Exception as e:
        if not isinstance(e, TimeLimitExceeded):
            logger.exception(e, exc_info=True)
        reporter.terminate()