logging.info( '================================================================================' ) model = LambdaMART( metric='NDCG@10', max_leaf_nodes=7, shrinkage=0.1, estopping=30, n_jobs=-1, random_state=42, use_pines=True, pines_kwargs=dict( switch_criterion=ObliviousCartSwitchCriterionType.OBLIVIOUS_WHILE_CAN, tree_type=TreeType.OBLIVIOUS_CART, max_n_splits=10, min_samples_leaf=50, max_depth=10, )) model.fit(training_queries, validation_queries=validation_queries) logging.info( '================================================================================' ) logging.info('%s on the test queries: %.8f' % (model.metric, model.evaluate(test_queries, n_jobs=-1))) model.save('LambdaMART_L7_S0.1_E50_' + model.metric)
remove_useless_queries = False # Find constant query-document features. cfs = find_constant_features([training_queries, validation_queries, test_queries]) # Get rid of constant features and (possibly) remove useless queries. training_queries.adjust(remove_features=cfs, purge=remove_useless_queries) validation_queries.adjust(remove_features=cfs, purge=remove_useless_queries) test_queries.adjust(remove_features=cfs) # Print basic info about query datasets. logging.info('Train queries: %s' % training_queries) logging.info('Valid queries: %s' % validation_queries) logging.info('Test queries: %s' % test_queries) logging.info('=' * 80) model = LambdaMART(metric='NDCG@10', max_leaf_nodes=7, shrinkage=0.1, estopping=50, n_jobs=-1, min_samples_leaf=50, random_state=42) model.fit(training_queries, validation_queries=validation_queries) logging.info('=' * 80) logging.info('%s on the test queries: %.8f' % (model.metric, model.evaluate(test_queries, n_jobs=-1))) model.save('LambdaMART_L7_S0.1_E50_' + model.metric)
model = LambdaMART(metric='nDCG@38', max_leaf_nodes=7, shrinkage=0.1, estopping=10, n_jobs=-1, min_samples_leaf=50, random_state=42) #TODO: do some crossval here? model.fit(training_queries, validation_queries=test_queries) logging.info( '================================================================================' ) logging.info('%s on the test queries: %.8f' % (model.metric, model.evaluate(test_queries, n_jobs=-1))) model.save('LambdaMART_L7_S0.1_E50_' + model.metric) predicted_rankings = model.predict_rankings(test_queries) test_df = pd.read_csv("../test_set_VU_DM_2014.csv", header=0, nrows=test_queries.document_count()) test_df['pred_position'] = np.concatenate(predicted_rankings) sorted_df = test_df[['srch_id', 'prop_id', 'pred_position' ]].sort_values(['srch_id', 'pred_position']) submission = pd.DataFrame({ 'SearchId': sorted_df.srch_id, 'PropertyId': sorted_df.prop_id })[['SearchId', 'PropertyId']]
validation_queries.adjust(remove_features=cfs, purge=remove_useless_queries) test_queries.adjust(remove_features=cfs) # Print basic info about query datasets. logging.info('Train queries: %s' % training_queries) logging.info('Valid queries: %s' % validation_queries) logging.info('Test queries: %s' % test_queries) logging.info('================================================================================') model = LambdaMART(metric='nDCG@38', max_leaf_nodes=7, shrinkage=0.1, estopping=10, n_jobs=-1, min_samples_leaf=50, random_state=42) #TODO: do some crossval here? model.fit(training_queries, validation_queries=test_queries) logging.info('================================================================================') logging.info('%s on the test queries: %.8f' % (model.metric, model.evaluate(test_queries, n_jobs=-1))) model.save('LambdaMART_L7_S0.1_E50_' + model.metric) predicted_rankings = model.predict_rankings(test_queries) test_df = pd.read_csv("../test_set_VU_DM_2014.csv", header=0, nrows = test_queries.document_count()) test_df['pred_position'] = np.concatenate(predicted_rankings) sorted_df = test_df[['srch_id', 'prop_id', 'pred_position']].sort_values(['srch_id', 'pred_position']) submission = pd.DataFrame({ 'SearchId': sorted_df.srch_id, 'PropertyId': sorted_df.prop_id })[['SearchId', 'PropertyId']] submission.to_csv('model_%d2_%f.csv' % (test_queries.document_count(), model.evaluate(test_queries, n_jobs=-1)), index=False)