def naive(): bn = hh.BayesNet('A', 'B', 'C') bn.P['A'] = pd.Series({True: .1, False: .9}) bn.P['B'] = pd.Series({True: .3, False: .7}) bn.P['C'] = pd.Series({True: .5, False: .5}) bn.prepare() return bn
def load_sprinkler() -> hh.BayesNet: """Load the water sprinkler network. This example is taken from figure 14.12(a) of Artificial Intelligence: A Modern Approach. Example: >>> import hedgehog as hh >>> bn = hh.load_sprinkler() >>> bn.query('Rain', event={'Sprinkler': True}) Rain False 0.7 True 0.3 Name: P(Rain), dtype: float64 """ bn = hh.BayesNet(('Cloudy', 'Sprinkler'), ('Cloudy', 'Rain'), ('Sprinkler', 'Wet grass'), ('Rain', 'Wet grass')) # P(Cloudy) bn.P['Cloudy'] = pd.Series({False: .5, True: .5}) # P(Sprinkler | Cloudy) bn.P['Sprinkler'] = pd.Series({ (True, True): .1, (True, False): .9, (False, True): .5, (False, False): .5 }) # P(Rain | Cloudy) bn.P['Rain'] = pd.Series({ (True, True): .8, (True, False): .2, (False, True): .2, (False, False): .8 }) # P(Wet grass | Sprinkler, Rain) bn.P['Wet grass'] = pd.Series({ (True, True, True): .99, (True, True, False): .01, (True, False, True): .9, (True, False, False): .1, (False, True, True): .9, (False, True, False): .1, (False, False, True): 0, (False, False, False): 1 }) bn.prepare() return bn
def test_cpt_with_index_names(): """ From https://github.com/MaxHalford/hedgehog/issues/19 """ edges = pd.DataFrame({"parent": ["A", "B"], "child": "C"}) bn = hh.BayesNet(*edges.itertuples(index=False, name=None)) bn.P['A'] = pd.Series({True: 0.7, False: 0.3}) bn.P['B'] = pd.Series({True: 0.4, False: 0.6}) PC = pd.DataFrame({ "B": [True, True, True, True, False, False, False, False], "A": [True, True, False, False, True, True, False, False], "C": [ True, False, True, False, True, False, True, False, ], "p": [1, 0, 0, 1, 0.5, 0.5, 0.001, 0.999], }) bn.P["C"] = PC.set_index(["B", "A", "C"])["p"] bn.prepare() pd.testing.assert_series_equal( bn.query("C", event={ "B": False, "A": True }), pd.Series([0.5, 0.5], name="P(C)", index=pd.Index([False, True], name="C")))
def __init__(self): bn = hh.BayesNet( (["player_emotion_face", "player_emotion_voice" ], "player_emotion"), (["player_offer", "player_emotion"], "robot_decision"), (["player_offer", "player_emotion"], "robot_offer")) bn.P["player_emotion_face"] = pd.Series({ "positive": 0.3, "negative": 0.3, "neutral": 0.4 }) bn.P["player_emotion_voice"] = bn.P["player_emotion_face"] bn.P["player_emotion"] = pd.Series({ ("positive", "positive", "positive"): 1, ("positive", "positive", "negative"): 0, ("positive", "positive", "neutral"): 0, ("positive", "negative", "positive"): 0.5, ("positive", "negative", "negative"): 0.5, ("positive", "negative", "neutral"): 0, ("positive", "neutral", "positive"): 0.5, ("positive", "neutral", "negative"): 0.0, ("positive", "neutral", "neutral"): 0.5, # ----------------------------------------- ("negative", "positive", "positive"): 0.5, ("negative", "positive", "negative"): 0.5, ("negative", "positive", "neutral"): 0, ("negative", "negative", "positive"): 0, ("negative", "negative", "negative"): 1, ("negative", "negative", "neutral"): 0, ("negative", "neutral", "positive"): 0, ("negative", "neutral", "negative"): 0.5, ("negative", "neutral", "neutral"): 0.5, # ----------------------------------------- ("neutral", "positive", "positive"): 0.5, ("neutral", "positive", "negative"): 0, ("neutral", "positive", "neutral"): 0.5, ("neutral", "negative", "positive"): 0, ("neutral", "negative", "negative"): 0.5, ("neutral", "negative", "neutral"): 0.5, ("neutral", "neutral", "positive"): 0, ("neutral", "neutral", "negative"): 0, ("neutral", "neutral", "neutral"): 1, }) bn.P["player_offer"] = pd.Series({ "1": 0.03, "2": 0.07, "3": 0.1, "4": 0.2, "5": 0.2, "6": 0.2, "7": 0.1, "8": 0.07, "9": 0.03, }) bn.P["robot_decision"] = pd.Series({ ("positive", "1", "yes"): 0, ("positive", "1", "no"): 1, ("positive", "2", "yes"): 0, ("positive", "2", "no"): 1, ("positive", "3", "yes"): 0, ("positive", "3", "no"): 1, ("positive", "4", "yes"): 0.2, ("positive", "4", "no"): 0.8, ("positive", "5", "yes"): 0.4, ("positive", "5", "no"): 0.6, ("positive", "6", "yes"): 1, ("positive", "6", "no"): 0, ("positive", "7", "yes"): 1, ("positive", "7", "no"): 0, ("positive", "8", "yes"): 1, ("positive", "8", "no"): 0, ("positive", "9", "yes"): 1, ("positive", "9", "no"): 0, # ----------------------- ("negative", "1", "yes"): 0.4, ("negative", "1", "no"): 0.6, ("negative", "2", "yes"): 0.5, ("negative", "2", "no"): 0.5, ("negative", "3", "yes"): 0.6, ("negative", "3", "no"): 0.4, ("negative", "4", "yes"): 0.7, ("negative", "4", "no"): 0.3, ("negative", "5", "yes"): 0.8, ("negative", "5", "no"): 0.2, ("negative", "6", "yes"): 0.9, ("negative", "6", "no"): 0.1, ("negative", "7", "yes"): 0.9, ("negative", "7", "no"): 0.1, ("negative", "8", "yes"): 1, ("negative", "8", "no"): 0, ("negative", "9", "yes"): 1, ("negative", "9", "no"): 0, # ----------------------- ("neutral", "1", "yes"): 0, ("neutral", "1", "no"): 1, ("neutral", "2", "yes"): 0, ("neutral", "2", "no"): 1, ("neutral", "3", "yes"): 0.1, ("neutral", "3", "no"): 0.9, ("neutral", "4", "yes"): 0.3, ("neutral", "4", "no"): 0.7, ("neutral", "5", "yes"): 0.5, ("neutral", "5", "no"): 0.5, ("neutral", "6", "yes"): 0.7, ("neutral", "6", "no"): 0.3, ("neutral", "7", "yes"): 0.9, ("neutral", "7", "no"): 0.1, ("neutral", "8", "yes"): 1, ("neutral", "8", "no"): 0, ("neutral", "9", "yes"): 1, ("neutral", "9", "no"): 0, }) bn.P["robot_offer"] = pd.Series({ ("positive", "1", "1"): 0.95, ("positive", "1", "2"): 0.05, ("positive", "1", "3"): 0, ("positive", "1", "4"): 0, ("positive", "1", "5"): 0, ("positive", "1", "6"): 0, ("positive", "1", "7"): 0, ("positive", "1", "8"): 0, ("positive", "1", "9"): 0, ("positive", "2", "1"): 0.9, ("positive", "2", "2"): 0.1, ("positive", "2", "3"): 0, ("positive", "2", "4"): 0, ("positive", "2", "5"): 0, ("positive", "2", "6"): 0, ("positive", "2", "7"): 0, ("positive", "2", "8"): 0, ("positive", "2", "9"): 0, ("positive", "3", "1"): 0.8, ("positive", "3", "2"): 0.15, ("positive", "3", "3"): 0.05, ("positive", "3", "4"): 0, ("positive", "3", "5"): 0, ("positive", "3", "6"): 0, ("positive", "3", "7"): 0, ("positive", "3", "8"): 0, ("positive", "3", "9"): 0, ("positive", "4", "1"): 0.6, ("positive", "4", "2"): 0.2, ("positive", "4", "3"): 0.1, ("positive", "4", "4"): 0.1, ("positive", "4", "5"): 0, ("positive", "4", "6"): 0, ("positive", "4", "7"): 0, ("positive", "4", "8"): 0, ("positive", "4", "9"): 0, ("positive", "5", "1"): 0.3, ("positive", "5", "2"): 0.3, ("positive", "5", "3"): 0.2, ("positive", "5", "4"): 0.1, ("positive", "5", "5"): 0.1, ("positive", "5", "6"): 0, ("positive", "5", "7"): 0, ("positive", "5", "8"): 0, ("positive", "5", "9"): 0, ("positive", "6", "1"): 0.1, ("positive", "6", "2"): 0.2, ("positive", "6", "3"): 0.2, ("positive", "6", "4"): 0.2, ("positive", "6", "5"): 0.2, ("positive", "6", "6"): 0.1, ("positive", "6", "7"): 0, ("positive", "6", "8"): 0, ("positive", "6", "9"): 0, ("positive", "7", "1"): 0.1, ("positive", "7", "2"): 0.1, ("positive", "7", "3"): 0.2, ("positive", "7", "4"): 0.2, ("positive", "7", "5"): 0.2, ("positive", "7", "6"): 0.1, ("positive", "7", "7"): 0.1, ("positive", "7", "8"): 0, ("positive", "7", "9"): 0, ("positive", "8", "1"): 0.1, ("positive", "8", "2"): 0.1, ("positive", "8", "3"): 0.1, ("positive", "8", "4"): 0.1, ("positive", "8", "5"): 0.1, ("positive", "8", "6"): 0.2, ("positive", "8", "7"): 0.2, ("positive", "8", "8"): 0.1, ("positive", "8", "9"): 0, ("positive", "9", "1"): 0.1, ("positive", "9", "2"): 0.1, ("positive", "9", "3"): 0.1, ("positive", "9", "4"): 0.1, ("positive", "9", "5"): 0.1, ("positive", "9", "6"): 0.1, ("positive", "9", "7"): 0.1, ("positive", "9", "8"): 0.2, ("positive", "9", "9"): 0.1, # ----------------------- ("negative", "1", "1"): 0, ("negative", "1", "2"): 0, ("negative", "1", "3"): 0, ("negative", "1", "4"): 0, ("negative", "1", "5"): 0.05, ("negative", "1", "6"): 0.15, ("negative", "1", "7"): 0.2, ("negative", "1", "8"): 0.3, ("negative", "1", "9"): 0.3, ("negative", "2", "1"): 0, ("negative", "2", "2"): 0, ("negative", "2", "3"): 0, ("negative", "2", "4"): 0.05, ("negative", "2", "5"): 0.15, ("negative", "2", "6"): 0.2, ("negative", "2", "7"): 0.2, ("negative", "2", "8"): 0.2, ("negative", "2", "9"): 0.2, ("negative", "3", "1"): 0, ("negative", "3", "2"): 0, ("negative", "3", "3"): 0.05, ("negative", "3", "4"): 0.15, ("negative", "3", "5"): 0.1, ("negative", "3", "6"): 0.2, ("negative", "3", "7"): 0.2, ("negative", "3", "8"): 0.2, ("negative", "3", "9"): 0.2, ("negative", "4", "1"): 0, ("negative", "4", "2"): 0, ("negative", "4", "3"): 0, ("negative", "4", "4"): 0.05, ("negative", "4", "5"): 0.15, ("negative", "4", "6"): 0.2, ("negative", "4", "7"): 0.2, ("negative", "4", "8"): 0.2, ("negative", "4", "9"): 0.2, ("negative", "5", "1"): 0, ("negative", "5", "2"): 0, ("negative", "5", "3"): 0, ("negative", "5", "4"): 0, ("negative", "5", "5"): 0.05, ("negative", "5", "6"): 0.15, ("negative", "5", "7"): 0.2, ("negative", "5", "8"): 0.3, ("negative", "5", "9"): 0.3, ("negative", "6", "1"): 0, ("negative", "6", "2"): 0, ("negative", "6", "3"): 0, ("negative", "6", "4"): 0, ("negative", "6", "5"): 0, ("negative", "6", "6"): 0.2, ("negative", "6", "7"): 0.2, ("negative", "6", "8"): 0.3, ("negative", "6", "9"): 0.3, ("negative", "7", "1"): 0, ("negative", "7", "2"): 0, ("negative", "7", "3"): 0, ("negative", "7", "4"): 0, ("negative", "7", "5"): 0, ("negative", "7", "6"): 0, ("negative", "7", "7"): 0.3, ("negative", "7", "8"): 0.4, ("negative", "7", "9"): 0.4, ("negative", "8", "1"): 0, ("negative", "8", "2"): 0, ("negative", "8", "3"): 0, ("negative", "8", "4"): 0, ("negative", "8", "5"): 0, ("negative", "8", "6"): 0, ("negative", "8", "7"): 0.1, ("negative", "8", "8"): 0.4, ("negative", "8", "9"): 0.5, ("negative", "9", "1"): 0, ("negative", "9", "2"): 0, ("negative", "9", "3"): 0, ("negative", "9", "4"): 0, ("negative", "9", "5"): 0, ("negative", "9", "6"): 0, ("negative", "9", "7"): 0, ("negative", "9", "8"): 0.5, ("negative", "9", "9"): 0.5, # ----------------------- ("neutral", "1", "1"): 0.03, ("neutral", "1", "2"): 0.07, ("neutral", "1", "3"): 0.1, ("neutral", "1", "4"): 0.2, ("neutral", "1", "5"): 0.2, ("neutral", "1", "6"): 0.2, ("neutral", "1", "7"): 0.1, ("neutral", "1", "8"): 0.07, ("neutral", "1", "9"): 0.03, ("neutral", "2", "1"): 0.03, ("neutral", "2", "2"): 0.07, ("neutral", "2", "3"): 0.1, ("neutral", "2", "4"): 0.2, ("neutral", "2", "5"): 0.2, ("neutral", "2", "6"): 0.2, ("neutral", "2", "7"): 0.1, ("neutral", "2", "8"): 0.07, ("neutral", "2", "9"): 0.03, ("neutral", "3", "1"): 0.03, ("neutral", "3", "2"): 0.07, ("neutral", "3", "3"): 0.1, ("neutral", "3", "4"): 0.2, ("neutral", "3", "5"): 0.2, ("neutral", "3", "6"): 0.2, ("neutral", "3", "7"): 0.1, ("neutral", "3", "8"): 0.07, ("neutral", "3", "9"): 0.03, ("neutral", "4", "1"): 0.03, ("neutral", "4", "2"): 0.07, ("neutral", "4", "3"): 0.1, ("neutral", "4", "4"): 0.2, ("neutral", "4", "5"): 0.2, ("neutral", "4", "6"): 0.2, ("neutral", "4", "7"): 0.1, ("neutral", "4", "8"): 0.07, ("neutral", "4", "9"): 0.03, ("neutral", "5", "1"): 0.03, ("neutral", "5", "2"): 0.07, ("neutral", "5", "3"): 0.1, ("neutral", "5", "4"): 0.2, ("neutral", "5", "5"): 0.2, ("neutral", "5", "6"): 0.2, ("neutral", "5", "7"): 0.1, ("neutral", "5", "8"): 0.07, ("neutral", "5", "9"): 0.03, ("neutral", "6", "1"): 0.03, ("neutral", "6", "2"): 0.07, ("neutral", "6", "3"): 0.1, ("neutral", "6", "4"): 0.2, ("neutral", "6", "5"): 0.2, ("neutral", "6", "6"): 0.2, ("neutral", "6", "7"): 0.1, ("neutral", "6", "8"): 0.07, ("neutral", "6", "9"): 0.03, ("neutral", "7", "1"): 0.03, ("neutral", "7", "2"): 0.07, ("neutral", "7", "3"): 0.1, ("neutral", "7", "4"): 0.2, ("neutral", "7", "5"): 0.2, ("neutral", "7", "6"): 0.2, ("neutral", "7", "7"): 0.1, ("neutral", "7", "8"): 0.07, ("neutral", "7", "9"): 0.03, ("neutral", "8", "1"): 0.03, ("neutral", "8", "2"): 0.07, ("neutral", "8", "3"): 0.1, ("neutral", "8", "4"): 0.2, ("neutral", "8", "5"): 0.2, ("neutral", "8", "6"): 0.2, ("neutral", "8", "7"): 0.1, ("neutral", "8", "8"): 0.07, ("neutral", "8", "9"): 0.03, ("neutral", "9", "1"): 0.03, ("neutral", "9", "2"): 0.07, ("neutral", "9", "3"): 0.1, ("neutral", "9", "4"): 0.2, ("neutral", "9", "5"): 0.2, ("neutral", "9", "6"): 0.2, ("neutral", "9", "7"): 0.1, ("neutral", "9", "8"): 0.07, ("neutral", "9", "9"): 0.3, }) bn.prepare() self.bn = bn self.emotionSimplifier = { 'happy': 'positive', 'surprised': 'neutral', 'calm': 'neutral', 'disgusted': 'negative', 'sad': 'negative', 'fearful': 'negative', 'angry': 'negative' }
def load_asia() -> hh.BayesNet: """Load the Asia network. Example: >>> import hedgehog as hh >>> bn = hh.load_asia() >>> bn.query('Lung cancer', event={'Visit to Asia': True, 'Smoker': False}) Lung cancer False 0.99 True 0.01 Name: P(Lung cancer), dtype: float64 """ bn = hh.BayesNet(('Visit to Asia', 'Tuberculosis'), ('Smoker', ('Lung cancer', 'Bronchitis')), (('Tuberculosis', 'Lung cancer'), 'TB or cancer'), ('TB or cancer', ('Positive X-ray', 'Dispnea')), ('Bronchitis', 'Dispnea')) # P(Visit to Asia) bn.P['Visit to Asia'] = pd.Series({True: .01, False: .99}) # P(Tuberculosis | Visit to Asia) bn.P['Tuberculosis'] = pd.Series({ (True, True): .05, (True, False): .95, (False, True): .01, (False, False): .99 }) # P(Smoker) bn.P['Smoker'] = pd.Series({True: .5, False: .5}) # P(Lung cancer | Smoker) bn.P['Lung cancer'] = pd.Series({ (True, True): .1, (True, False): .9, (False, True): .01, (False, False): .99 }) # P(Bronchitis | Smoker) bn.P['Bronchitis'] = pd.Series({ (True, True): .6, (True, False): .4, (False, True): .3, (False, False): .7 }) # P(TB or cancer | Tuberculosis, Lung cancer) bn.P['TB or cancer'] = pd.Series({ (True, True, True): 1, (True, True, False): 0, (True, False, True): 1, (True, False, False): 0, (False, True, True): 1, (False, True, False): 0, (False, False, True): 0, (False, False, False): 1 }) # P(Positive X-ray | TB or cancer) bn.P['Positive X-ray'] = pd.Series({ (True, True): .98, (True, False): .02, (False, True): .05, (False, False): .95 }) # P(Dispnea | TB or cancer, Bronchitis) bn.P['Dispnea'] = pd.Series({ (True, True, True): .9, (True, True, False): .1, (True, False, True): .7, (True, False, False): .3, (False, True, True): .8, (False, True, False): .2, (False, False, True): .1, (False, False, False): .9 }) bn.prepare() return bn
def load_alarm() -> hh.BayesNet: """Load Judea Pearl's famous example. At the time of writing his seminal paper on Bayesian networks, Judea Pearl lived in California, where earthquakes are quite common. Example: >>> import hedgehog as hh >>> bn = hh.load_alarm() >>> bn.query('John calls', 'Mary calls', event={'Burglary': True, 'Earthquake': False}) John calls Mary calls False False 0.08463 True 0.06637 True False 0.25677 True 0.59223 Name: P(John calls, Mary calls), dtype: float64 """ bn = hh.BayesNet(('Burglary', 'Alarm'), ('Earthquake', 'Alarm'), ('Alarm', 'John calls'), ('Alarm', 'Mary calls')) # P(Burglary) bn.P['Burglary'] = pd.Series({False: .999, True: .001}) # P(Earthquake) bn.P['Earthquake'] = pd.Series({False: .998, True: .002}) # P(Alarm | Burglary, Earthquake) bn.P['Alarm'] = pd.Series({ (True, True, True): .95, (True, True, False): .05, (True, False, True): .94, (True, False, False): .06, (False, True, True): .29, (False, True, False): .71, (False, False, True): .001, (False, False, False): .999 }) # P(John calls | Alarm) bn.P['John calls'] = pd.Series({ (True, True): .9, (True, False): .1, (False, True): .05, (False, False): .95 }) # P(Mary calls | Alarm) bn.P['Mary calls'] = pd.Series({ (True, True): .7, (True, False): .3, (False, True): .01, (False, False): .99 }) bn.prepare() return bn
def load_grades(): """Load the student grades network. Example: >>> import hedgehog as hh >>> bn = hh.load_grades() >>> bn.nodes ['Difficulty', 'Grade', 'Intelligence', 'Letter', 'SAT'] >>> bn.children {'Difficulty': ['Grade'], 'Intelligence': ['Grade', 'SAT'], 'Grade': ['Letter']} >>> bn.parents {'Grade': ['Difficulty', 'Intelligence'], 'SAT': ['Intelligence'], 'Letter': ['Grade']} >>> bn.query('Letter', 'SAT', event={'Intelligence': 'Smart'}) Letter SAT Strong Failure 0.153544 Success 0.614176 Weak Failure 0.046456 Success 0.185824 Name: P(Letter, SAT), dtype: float64 """ bn = hh.BayesNet(('Difficulty', 'Grade'), ('Intelligence', 'Grade'), ('Intelligence', 'SAT'), ('Grade', 'Letter')) # P(Difficulty) bn.P['Difficulty'] = pd.Series({'Easy': .6, 'Hard': .4}) # P(Intelligence) bn.P['Intelligence'] = pd.Series({'Average': .7, 'Smart': .3}) # P(Grade | Difficult, Intelligence) bn.P['Grade'] = pd.Series({ ('Easy', 'Average', 'A'): .3, ('Easy', 'Average', 'B'): .4, ('Easy', 'Average', 'C'): .3, ('Easy', 'Smart', 'A'): .9, ('Easy', 'Smart', 'B'): .08, ('Easy', 'Smart', 'C'): .02, ('Hard', 'Average', 'A'): .05, ('Hard', 'Average', 'B'): .25, ('Hard', 'Average', 'C'): .7, ('Hard', 'Smart', 'A'): .5, ('Hard', 'Smart', 'B'): .3, ('Hard', 'Smart', 'C'): .2 }) # P(SAT | Intelligence) bn.P['SAT'] = pd.Series({ ('Average', 'Failure'): .95, ('Average', 'Success'): .05, ('Smart', 'Failure'): .2, ('Smart', 'Success'): .8 }) # P(Letter | Grade) bn.P['Letter'] = pd.Series({ ('A', 'Weak'): .1, ('A', 'Strong'): .9, ('B', 'Weak'): .4, ('B', 'Strong'): .6, ('C', 'Weak'): .99, ('C', 'Strong'): .01 }) bn.prepare() return bn