Beispiel #1
0
def test_sampling_inference_no_correction():

    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BayesianModelSampler)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)
    # inference = PGMPYInference(pgm_model)
    # inference.infer({'F(b1)':[1,1,1], 'F(b2)':[0,0,0], f'corr_{time}':1})
    # q = inference.query(violations)
    # #
    # q = pgm_model.query(violations, [1, 1])
    #
    # assert(q[violations[0]] - 0.5 < 0.2)
    # assert(q[violations[1]] - 0.5 < 0.2)
    #
    #
    time = 1
    violations = pgm_model.add_no_correction(['b3', 'b2'], time,
                                             red_on_blue_rules)

    pgm_model.observe({
        'F(b3)': [0.5, 0.5, 0.5],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 0
    })

    q = pgm_model.query(violations)
Beispiel #2
0
    def __init__(self,
                 world,
                 colour_models=None,
                 rule_beliefs=None,
                 domain_file='blocks-domain.pddl',
                 teacher=None,
                 threshold=0.7,
                 update_negative=True,
                 update_once=True,
                 colour_model_type='default',
                 model_config={},
                 tracker=Tracker(),
                 debug=None,
                 simplified_colour_count=False,
                 inference_type=InferenceType.SearchInference,
                 max_inference_size=-1,
                 max_beam_size=-1,
                 p_direct=1,
                 p_indirect=1):

        super(PGMCorrectingAgent,
              self).__init__(world, colour_models, rule_beliefs, domain_file,
                             teacher, threshold, update_negative, update_once,
                             colour_model_type, model_config, tracker)

        self.debug = defaultdict(lambda: False)
        if debug is not None:
            self.debug.update(debug)

        self.pgm_model = CorrectionPGMModel(
            inference_type=inference_type,
            max_inference_size=max_inference_size,
            max_beam_size=max_beam_size,
            p_direct=p_direct,
            p_indirect=p_indirect)
        self.time = 0
        self.last_correction = -1
        self.marks = defaultdict(list)
        self.previous_corrections = []
        self.previous_args = {}
        self.simplified_colour_count = simplified_colour_count
        self.inference_times = []
        self.p_direct = p_direct
        self.p_indirect = p_indirect
Beispiel #3
0
def test_sampling_inference_table_likelihoddweighted():
    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BayesianModelSampler,
        sampling_type=SamplingType.LikelihoodWeighted)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2', 'b3'],
                                        time,
                                        correction_type=CorrectionType.TABLE)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        'F(b3)': [0.5, 0.5, 0.5],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)
    # inference = PGMPYInference(pgm_model)
    # inference.infer({'F(b1)':[1,1,1], 'F(b2)':[0,0,0], f'corr_{time}':1})
    # q = inference.query(violations)
    # #
    # q = pgm_model.query(violations, [1, 1])

    assert (q[violations[0]] - 0.5 < 0.2)
    assert (q[violations[1]] - 0.5 < 0.2)

    pgm_model.observe({'red(b1)': 1})

    q = pgm_model.query(violations)

    assert (q[violations[0]] == 0.0)
    assert (q[violations[1]] == 1.0)
Beispiel #4
0
def test_belief_inference2():
    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BeliefPropagation)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2', 'b3'],
                                        time,
                                        correction_type=CorrectionType.TABLE)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        'F(b3)': [0.5, 0.5, 0.5],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)
    # inference = PGMPYInference(pgm_model)
    # inference.infer({'F(b1)':[1,1,1], 'F(b2)':[0,0,0], f'corr_{time}':1})
    # q = inference.query(violations)
    # #
    # q = pgm_model.query(violations, [1, 1])

    assert (q[violations[0]] == 0.5)
    assert (q[violations[1]] == 0.5)

    pgm_model.observe({'red(b1)': 1})

    q = pgm_model.query(violations)

    assert (q[violations[0]] == 0.0)
    assert (q[violations[1]] == 1.0)
Beispiel #5
0
def test_extend_model():
    pgm_model = CorrectionPGMModel()

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')
    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })
    q = pgm_model.query(violations, [1, 1])
    assert (q[violations[0]] == 0.5)
    assert (q[violations[1]] == 0.5)
Beispiel #6
0
def test_gibbs_inference_connected_features2():

    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BayesianModelSampler)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations1 = pgm_model.extend_model(red_on_blue_rules,
                                         red_cm,
                                         blue_cm, ['b1', 'b2'],
                                         time,
                                         correction_type=CorrectionType.TOWER)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })

    time = 1
    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b3', 'b4'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.add_same_reason(violations1, violations)

    pgm_model.test_models()

    pgm_model.observe({
        'F(b3)': [1, 1, 1],
        'F(b4)': [0, 0, 0],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)

    pgm_model.test_models()

    assert (abs(q[violations[0]] - 0.5) < 0.2)
    assert (abs(q[violations[1]] - 0.5) < 0.2)

    pgm_model.observe({'red(b3)': 1})

    q = pgm_model.query(violations)

    pgm_model.test_models()

    assert (q[violations[0]] == 1.0)
    assert (q[violations[1]] == 0.0)
Beispiel #7
0
def test_gibbs_inference_connected_features():
    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BayesianModelSampler)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })

    green_on_orange_rules = rules.Rule.generate_red_on_blue_options(
        'green', 'orange')

    green_cm = KDEColourModel('green')
    orange_cm = KDEColourModel('orange')

    time = 1
    violations = pgm_model.extend_model(green_on_orange_rules,
                                        green_cm,
                                        orange_cm, ['b2', 'b4'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.test_models()

    pgm_model.observe({
        'F(b2)': [1, 1, 1],
        'F(b4)': [0, 0, 0],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)
    # inference = PGMPYInference(pgm_model)
    # inference.infer({'F(b1)':[1,1,1], 'F(b2)':[0,0,0], f'corr_{time}':1})
    # q = inference.query(violations)
    # #
    # q = pgm_model.query(violations, [1, 1])

    pgm_model.test_models()

    assert (abs(q[violations[0]] - 0.5) < 0.2)
    assert (abs(q[violations[1]] - 0.5) < 0.2)

    pgm_model.observe({'green(b2)': 1})

    q = pgm_model.query(violations)

    pgm_model.test_models()

    assert (q[violations[0]] == 1.0)
    assert (q[violations[1]] == 0.0)
Beispiel #8
0
def test_belief_inference_separated_models():
    pgm_model = CorrectionPGMModel(
        inference_type=InferenceType.BeliefPropagation)

    red_cm = KDEColourModel('red')
    blue_cm = KDEColourModel('blue')

    time = 0
    red_on_blue_rules = rules.Rule.generate_red_on_blue_options('red', 'blue')

    violations = pgm_model.extend_model(red_on_blue_rules,
                                        red_cm,
                                        blue_cm, ['b1', 'b2'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })

    green_on_orange_rules = rules.Rule.generate_red_on_blue_options(
        'green', 'orange')

    green_cm = KDEColourModel('green')
    orange_cm = KDEColourModel('orange')

    time = 1
    violations = pgm_model.extend_model(green_on_orange_rules,
                                        green_cm,
                                        orange_cm, ['b3', 'b4'],
                                        time,
                                        correction_type=CorrectionType.TOWER)

    pgm_model.test_models()

    pgm_model.observe({
        'F(b3)': [1, 1, 1],
        'F(b4)': [0, 0, 0],
        f'corr_{time}': 1
    })

    q = pgm_model.query(violations)
    # inference = PGMPYInference(pgm_model)
    # inference.infer({'F(b1)':[1,1,1], 'F(b2)':[0,0,0], f'corr_{time}':1})
    # q = inference.query(violations)
    # #
    # q = pgm_model.query(violations, [1, 1])

    pgm_model.test_models()

    assert (q[violations[0]] == 0.5)
    assert (q[violations[1]] == 0.5)

    pgm_model.observe({'green(b3)': 1})

    q = pgm_model.query(violations)

    pgm_model.test_models()

    assert (q[violations[0]] == 1.0)
    assert (q[violations[1]] == 0.0)

    colours = ["red(b1)", "blue(b2)", "green(b3)", "orange(b4)"]

    q = pgm_model.query(colours)

    pgm_model.test_models()

    assert (q[colours[-1]] == 0)
    assert (q[colours[-2]] == 1.0)
    assert (q[colours[0]] == 0.5)
    assert (q[colours[1]] == 0.5)

    violations = pgm_model.add_no_correction(['b3', 'b5'], 3,
                                             green_on_orange_rules)

    pgm_model.test_models()

    pgm_model.observe({
        "F(b3)": [1, 1, 1],
        "F(b5)": [0.2, 0.3, 0.4],
        "corr_3": 0
    })

    pgm_model.test_models()

    q = pgm_model.query(violations)
Beispiel #9
0
class PGMCorrectingAgent(CorrectingAgent):
    def __init__(self,
                 world,
                 colour_models=None,
                 rule_beliefs=None,
                 domain_file='blocks-domain.pddl',
                 teacher=None,
                 threshold=0.7,
                 update_negative=True,
                 update_once=True,
                 colour_model_type='default',
                 model_config={},
                 tracker=Tracker(),
                 debug=None,
                 simplified_colour_count=False,
                 inference_type=InferenceType.SearchInference,
                 max_inference_size=-1,
                 max_beam_size=-1,
                 p_direct=1,
                 p_indirect=1):

        super(PGMCorrectingAgent,
              self).__init__(world, colour_models, rule_beliefs, domain_file,
                             teacher, threshold, update_negative, update_once,
                             colour_model_type, model_config, tracker)

        self.debug = defaultdict(lambda: False)
        if debug is not None:
            self.debug.update(debug)

        self.pgm_model = CorrectionPGMModel(
            inference_type=inference_type,
            max_inference_size=max_inference_size,
            max_beam_size=max_beam_size,
            p_direct=p_direct,
            p_indirect=p_indirect)
        self.time = 0
        self.last_correction = -1
        self.marks = defaultdict(list)
        self.previous_corrections = []
        self.previous_args = {}
        self.simplified_colour_count = simplified_colour_count
        self.inference_times = []
        self.p_direct = p_direct
        self.p_indirect = p_indirect

        # print("p direct", self.p_direct)
        # print("p indirect", self.p_indirect)
        # print("inference type", inference_type)

    def __repr__(self):
        return "PGMCorrectingAgent"

    def __str__(self):
        return self.__repr__()

    def update_goal(self):
        rule_probs = self.pgm_model.get_rule_probs()
        rules = []
        for rule, p in rule_probs.items():
            try:
                if p > 0.5:
                    rule = Rule.from_string(rule)
                    rules.append(rule.to_formula())
                    if self.debug['show_rules']:
                        print(f'Added rule {rule} to goal')
            except TypeError as e:
                print(self.pgm_model.ordered_models)
                for model in self.pgm_model.ordered_models:
                    print(get_scope(model))
                print(rule_probs)
                raise e

        self.goal = goals.goal_from_list(rules, self.domain_file)

    def no_correction(self, action, args):

        if action.lower() == 'unstack':
            return
        # print("no correction", self.time)
        self.time += 1
        # print(args, self.marks.keys())
        if args[0] in self.marks.keys() or args[1] in self.marks.keys():
            marks = set(self.marks[args[0]] + self.marks[args[1]])
            marks = [rule for rule in marks if isinstance(rule, RedOnBlueRule)]
            if len(marks) == 0:
                return
            self.pgm_model.add_no_correction(args, self.time, marks)
            data = self.get_colour_data(args)
            corr = corr_variable_name(self.time)
            data[corr] = 0
            # print(data)
            time = self.pgm_model.observe(data)
            self.inference_times.append(time)
            # self.pgm_model.infer()
            try:
                self.update_cms()
            except InferenceFailedError:
                return

    def get_relevant_data(self, args, message):
        args_for_model = args.copy()

        if message.T == 'table':
            args_for_model += [message.o3]
        data = self.get_colour_data(args_for_model)

        corr = corr_variable_name(self.time)

        data[corr] = 1

        if 't' in args[1] and message.o2 is not None:
            blue = colour_variable_name(message.o2[0], args[1])
            data[blue] = 0
        return data

    def update_model(self, user_input, args):
        message = read_sentence(user_input)
        # print("correction ", message.T, args)
        if message.T == 'recover':
            red = message.o1[0]
            blue = message.o2[0]
            rules = Rule.generate_red_on_blue_options(red, blue)

            red_cm = self.add_cm(red)
            blue_cm = self.add_cm(blue)
            #
            # print(message.o3)
            # print(rules)

            violations = self.pgm_model.add_recovery(message.o3, self.time,
                                                     rules, red_cm, blue_cm)

            # print(violations)

            data = self.get_colour_data(message.o3)
            corr = corr_variable_name(self.time)
            data[corr] = 1

        elif message.T == 'same reason':

            for prev_corr, prev_time in self.previous_corrections[::-1]:
                if 'same reason' not in prev_corr:
                    break
            # print("Same reason", self.time)
            # print(prev_corr, prev_time)

            prev_message = read_sentence(prev_corr, use_dmrs=False)
            user_input = prev_corr
            violations = self.build_same_reason(prev_message, args, prev_time)
            data = self.get_relevant_data(args, prev_message)
            message = prev_message

        elif message.T in ['tower', 'table']:

            if isinstance(
                    self.teacher, FaultyTeacherAgent
            ) and message.T == 'table' and self.teacher.recover_prob > 0:
                # print(self.teacher.recover_prob)

                red = message.o1[0]
                blue = message.o2[0]
                rules = Rule.generate_red_on_blue_options(red, blue)

                tower_name = args[-1] if 't' in args[-1] else None

                red_cm = self.add_cm(red)
                blue_cm = self.add_cm(blue)

                objects = self.world.state.get_objects_in_tower(tower_name)

                violations = self.pgm_model.create_uncertain_table_model(
                    rules, red_cm, blue_cm, args + [message.o3], objects,
                    self.time)

                data = self.get_colour_data(objects + [message.o3])
                corr = corr_variable_name(self.time)
                data[corr] = 1
            else:
                data = self.get_relevant_data(args, message)
                violations = self.build_pgm_model(message, args)
                corr = corr_variable_name(self.time)
                data[corr] = 1

        elif 'partial.neg' == message.T:
            # print("partial.neg", self.time)
            for i, (prev_corr,
                    prev_time) in enumerate(self.previous_corrections[::-1]):
                if message.o1 in prev_corr:
                    break

            user_input = prev_corr
            prev_message = read_sentence(prev_corr, use_dmrs=False)

            violations = self.build_pgm_model(prev_message, args)
            prev_args = self.previous_args[prev_time]

            if message.o1 == prev_message.o1[0]:
                curr_negation = f'{message.o1}({args[0]})'
                prev_negation = f'{message.o1}({prev_args[0]})'
            else:
                curr_negation = f'{message.o1}({args[1]})'
                prev_negation = f'{message.o1}({prev_args[1]})'

            data = self.get_relevant_data(args, prev_message)
            data[curr_negation] = 0
            data[prev_negation] = 0
            # print(data)
        elif 'colour count' == message.T:
            colour_name = message.o1
            number = message.o2
            rule = ColourCountRule(colour_name, number)
            cm = self.add_cm(colour_name)
            tower_name = args[-1]
            objects = self.world.state.get_objects_in_tower(tower_name)
            top, _ = self.world.state.get_top_two(tower_name)
            if self.simplified_colour_count:
                objects = [top]

            violations = self.pgm_model.add_colour_count_correction(
                rule,
                cm,
                objects,
                self.time,
                faulty_teacher=isinstance(self.teacher, FaultyTeacherAgent))
            data = self.get_colour_data(objects)
            corr = f"corr_{self.time}"
            data[corr] = 1
            blue_top_obj = f"{colour_name}({top})"
            data[blue_top_obj] = 1
        elif 'colour count+tower' == message.T:
            colour_name = message.o1
            number = message.o2
            colour_count = ColourCountRule(colour_name, number)
            red_cm = self.add_cm(message.o3.o1[0])
            blue_cm = self.add_cm(message.o3.o2[0])
            red_on_blue = Rule.generate_red_on_blue_options(
                message.o3.o1[0], message.o3.o2[0])
            tower_name = args[-1]
            top, _ = self.world.state.get_top_two(tower_name)
            objects = self.world.state.get_objects_in_tower(tower_name)
            if self.simplified_colour_count:
                objects = [top]
            violations = self.pgm_model.add_cc_and_rob(colour_count,
                                                       red_on_blue, red_cm,
                                                       blue_cm, objects, top,
                                                       self.time)
            data = self.get_colour_data(objects)
            corr = f"corr_{self.time}"
            data[corr] = 1

        self.previous_corrections.append((user_input, self.time))
        self.previous_args[self.time] = args

        return violations, data, message

    def ask_question(self, message, args):

        if message.T in ['table', 'tower']:
            question = f'Is the top object {message.o1[0]}?'
            # dialogue.info('R: ' + question)
            print(question)
            if isinstance(message.o1, list):
                red = f'{message.o1[0]}({args[0]})'
                blue = f'{message.o2[0]}({args[1]})'
                blue_c = message.o2[0]
                red_c = message.o1[0]
            else:
                red = f'{message.o1}({args[0]}'
                blue = f'{message.o2[0]}({args[1]})'
                blue_c = message.o2
                red_c = message.o1

            if len(args) == 3:
                tower = args[-1]
            else:
                tower = None
            answer = self.teacher.answer_question(question, self.world, tower,
                                                  [red_c, blue_c])
            # dialogue.info("T: " + answer)
            # print(answer)
            if "sorry" in answer:
                if "neither" in answer:
                    data = {red: 0, blue: 0}
                else:
                    data = {red: 1, blue: 1}
            else:
                bin_answer = int(answer.lower() == 'yes')
                data = {red: bin_answer}
            time = self.pgm_model.observe(data)
            self.inference_times.append(time)

    def mark_block(self, most_likely_violation, message, args):
        rule = Rule.from_violation(most_likely_violation)

        if isinstance(rule, list):
            colour_count, red_on_blue = rule
            red_on_blue_options = Rule.generate_red_on_blue_options(
                red_on_blue.c1, red_on_blue.c2)
            self.marks[args[0]] += red_on_blue_options

        elif isinstance(rule, RedOnBlueRule):
            rules = Rule.generate_red_on_blue_options(rule.c1, rule.c2)
            if rule.rule_type == 1:
                if message.T == 'tower':
                    self.marks[args[0]] += rules
                elif message.T == 'table':
                    self.marks[args[1]] += rules
                    self.marks[message.o3] += rules
            elif rule.rule_type == 2:
                if message.T == 'tower':
                    self.marks[args[1]] += rules
                elif message.T == 'table':
                    self.marks[args[0]] += rules
                    self.marks[message.o3] += rules
        elif isinstance(rule, ColourCountRule):
            rules = [rule]
            objects_in_tower = self.world.state.get_objects_in_tower(args[-1])
            for obj in objects_in_tower:
                self.marks[obj] += rules
        else:
            raise NotImplementedError("Invalid or non implemented rule type")

    def get_correction(self,
                       user_input,
                       actions,
                       args,
                       test=False,
                       ask_question=None):
        self.time += 1
        self.last_correction = self.time

        violations, data, message = self.update_model(user_input, args)
        tmp_goal = None
        mark_block = True
        ignore_correction = False

        print(user_input)

        if "sorry" not in user_input:
            tmp_goal = pddl_functions.create_formula('on', args[:2], op='not')
            #self.tmp_goal = goals.update_goal(self.tmp_goal, not_on_xy)

        else:
            on_xy = pddl_functions.create_formula('on', message.o3)
            self.tmp_goal = goals.update_goal(self.tmp_goal, on_xy)

        time = self.pgm_model.observe(data)
        self.inference_times.append(time)

        try:
            q = self.pgm_model.query(list(violations))
        except InferenceFailedError:
            self.world.back_track()
            return
        # print(q)

        if ask_question is not None and ask_question:
            if message.T in ['table', 'tower']:
                self.ask_question(message, args)
                try:
                    q = self.pgm_model.query(list(violations))
                except InferenceFailedError:
                    self.world.back_track()
                    return

        elif all([v < 1 - self.threshold
                  for v in q.values()]):  # and "sorry" not in user_input:
            print("ignoring correction")
            tmp_goal = None
            mark_block = False
            ignore_correction = True

        elif ask_question is None and max(q.values(
        )) < self.threshold:  #min(q.values()) > 1-self.threshold:

            if message.T in ['table', 'tower']:
                self.ask_question(message, args)
                try:
                    q = self.pgm_model.query(list(violations))
                except InferenceFailedError:
                    self.world.back_track()
                    return

        if tmp_goal is not None:
            self.tmp_goal = goals.update_goal(self.tmp_goal, tmp_goal)

        if mark_block is True:
            most_likely_violation = max(q, key=q.get)
            self.mark_block(most_likely_violation, message, args)

        self.update_cms()
        self.update_goal()
        if ignore_correction is not True:
            self.world.back_track()

    def update_cms(self):
        colours = self.pgm_model.get_colour_predictions()
        for cm in self.colour_models.values():
            cm.reset()

        for colour, p in colours.items():

            colour, arg = get_predicate(colour)
            arg = arg[0]
            if 't' in arg:
                continue
            fx = self.get_colour_data([arg])[f'F({arg})']
            # print(fx)

            if p > 0.7:
                self.colour_models[colour].update(fx, p)
                if self.debug['show_cm_update']:
                    print(
                        f'Updated {colour} model with: {fx} at probability {p}'
                    )
            elif self.debug['show_cm_update']:
                print(
                    f'Did not update {colour} model with: {fx} at probability {p}'
                )

        if self.debug['evaluate_cms']:
            for colour, cm in self.colour_models.items():
                print(f'Evaluating {colour} model')
                evaluate_colour_model(cm)

    def add_cm(self, colour_name):
        try:
            red_cm = self.colour_models[colour_name]
        except KeyError:
            red_cm = KDEColourModel(colour_name, **self.model_config)
            self.colour_models[colour_name] = red_cm
        return red_cm

    def build_same_reason(self, message, args, prev_time):
        violations = self.build_pgm_model(message, args)
        rules = Rule.generate_red_on_blue_options(message.o1[0], message.o2[0])
        previous_violations = [f'V_{prev_time}({str(rule)})' for rule in rules]
        self.pgm_model.add_same_reason(violations, previous_violations)
        return violations

    def build_pgm_model(self, message, args):

        rules = Rule.generate_red_on_blue_options(message.o1[0], message.o2[0])
        red_cm = self.add_cm(message.o1[0])
        blue_cm = self.add_cm(message.o2[0])

        if isinstance(self.teacher,
                      FaultyTeacherAgent) and self.teacher.recover_prob > 0:
            table_empty = len(self.world.state.get_objects_on_table()) == 0
        else:
            table_empty = False

        is_table_correction = message.T == 'table'
        if message.T == 'tower':
            correction_type = CorrectionType.TOWER
        elif message.T == 'table':
            if isinstance(
                    self.teacher,
                    FaultyTeacherAgent) and self.teacher.recover_prob > 0:
                correction_type = CorrectionType.UNCERTAIN_TABLE
            else:
                correction_type = CorrectionType.TABLE
        else:
            raise ValueError(
                f"Invalid message type, expected tower or table, not {message.T}"
            )
        # print("table correction?", is_table_correction)
        if is_table_correction:
            args = args[:2]
            args += [message.o3]

        violations = self.pgm_model.extend_model(
            rules,
            red_cm,
            blue_cm,
            args,
            self.time,
            correction_type=correction_type,
            table_empty=table_empty)

        return violations

    def get_colour_data(self, args):
        # observation = self.world.sense()
        # print('observation', observation)
        # colour_data = observation.colours
        colour_data = self.world.data
        # print('colour_data', colour_data)
        data = {
            f'F({arg})': colour_data[arg]
            for arg in args if 't' not in arg
        }
        return data

    def new_world(self, world):

        self.marks = defaultdict(list)
        self.time = 0
        self.last_correction = -1

        self.pgm_model.reset()
        for cm in self.colour_models.values():
            cm.fix()

        self.teacher.reset()
        super(PGMCorrectingAgent, self).new_world(world)