コード例 #1
0
def catboost_ensemble(config, is_train):
    catboost_ensemble = Step(
        name='catboost_ensemble',
        transformer=CatboostClassifierMultilabel(**config.catboost_ensemble),
        input_data=['input'],
        cache_dirpath=config.env.cache_dirpath)

    output = Step(name='output',
                  transformer=Dummy(),
                  input_steps=[catboost_ensemble],
                  adapter={
                      'y_pred':
                      ([('catboost_ensemble', 'prediction_probability')])
                  },
                  cache_dirpath=config.env.cache_dirpath)

    if is_train:
        catboost_ensemble.overwrite_transformer = True

    return output
コード例 #2
0
def gru_stacker_ensemble(config, is_train):
    if is_train:
        gru_stacker_ensemble = Step(
            name='gru_stacker_ensemble',
            transformer=StackerGru(**config.gru_stacker),
            input_data=['input'],
            adapter={
                'X': ([('input', 'X')]),
                'y': ([('input', 'y')]),
                'validation_data': ([('input', 'X_valid'),
                                     ('input', 'y_valid')], to_tuple_inputs),
            },
            cache_dirpath=config.env.cache_dirpath)
    else:
        gru_stacker_ensemble = Step(
            name='gru_stacker_ensemble',
            transformer=StackerGru(**config.gru_stacker),
            input_data=['input'],
            adapter={
                'X': ([('input', 'X')]),
                'y': ([('input', 'y')]),
            },
            cache_dirpath=config.env.cache_dirpath)
    output = Step(name='output',
                  transformer=Dummy(),
                  input_steps=[gru_stacker_ensemble],
                  adapter={
                      'y_pred':
                      ([('gru_stacker_ensemble', 'prediction_probability')])
                  },
                  cache_dirpath=config.env.cache_dirpath)

    if is_train:
        gru_stacker_ensemble.overwrite_transformer = True

    return output