예제 #1
0
def train_model():
    """Training model."""
    s3_obj = load_S3()
    data = load_data(s3_obj)
    hyper_params = load_hyper_params() or {}
    LOWER_LIMIT = int(hyper_params.get('lower_limit', 13))
    UPPER_LIMIT = int(hyper_params.get('upper_limit', 15))
    LATENT_FACTOR = int(hyper_params.get('latent_factor', 300))
    logger.info(
        "Lower limit {}, Upper limit {} and latent factor {} are used.".format(
            LOWER_LIMIT, UPPER_LIMIT, LATENT_FACTOR))
    package_id_dict, manifest_id_dict = preprocess_raw_data(
        data.get('package_dict', {}), LOWER_LIMIT, UPPER_LIMIT)
    user_input_stacks = data.get('package_dict', {}).\
        get('user_input_stack', [])
    user_item_list = make_user_item_df(manifest_id_dict, package_id_dict,
                                       user_input_stacks)
    user_item_df = pd.DataFrame(user_item_list)
    training_df, testing_df = train_test_split(user_item_df)
    format_pkg_id_dict, format_mnf_id_dict = format_dict(
        package_id_dict, manifest_id_dict)
    del package_id_dict, manifest_id_dict
    trained_recommender = run_recommender(training_df, LATENT_FACTOR)
    precision_at_30, recall_at_30 = precision_recall_at_m(
        30, testing_df, trained_recommender, user_item_df)
    precision_at_50, recall_at_50 = precision_recall_at_m(
        50, testing_df, trained_recommender, user_item_df)
    try:
        save_obj(s3_obj, trained_recommender, precision_at_30, recall_at_30,
                 format_pkg_id_dict, format_mnf_id_dict, precision_at_50,
                 recall_at_50, LOWER_LIMIT, UPPER_LIMIT, LATENT_FACTOR)
        if GITHUB_TOKEN:
            create_git_pr(s3_client=s3_obj,
                          model_version=MODEL_VERSION,
                          recall_at_30=recall_at_30)
    except Exception as error:
        logger.error(error)
        raise
예제 #2
0
def test_load_hyper_params():
    # mock command line args
    helper.argv = ['helper.py', '{"a": 111, "b": "some text"}']
    hyper_params = helper.load_hyper_params()
    assert hyper_params.get('a') == 111
    assert hyper_params.get('b') == "some text"