def test_colour_count_CPD_generation2(): pgm_model = CorrectionPGMModel() time = 0 rule = ColourCountRule('blue', 1) cm = KDEColourModel('blue') objects = ['b1', 'b2', 'b3', 'b4', 'b5'] violations = pgm_model.add_colour_count_correction(rule, cm, objects, time) data = {f'corr_{time}': 1} for obj in objects: data[f'F({obj})'] = [1, 1, 1] pgm_model.observe(data) colours = [f'blue({obj})' for obj in objects] evidence = colours + violations q = pgm_model.query(evidence) assert (abs(q[violations[0]] - 1.0) < 0.001) for colour in colours: assert (abs(q[colour] - 2 / len(objects)) < 0.000001)
def test_get_correction_colour_count(): w = RandomColoursWorld('blocks-domain-updated.pddl', problem_directory='multitower', problem_number=1) teacher = ExtendedTeacherAgent() agent = PGMCorrectingAgent(w, teacher=teacher) w.update('put', ['b0', 't1', 'tower1']) w.update('put', ['b2', 'b0', 'tower1']) w.update('put', ['b5', 'b2', 'tower1']) rule = ColourCountRule('blue', 1) b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346] b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336] b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746] b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052] correction = teacher.correction(w) agent.get_correction(correction, 'put', ['b5', 'b2', 'tower1']) test_b5_data_in_cm = any([ array_equal(datum, b5_data) for datum in agent.colour_models['blue'].data ]) assert (test_b5_data_in_cm) rule_probs = agent.pgm_model.get_rule_probs() assert (rule_probs[rule] > 0.5)
def test_update_model(): w = RandomColoursWorld('blocks-domain-updated.pddl', problem_directory='multitower', problem_number=1) w.update('put', ['b0', 't1', 'tower1']) w.update('put', ['b2', 'b0', 'tower1']) w.update('put', ['b5', 'b2', 'tower1']) agent = PGMCorrectingAgent(w) rule = ColourCountRule('blue', 1) b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346] b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336] b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746] b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052] violations, data, message = agent.update_model( f"no, you cannot put more than 1 blue blocks in a tower", ['b5', 'b2', 'tower1']) assert (f'V_0({rule})' in violations) assert (message.T == 'colour count') assert (message.o1 == 'blue') assert (message.o2 == 1) assert (array_equal(data["F(b0)"], b0_data)) assert (array_equal(data["F(b5)"], b5_data)) assert (array_equal((data["F(b2)"]), b2_data)) assert (data["corr_0"] == 1)
def test_get_correction_colour_count2(): w = RandomColoursWorld('blocks-domain-updated.pddl', problem_directory='multitower', problem_number=2) teacher = ExtendedTeacherAgent() agent = PGMCorrectingAgent(w, teacher=teacher) w.update('put', ['b0', 't1', 'tower1']) w.update('put', ['b2', 'b0', 'tower1']) w.update('put', ['b1', 'b2', 'tower1']) rule = ColourCountRule('blue', 1) b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346] b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336] b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746] b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052] correction = teacher.correction(w) assert ( correction == "no, you cannot put more than 1 blue blocks in a tower and you must put blue blocks on red blocks" ) agent.get_correction(correction, 'put', ['b1', 'b2', 'tower1']) colour_predictions = agent.pgm_model.get_colour_predictions() assert (abs(colour_predictions['red(b1)'] - 1.0) < 0.0001) assert (array_equal(agent.get_colour_data(['b1'])['F(b1)'], b1_data)) assert (agent.colour_models['red'].data != []) assert (len(agent.colour_models['red'].data) == 1) test_b1_data_in_cm = reduce(lambda x, y: x or y, [ array_equal(agent.colour_models['red'].data[i], np.array(b1_data)) for i in range(len(agent.colour_models['red'].data)) ]) test_b2_data_in_cm = reduce(lambda x, y: x or y, [ array_equal(datum, b2_data) for datum in agent.colour_models['red'].data ]) test_b0_data_in_cm = reduce(lambda x, y: x or y, [ array_equal(datum, np.array(b0_data)) for datum in agent.colour_models['red'].data ]) assert (np.all( np.array(hsv2rgb([[b1_data]])[0][0]) == agent.colour_models['red'].data[0])) assert (test_b0_data_in_cm is False) assert (test_b2_data_in_cm is False) assert (test_b1_data_in_cm) rule_probs = agent.pgm_model.get_rule_probs() assert (rule_probs[rule] > 0.5)
def generate_colour_count(max_num=4, exact_num=None): # colour = random.choice(list(colour_dict.keys())) colour = random.choice(list(fruit_dict.keys())) # print("colour1") # print(colour) if exact_num is not None: number = exact_num else: number = random.choice(range(1, max_num + 1)) return ColourCountRule(colour, number)
def test_extended_blocks_world_problem(): cc = ColourCountRule('blue', 2) colours = ['red', 'blue', 'green', 'yellow', 'pink', 'purple', 'red', 'blue', 'blue', 'blue'] problem = problem_def.ExtendedBlocksWorldProblem(num_blocks=10, num_towers=2, rules=[cc.to_formula()], colours=colours) assert(problem.goal.subformulas[0].asPDDL() == "(forall (?x) (done ?x))") assert(problem.goal.subformulas[1].asPDDL() == cc.asPDDL()) state = PDDLState.from_problem(problem) for i, colour in enumerate(colours): assert(state.predicate_holds(colour, [f"b{i}"])) assert(state.predicate_holds('clear', [f'b{i}'])) assert(state.predicate_holds('on-table', [f'b{i}'])) for colour in ['red', 'blue', 'green', 'yellow', 'pink', 'purple', 'orange']: assert(state.get_colour_count(colour, 'tower0') == 0) assert (state.get_colour_count(colour, 'tower1') == 0)
def add_joint_violations(self, colour_count: ColourCountRule, red_on_blue_options: list, time: int, colours_in_tower: list, model: FactorGraph): p = self.get_p(CorrectionType.TABLE) cpds = [ colour_count.generateCPD(num_blocks_in_tower=len(colours_in_tower), correction_type=CorrectionType.TABLE, p=p) for rule in red_on_blue_options ] violated_rule_factor_name1 = f"V_{time}({colour_count} && {red_on_blue_options[0]})" violated_rule_factor_name2 = f"V_{time}({colour_count} && {red_on_blue_options[1]})" evidence1 = colours_in_tower + [red_on_blue_options[0], colour_count] evidence2 = colours_in_tower + [red_on_blue_options[1], colour_count] cpd1, cpd2 = cpds rule_violated_factor1 = TabularCPD(violated_rule_factor_name1, 2, cpd1, evidence=evidence1, evidence_card=[2] * len(evidence1)) rule_violated_factor1 = rule_violated_factor1.to_factor() rule_violated_factor2 = TabularCPD(violated_rule_factor_name2, 2, cpd1, evidence=evidence2, evidence_card=[2] * len(evidence2)) rule_violated_factor2 = rule_violated_factor2.to_factor() self.add_factor([ violated_rule_factor_name1, rule_violated_factor1, red_on_blue_options[0], colour_count ], rule_violated_factor1, model) self.add_factor([ violated_rule_factor_name2, rule_violated_factor2, red_on_blue_options[1], colour_count ], rule_violated_factor2, model) return [violated_rule_factor_name1, violated_rule_factor_name2]
def test_colour_count_CPD_generation3(): pgm_model = CorrectionPGMModel() time = 0 colour_count = ColourCountRule('blue', 1) red_on_blue_options = Rule.generate_red_on_blue_options('blue', 'red') blue_cm = KDEColourModel('blue') red_cm = KDEColourModel('red') objects = [f"b{o}" for o in range(10)] objects_in_tower = objects[:5] top_object = objects[4] violations = pgm_model.add_cc_and_rob(colour_count, red_on_blue_options, blue_cm, red_cm, objects_in_tower, top_object, time) data = {f'corr_{time}': 1} for obj in objects: data[f'F({obj})'] = [1, 1, 1] pgm_model.observe(data) colours_in_tower = [f'blue({obj})' for obj in objects_in_tower[:-1] ] + [f'red({top_object})'] evidence = colours_in_tower + violations + red_on_blue_options + [ colour_count ] q = pgm_model.query(evidence) assert (abs(q[violations[0]] - 2 / 3) < 0.001) assert (abs(q[violations[1]] - 2 / 3) < 0.0001) assert (abs(q[colour_count] - 1.0) < 0.00001) assert (abs(q[red_on_blue_options[0]] - 2 / 3) < 0.00001) assert (pgm_model.colours['red'] is not None) assert (abs(q[colours_in_tower[-1]] - 1.0) < 0.00001) for colour in colours_in_tower[:-1]: assert (abs(q[colour] - 1 / (len(colours_in_tower) - 1) < 0.0001))
def test_colour_count_CPD_generation(): pgm_model = CorrectionPGMModel() time = 0 rule = ColourCountRule('blue', 1) cm = KDEColourModel('blue') new_model = pgm_model.add_new_model() violations = pgm_model.add_colour_count_correction(rule, cm, ['b1', 'b2'], time, new_model) pgm_model.observe({ 'F(b1)': [1, 1, 1], 'F(b2)': [0, 0, 0], f'corr_{time}': 1 }) q = pgm_model.query(['blue(b1)', 'blue(b2)', violations[0]]) assert (q[violations[0]] == 1.0) assert (q["blue(b1)"] == 1.0) assert (q["blue(b2)"] == 1.0)
def build_model(self, message, args): data = self.get_data(message, args) i = 0 # To reduce the number of individual colour names reuse colour models where the colour is the same # This only works because we have a set number of colour terms # First try to find colour models that match particular data points if data['o2'] is None: print("o2 is bottom position") return try: c1, c1_model = self.find_matching_cm(data['o1']) c2, c2_model = self.find_matching_cm(data['o2']) except ValueError: print("too many colour models, stopped learning") return #If tower: create the negated rule and two new colour variables. if message.T == 'tower': if data['o2'] is None: return rule = correctingagent.world.rules.NotRedOnBlueRule(c1, c2).to_formula() self.colour_models.update({c1:c1_model, c2:c2_model}) self.rule_models['not {} and {}'.format(c1, c2)] = rule return elif message.T == 'table': c3, c3_model = self.find_matching_cm(data['o3']) rule1 = RedOnBlueRule(c3, c2, 1) rule2 = RedOnBlueRule(c1, c3, 2) # rule1 = correctingagent.world.rules.create_red_on_blue_rule([c3], [c2]) # rule2 = correctingagent.world.rules.create_red_on_blue_rule([c3], [c1], ['?y', '?x']) if data['o2'] is None: self.colour_models.update({c1:c1_model, c3:c3_model}) self.rule_models[rule2] = rule2 else: self.colour_models.update({c1:c1_model, c2:c2_model, c3:c3_model}) obs = self.world.sense() try: test = search.ActiveLearningTest(rule1.to_formula(), rule2.to_formula(), obs.colours, c3_model, args[1], self.world) self.active_tests.append(test) except TestFailed: self.rule_models[rule2] = rule2.to_formula() return elif message.T == 'colour count': n = 0 tower = args[-1] try: cm = self.colour_models['blue'] except KeyError: cm = prob_model.KDEColourModel('blue', data=np.array([data['o1']]), weights=np.array([1]), **self.model_config) blocks_in_tower = self.world.state.get_objects_in_tower(tower) for block in blocks_in_tower: datum = self.world.observe_object(block) if cm.p(1, datum) > 0.5: n = min(3, n+1) if n > 0: rule = ColourCountRule('blue', n) self.colour_models.update({'blue':cm}) self.rule_models["colour_count"] = rule.to_formula() return elif message.T == 'colour count+tower': if 'blue' in self.colour_models.keys(): rule = RedOnBlueRule('blue', c1, 2) self.colour_models.update({c1:c1_model}) self.rule_models[str(rule)] = rule.to_formula() return else: return
def update_model(self, user_input, args): message = read_sentence(user_input) # print("correction ", message.T, args) if message.T == 'recover': red = message.o1[0] blue = message.o2[0] rules = Rule.generate_red_on_blue_options(red, blue) red_cm = self.add_cm(red) blue_cm = self.add_cm(blue) # # print(message.o3) # print(rules) violations = self.pgm_model.add_recovery(message.o3, self.time, rules, red_cm, blue_cm) # print(violations) data = self.get_colour_data(message.o3) corr = corr_variable_name(self.time) data[corr] = 1 elif message.T == 'same reason': for prev_corr, prev_time in self.previous_corrections[::-1]: if 'same reason' not in prev_corr: break # print("Same reason", self.time) # print(prev_corr, prev_time) prev_message = read_sentence(prev_corr, use_dmrs=False) user_input = prev_corr violations = self.build_same_reason(prev_message, args, prev_time) data = self.get_relevant_data(args, prev_message) message = prev_message elif message.T in ['tower', 'table']: if isinstance( self.teacher, FaultyTeacherAgent ) and message.T == 'table' and self.teacher.recover_prob > 0: # print(self.teacher.recover_prob) red = message.o1[0] blue = message.o2[0] rules = Rule.generate_red_on_blue_options(red, blue) tower_name = args[-1] if 't' in args[-1] else None red_cm = self.add_cm(red) blue_cm = self.add_cm(blue) objects = self.world.state.get_objects_in_tower(tower_name) violations = self.pgm_model.create_uncertain_table_model( rules, red_cm, blue_cm, args + [message.o3], objects, self.time) data = self.get_colour_data(objects + [message.o3]) corr = corr_variable_name(self.time) data[corr] = 1 else: data = self.get_relevant_data(args, message) violations = self.build_pgm_model(message, args) corr = corr_variable_name(self.time) data[corr] = 1 elif 'partial.neg' == message.T: # print("partial.neg", self.time) for i, (prev_corr, prev_time) in enumerate(self.previous_corrections[::-1]): if message.o1 in prev_corr: break user_input = prev_corr prev_message = read_sentence(prev_corr, use_dmrs=False) violations = self.build_pgm_model(prev_message, args) prev_args = self.previous_args[prev_time] if message.o1 == prev_message.o1[0]: curr_negation = f'{message.o1}({args[0]})' prev_negation = f'{message.o1}({prev_args[0]})' else: curr_negation = f'{message.o1}({args[1]})' prev_negation = f'{message.o1}({prev_args[1]})' data = self.get_relevant_data(args, prev_message) data[curr_negation] = 0 data[prev_negation] = 0 # print(data) elif 'colour count' == message.T: colour_name = message.o1 number = message.o2 rule = ColourCountRule(colour_name, number) cm = self.add_cm(colour_name) tower_name = args[-1] objects = self.world.state.get_objects_in_tower(tower_name) top, _ = self.world.state.get_top_two(tower_name) if self.simplified_colour_count: objects = [top] violations = self.pgm_model.add_colour_count_correction( rule, cm, objects, self.time, faulty_teacher=isinstance(self.teacher, FaultyTeacherAgent)) data = self.get_colour_data(objects) corr = f"corr_{self.time}" data[corr] = 1 blue_top_obj = f"{colour_name}({top})" data[blue_top_obj] = 1 elif 'colour count+tower' == message.T: colour_name = message.o1 number = message.o2 colour_count = ColourCountRule(colour_name, number) red_cm = self.add_cm(message.o3.o1[0]) blue_cm = self.add_cm(message.o3.o2[0]) red_on_blue = Rule.generate_red_on_blue_options( message.o3.o1[0], message.o3.o2[0]) tower_name = args[-1] top, _ = self.world.state.get_top_two(tower_name) objects = self.world.state.get_objects_in_tower(tower_name) if self.simplified_colour_count: objects = [top] violations = self.pgm_model.add_cc_and_rob(colour_count, red_on_blue, red_cm, blue_cm, objects, top, self.time) data = self.get_colour_data(objects) corr = f"corr_{self.time}" data[corr] = 1 self.previous_corrections.append((user_input, self.time)) self.previous_args[self.time] = args return violations, data, message