def train_model(): """Training model.""" s3_obj = load_S3() data = load_data(s3_obj) hyper_params = load_hyper_params() or {} LOWER_LIMIT = int(hyper_params.get('lower_limit', 13)) UPPER_LIMIT = int(hyper_params.get('upper_limit', 15)) LATENT_FACTOR = int(hyper_params.get('latent_factor', 300)) logger.info( "Lower limit {}, Upper limit {} and latent factor {} are used.".format( LOWER_LIMIT, UPPER_LIMIT, LATENT_FACTOR)) package_id_dict, manifest_id_dict = preprocess_raw_data( data.get('package_dict', {}), LOWER_LIMIT, UPPER_LIMIT) user_input_stacks = data.get('package_dict', {}).\ get('user_input_stack', []) user_item_list = make_user_item_df(manifest_id_dict, package_id_dict, user_input_stacks) user_item_df = pd.DataFrame(user_item_list) training_df, testing_df = train_test_split(user_item_df) format_pkg_id_dict, format_mnf_id_dict = format_dict( package_id_dict, manifest_id_dict) del package_id_dict, manifest_id_dict trained_recommender = run_recommender(training_df, LATENT_FACTOR) precision_at_30, recall_at_30 = precision_recall_at_m( 30, testing_df, trained_recommender, user_item_df) precision_at_50, recall_at_50 = precision_recall_at_m( 50, testing_df, trained_recommender, user_item_df) try: save_obj(s3_obj, trained_recommender, precision_at_30, recall_at_30, format_pkg_id_dict, format_mnf_id_dict, precision_at_50, recall_at_50, LOWER_LIMIT, UPPER_LIMIT, LATENT_FACTOR) if GITHUB_TOKEN: create_git_pr(s3_client=s3_obj, model_version=MODEL_VERSION, recall_at_30=recall_at_30) except Exception as error: logger.error(error) raise
def test_load_hyper_params(): # mock command line args helper.argv = ['helper.py', '{"a": 111, "b": "some text"}'] hyper_params = helper.load_hyper_params() assert hyper_params.get('a') == 111 assert hyper_params.get('b') == "some text"