def __init__(self, *, vector_file: Path, ramen_query_service: GraphRamenQueryService): Pipeline.__init__( self, stages=( SimilarRamenCandidateGenerator( ramen_vector_file=vector_file.resolve(), ramen_query_service=ramen_query_service, generator_explanation=Explanation( explanation_string= f"This ramen is identified as being similar to the target ramen." ), ), SameBrandFilter(filter_explanation=Explanation( explanation_string= "This ramen is from a different brand than the target ramen." )), RamenRatingScorer(scoring_explanation=Explanation( explanation_string="This ramen has a high rating score.")), RamenStyleScorer(scoring_explanation=Explanation( explanation_string= "This ramen is the same style as the target ramen.")), CandidateRanker(), ), )
def __init__(self, **kwargs): self.temp_guideline_uri_dict = dict() # Currently using example hard-coded guidelines self.temp_guideline_uri_dict[URIRef("exampleGuideline1")] = Guideline( uri=URIRef("exampleGuideline1"), user_conditions=frozenset(), filter_directives=frozenset(), scoring_directives=frozenset({ GuidelineDirective( target_value=2300, target_attribute="total_nutritional_info.sodium_mg", directive_type=ConstraintType.LEQ, ) }), explanation=Explanation( explanation_string= "As for the general population, people with diabetes should limit sodium consumption to <2,300 mg/day." ), ) self.temp_guideline_uri_dict[URIRef("exampleGuideline2")] = Guideline( uri=URIRef("exampleGuideline2"), user_conditions=frozenset({ lambda usr: usr.sex == "male", lambda usr: usr.bmi is not None and usr.bmi > 0, }), filter_directives=frozenset(), scoring_directives=frozenset({ GuidelineDirective( target_value=1800, target_attribute="total_nutritional_info.energ__kcal", directive_type=ConstraintType.LEQ, ) }), explanation=Explanation( explanation_string= "1,500–1,800 kcal/day for men, adjusted for the individuals baseline body weight" ), ) self.temp_guideline_uri_dict[URIRef("exampleGuideline3")] = Guideline( uri=URIRef("exampleGuideline3"), user_conditions=frozenset({ lambda usr: all( val in usr.target_lifestyle_guideline_set for val in frozenset({ URIRef( "http://idea.rpi.edu/heals/kb/placeholder/fakeuri3" ) })) }), filter_directives=frozenset(), scoring_directives=frozenset(), explanation=Explanation( explanation_string= "mediterranean diet. prefer including subclasses of fruit, nuts, fish, vegetable, legume, olive oil, dairy. only based on links existing in recipes-1. this placeholder is not a great example of a real guideline." ), ) super().__init__(**kwargs)
def __init__( self, *, recipe_embedding_service: RecipeEmbeddingService, food_kg: FoodKgQueryService, guideline_kg: GuidelineKgQueryService ): self.res = recipe_embedding_service self.food_kg = food_kg Pipeline.__init__( self, stages=( SimilarToFavoritesRecipeCandidateGenerator( recipe_embedding_service=self.res, food_kg_query_service=self.food_kg, ), ContainsAnyProhibitedIngredientFilter( filter_explanation=Explanation( explanation_string="This recipe does not contain any ingredients that are prohibited by you." ) ), ApplyGuidelinesToRecipesPipeline(guideline_kg=guideline_kg), RecipeCaloriesScorer( scoring_explanation=Explanation( explanation_string="Scoring based on calories, this is mostly a placeholder to break ties." ) ), CandidateRanker(), ), )
def likes_country_scorer() -> CandidateBoolScorer: return RamenEaterLikesCountryScorer( success_scoring_explanation=Explanation( explanation_string="This ramen is from a country that the user likes." ), failure_scoring_explanation=Explanation( explanation_string="This ramen is from not a country that the user likes." ), )
def likes_brand_scorer() -> CandidateBoolScorer: return RamenEaterLikesBrandScorer( success_scoring_explanation=Explanation( explanation_string="This ramen is from a brand that the user likes." ), failure_scoring_explanation=Explanation( explanation_string="This ramen is not from a brand that the user likes." ), )
def likes_style_scorer() -> CandidateBoolScorer: return RamenEaterLikesStyleScorer( success_scoring_explanation=Explanation( explanation_string="This ramen is a style that the user likes." ), failure_scoring_explanation=Explanation( explanation_string="This ramen is not a style that the user likes." ), )
def __init__( self, *, vector_file: Path, ramen_query_service: GraphRamenQueryService, num_days: int = 2, ramens_per_day: int = 3, min_daily_rating: int = 7, max_daily_price: int = 7, max_total_price: int = 13, ): Pipeline.__init__( self, stages=( RecommendForEaterPipeline( vector_file=vector_file, ramen_query_service=ramen_query_service), RamenMealPlanCandidateGenerator( num_days=num_days, ramens_per_day=ramens_per_day, min_daily_rating=min_daily_rating, max_daily_price=max_daily_price, max_total_price=max_total_price, generator_explanation=Explanation( explanation_string= "Based on ramens that you might like, a meal plan was generated." ), ), ), )
def generate( self, *, candidates: Generator[RecipeCandidate, None, None] = None, context: PatientContext, ) -> Generator[MealPlanCandidate, None, None]: recipe_candidates = tuple(candidates) print("ahhhh", len(recipe_candidates)) soln = self.solver.set_candidates(candidates=recipe_candidates).solve( output_uri=URIRef("placeholder.com/placeholder_meal_plan_soln_uri") ) yield MealPlanCandidate( context=context, applied_scores=[soln.overall_score], applied_explanations=[self.generator_explanation], domain_object=MealPlanRecommendation( explanation=Explanation( explanation_string=f"This is a meal plan that was generated for {self.number_of_days} days of meals," f" eating {self.meals_per_day} meals each day." ), meal_plan_days=tuple( MealPlanDay( recipe_recommendations=tuple( RecipeRecommendation( recipe=candidate.domain_object, explanation=RecipeRecommendationExplanation( explanation_contents=tuple( candidate.applied_explanations ) ), ) for candidate in section.section_candidates ), explanation=Explanation( explanation_string=f"This is a set of recommended recipes to eat for this day, " f"based on suggesting recipes that you are likely to like in general." ), ) for section in soln.solution_section_sets[0].sections ), ), )
def test_filter_applicable_guidelines(test_user, placeholder_guideline, placeholder_guideline2): user_context = PatientContext(target_user=test_user) candidates = [ guideline_candidate_placeholder(placeholder_guideline, context=user_context), guideline_candidate_placeholder(placeholder_guideline2, context=user_context), ] res = list( UserMatchGuidelineFilter(filter_explanation=Explanation("test1"))( candidates=candidates, context=user_context)) assert res == [ GuidelineCandidate( context=user_context, domain_object=placeholder_guideline, applied_explanations=[Explanation(explanation_string="test1")], applied_scores=[0], ) ]
def test_generate_guideline_pipeline(test_user, guideline_kg, placeholder_guideline): pipe = GenerateGuidelinesApplicableToUserPipeline( guideline_kg=guideline_kg) res = list(pipe(context=PatientContext(target_user=test_user))) assert res == [ GuidelineCandidate( context=PatientContext(target_user=test_user), domain_object=placeholder_guideline, applied_explanations=[ Explanation(explanation_string= "This is a guideline that exists in the system."), Explanation( explanation_string= "User matches the conditions to apply this guideline."), ], applied_scores=[0, 0], ) ]
def ramen_candidate_generator( graph_ramen_query_service, test_ramen_101 ) -> SimilarRamenCandidateGenerator: return SimilarRamenCandidateGenerator( ramen_vector_file=vector_file, ramen_query_service=graph_ramen_query_service, generator_explanation=Explanation( explanation_string=f"This ramen is identified as being similar to the target ramen." ), context=RamenContext(target_ramen=test_ramen_101), )
def __init__(self, *, guideline_kg: GuidelineKgQueryService): self.guideline_kg = guideline_kg Pipeline.__init__( self, stages=( AllGuidelinesCandidateGenerator( guideline_query_service=self.guideline_kg), UserMatchGuidelineFilter(filter_explanation=Explanation( "User matches the conditions to apply this guideline.")), ), )
def test_sodium_below_target_scorer(food_kg: FoodKgQueryService, test_user: FoodKgUser, test_ingredient_vars): user_context = PatientContext(target_user=test_user) scorer_stage = SodiumBelowTargetScorer( success_scoring_explanation=Explanation( explanation_string="yes test3"), failure_scoring_explanation=Explanation(explanation_string="no test3"), ) as_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.amish_soup_recipe_uri) gp_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.gratin_potato_recipe_uri) as_score = scorer_stage.score( candidate=recipe_candidate_placeholder(as_rec, context=user_context)) gp_score = scorer_stage.score( candidate=recipe_candidate_placeholder(gp_rec, context=user_context)) assert gp_score == (True, 1) and as_score == (False, 0)
def test_calorie_scorer(food_kg: FoodKgQueryService, test_user: FoodKgUser, test_ingredient_vars): user_context = PatientContext(target_user=test_user) scorer_stage = RecipeCaloriesScorer( scoring_explanation=Explanation(explanation_string="test4"), ) gp_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.gratin_potato_recipe_uri) gp_score = scorer_stage.score( candidate=recipe_candidate_placeholder(gp_rec, context=user_context)) assert gp_score == 0.49534
def test_prohibited_ingredient_filter(food_kg: FoodKgQueryService, test_user: FoodKgUser, test_ingredient_vars): user_context = PatientContext(target_user=test_user) filter_stage = ContainsAnyProhibitedIngredientFilter( filter_explanation=Explanation(explanation_string="test1")) og_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.onion_garlic_pot_recipe_uri) gp_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.gratin_potato_recipe_uri) og_filter = filter_stage.filter( candidate=recipe_candidate_placeholder(og_rec, context=user_context)) gp_filter = filter_stage.filter( candidate=recipe_candidate_placeholder(gp_rec, context=user_context)) assert og_filter and not gp_filter
def _graph_get_guideline_by_uri(self, *, guideline_uri: URIRef) -> Guideline: """ Retrieve a guideline from the graph with the given URI. :param guideline_uri: the URI of the guideline to retrieve :return: a Guideline object """ # currently just using a static dictionary of example guidelines return self.temp_guideline_uri_dict.get( guideline_uri, Guideline( uri=guideline_uri, user_conditions=frozenset({}), filter_directives=frozenset(), scoring_directives=frozenset(), explanation=Explanation(explanation_string=""), ), )
def __init__( self, number_of_days: int, meals_per_day: int, **kwargs, ): self.number_of_days = number_of_days self.meals_per_day = meals_per_day days = [] for i in range(self.number_of_days): days.append(DomainObject(uri=URIRef(f"placeholderuri.com/{i}"))) days = tuple(days) day_ss = ( SectionSetConstraint() .set_sections(sections=days) .add_section_count_constraint(exact_count=self.meals_per_day) ) for day_ind, day in enumerate(days): for day2_ind, day2 in enumerate(days[day_ind + 1 :]): day_ss.add_section_assignment_constraint( section_a_uri=day.uri, section_b_uri=day2.uri, constraint_type=ConstraintType.AM1, ) self.solver = ( ConstraintSolver(scaling=100) .set_section_set_constraints(section_sets=(day_ss,)) .add_overall_count_constraint( exact_count=self.meals_per_day * self.number_of_days ) ) generator_explanation = Explanation( explanation_string="placeholder hardcoded explanation for a meal plan generation using knapsack problem" ) CandidateGenerator.__init__( self, generator_explanation=generator_explanation, **kwargs )
def test_get_meal_plan_for_user( food_kg_rec: GraphExplainableFoodRecommenderService, food_kg: LocalGraphFoodKgQueryService, test_user: FoodKgUser, placeholder_guideline: Guideline, test_ingredient_vars, ): mp = food_kg_rec.get_meal_plan_for_user(user=test_user, number_of_days=3, meals_per_day=1) expected_mp_rec = MealPlanRecommendation( meal_plan_days=( MealPlanDay( recipe_recommendations=(RecipeRecommendation( recipe=food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars. gratin_potato_recipe_uri), explanation= RecipeRecommendationExplanation(explanation_contents=( Explanation( explanation_string= "This recipe had a similarity score of 0.012105363142076218 " "to one of your favorite recipes, Lamb Chops au Gratin." ), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Adheres to guideline: As for the general population, people with " "diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), )), ), ), explanation=Explanation( explanation_string= f"This is a set of recommended recipes to eat for this day, " f"based on suggesting recipes that you are likely to like in general." ), ), MealPlanDay( recipe_recommendations=(RecipeRecommendation( recipe=food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.layer_din_recipe_uri), explanation= RecipeRecommendationExplanation(explanation_contents=( Explanation( explanation_string= "This recipe had a similarity score of 1.0 " "to one of your favorite recipes, Lamb Chops au Gratin." ), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Does not adhere to guideline: As for the general population, people with " "diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), )), ), ), explanation=Explanation( explanation_string= f"This is a set of recommended recipes to eat for this day, " f"based on suggesting recipes that you are likely to like in general." ), ), MealPlanDay( recipe_recommendations=(RecipeRecommendation( recipe=food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.amish_soup_recipe_uri), explanation= RecipeRecommendationExplanation(explanation_contents=( Explanation( explanation_string= "This recipe had a similarity score of 1.0 " "to one of your favorite recipes, Lamb Chops au Gratin." ), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Does not adhere to guideline: As for the general population, people with " "diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), )), ), ), explanation=Explanation( explanation_string= f"This is a set of recommended recipes to eat for this day, " f"based on suggesting recipes that you are likely to like in general." ), ), ), explanation=Explanation( explanation_string= f"This is a meal plan that was generated for 3 days of meals," f" eating 1 meals each day."), ) assert (frozenset(mp.domain_object.meal_plan_days) == frozenset( expected_mp_rec.meal_plan_days) and mp.domain_object.explanation == expected_mp_rec.explanation)
def rating_scorer() -> CandidateScorer: return RamenRatingScorer( scoring_explanation=Explanation( explanation_string="This ramen has a high rating score." ) )
def style_scorer() -> CandidateScorer: return RamenStyleScorer( scoring_explanation=Explanation( explanation_string="This ramen is the same style as the target ramen." ) )
def same_brand_filterer() -> CandidateFilterer: return SameBrandFilter( filter_explanation=Explanation( explanation_string="This ramen is from a different brand than the target ramen" ) )
def prohibited_country_filterer() -> CandidateFilterer: return RamenEaterProhibitCountryFilter( filter_explanation=Explanation( explanation_string="This ramen is not from a prohibited country." ) )
def test_rank_recommend_recipes_pipeline( food_kg: FoodKgQueryService, embedding_service: RecipeEmbeddingService, guideline_kg: GuidelineKgQueryService, test_user: FoodKgUser, test_ingredient_vars, ): test_pipe = RecommendRecipesPipeline( recipe_embedding_service=embedding_service, food_kg=food_kg, guideline_kg=guideline_kg, ) res = list(test_pipe(context=PatientContext(target_user=test_user))) ld_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.layer_din_recipe_uri) as_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.amish_soup_recipe_uri) lg_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.lamb_gratin_recipe_uri) og_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.onion_garlic_pot_recipe_uri) gp_rec = food_kg.get_recipe_by_uri( recipe_uri=test_ingredient_vars.gratin_potato_recipe_uri) expected_res = [ RecipeCandidate( context=PatientContext(target_user=test_user), domain_object=gp_rec, applied_explanations=[ Explanation( explanation_string= "This recipe had a similarity score of 0.012105363142076218 " "to one of your favorite recipes, Lamb Chops au Gratin."), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Adheres to guideline: As for the general population, people with diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), ], applied_scores=[0.012105363142076218, 0, 1, 0.49534], ), RecipeCandidate( context=PatientContext(target_user=test_user), domain_object=as_rec, applied_explanations=[ Explanation( explanation_string= "This recipe had a similarity score of 1.0 " "to one of your favorite recipes, Lamb Chops au Gratin."), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Does not adhere to guideline: As for the general population, people with diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), ], applied_scores=[1.0, 0, 0, -1.9261110339999998], ), RecipeCandidate( context=PatientContext(target_user=test_user), domain_object=ld_rec, applied_explanations=[ Explanation( explanation_string= "This recipe had a similarity score of 1.0 " "to one of your favorite recipes, Lamb Chops au Gratin."), Explanation( explanation_string= "This recipe does not contain any ingredients that are prohibited by you." ), Explanation( explanation_string= "Does not adhere to guideline: As for the general population, people with diabetes should limit sodium consumption to <2,300 mg/day." ), Explanation( explanation_string= "Scoring based on calories, this is mostly a placeholder to break ties." ), ], applied_scores=[1.0, 0, 0, -1.9750581350000003], ), ] assert res == expected_res