Esempio n. 1
0
def test_colour_count_CPD_generation2():
    pgm_model = CorrectionPGMModel()
    time = 0
    rule = ColourCountRule('blue', 1)

    cm = KDEColourModel('blue')

    objects = ['b1', 'b2', 'b3', 'b4', 'b5']

    violations = pgm_model.add_colour_count_correction(rule, cm, objects, time)

    data = {f'corr_{time}': 1}
    for obj in objects:
        data[f'F({obj})'] = [1, 1, 1]

    pgm_model.observe(data)

    colours = [f'blue({obj})' for obj in objects]

    evidence = colours + violations
    q = pgm_model.query(evidence)

    assert (abs(q[violations[0]] - 1.0) < 0.001)

    for colour in colours:
        assert (abs(q[colour] - 2 / len(objects)) < 0.000001)
Esempio n. 2
0
def test_get_correction_colour_count():

    w = RandomColoursWorld('blocks-domain-updated.pddl',
                           problem_directory='multitower',
                           problem_number=1)

    teacher = ExtendedTeacherAgent()
    agent = PGMCorrectingAgent(w, teacher=teacher)

    w.update('put', ['b0', 't1', 'tower1'])
    w.update('put', ['b2', 'b0', 'tower1'])
    w.update('put', ['b5', 'b2', 'tower1'])

    rule = ColourCountRule('blue', 1)

    b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346]
    b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336]
    b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746]
    b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052]

    correction = teacher.correction(w)

    agent.get_correction(correction, 'put', ['b5', 'b2', 'tower1'])

    test_b5_data_in_cm = any([
        array_equal(datum, b5_data)
        for datum in agent.colour_models['blue'].data
    ])
    assert (test_b5_data_in_cm)
    rule_probs = agent.pgm_model.get_rule_probs()
    assert (rule_probs[rule] > 0.5)
Esempio n. 3
0
def test_update_model():

    w = RandomColoursWorld('blocks-domain-updated.pddl',
                           problem_directory='multitower',
                           problem_number=1)

    w.update('put', ['b0', 't1', 'tower1'])
    w.update('put', ['b2', 'b0', 'tower1'])
    w.update('put', ['b5', 'b2', 'tower1'])

    agent = PGMCorrectingAgent(w)

    rule = ColourCountRule('blue', 1)

    b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346]
    b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336]
    b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746]
    b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052]

    violations, data, message = agent.update_model(
        f"no, you cannot put more than 1 blue blocks in a tower",
        ['b5', 'b2', 'tower1'])

    assert (f'V_0({rule})' in violations)

    assert (message.T == 'colour count')
    assert (message.o1 == 'blue')
    assert (message.o2 == 1)

    assert (array_equal(data["F(b0)"], b0_data))
    assert (array_equal(data["F(b5)"], b5_data))
    assert (array_equal((data["F(b2)"]), b2_data))
    assert (data["corr_0"] == 1)
Esempio n. 4
0
def test_get_correction_colour_count2():

    w = RandomColoursWorld('blocks-domain-updated.pddl',
                           problem_directory='multitower',
                           problem_number=2)

    teacher = ExtendedTeacherAgent()
    agent = PGMCorrectingAgent(w, teacher=teacher)

    w.update('put', ['b0', 't1', 'tower1'])
    w.update('put', ['b2', 'b0', 'tower1'])
    w.update('put', ['b1', 'b2', 'tower1'])

    rule = ColourCountRule('blue', 1)

    b0_data = [0.34910434076796404, 0.9939351750475346, 0.9020967700201346]
    b1_data = [0.8872649368468363, 0.9880529065134138, 0.9322439031222336]
    b2_data = [0.15738804617607374, 0.9835023958587732, 0.9541478602571746]
    b5_data = [0.8805041887685969, 0.9819704146691624, 0.9679969438475052]

    correction = teacher.correction(w)
    assert (
        correction ==
        "no, you cannot put more than 1 blue blocks in a tower and you must put blue blocks on red blocks"
    )

    agent.get_correction(correction, 'put', ['b1', 'b2', 'tower1'])

    colour_predictions = agent.pgm_model.get_colour_predictions()
    assert (abs(colour_predictions['red(b1)'] - 1.0) < 0.0001)

    assert (array_equal(agent.get_colour_data(['b1'])['F(b1)'], b1_data))

    assert (agent.colour_models['red'].data != [])
    assert (len(agent.colour_models['red'].data) == 1)
    test_b1_data_in_cm = reduce(lambda x, y: x or y, [
        array_equal(agent.colour_models['red'].data[i], np.array(b1_data))
        for i in range(len(agent.colour_models['red'].data))
    ])
    test_b2_data_in_cm = reduce(lambda x, y: x or y, [
        array_equal(datum, b2_data)
        for datum in agent.colour_models['red'].data
    ])
    test_b0_data_in_cm = reduce(lambda x, y: x or y, [
        array_equal(datum, np.array(b0_data))
        for datum in agent.colour_models['red'].data
    ])

    assert (np.all(
        np.array(hsv2rgb([[b1_data]])[0][0]) ==
        agent.colour_models['red'].data[0]))
    assert (test_b0_data_in_cm is False)
    assert (test_b2_data_in_cm is False)
    assert (test_b1_data_in_cm)
    rule_probs = agent.pgm_model.get_rule_probs()
    assert (rule_probs[rule] > 0.5)
Esempio n. 5
0
def generate_colour_count(max_num=4, exact_num=None):
    # colour = random.choice(list(colour_dict.keys()))
    colour = random.choice(list(fruit_dict.keys()))
    # print("colour1")
    # print(colour)
    if exact_num is not None:
        number = exact_num
    else:
        number = random.choice(range(1, max_num + 1))
    return ColourCountRule(colour, number)
Esempio n. 6
0
def test_extended_blocks_world_problem():
    cc = ColourCountRule('blue', 2)
    colours = ['red', 'blue', 'green', 'yellow', 'pink', 'purple', 'red', 'blue', 'blue', 'blue']
    problem = problem_def.ExtendedBlocksWorldProblem(num_blocks=10, num_towers=2, rules=[cc.to_formula()], colours=colours)

    assert(problem.goal.subformulas[0].asPDDL() == "(forall (?x) (done ?x))")
    assert(problem.goal.subformulas[1].asPDDL() == cc.asPDDL())

    state = PDDLState.from_problem(problem)

    for i, colour in enumerate(colours):
        assert(state.predicate_holds(colour, [f"b{i}"]))

        assert(state.predicate_holds('clear', [f'b{i}']))
        assert(state.predicate_holds('on-table', [f'b{i}']))

    for colour in ['red', 'blue', 'green', 'yellow', 'pink', 'purple', 'orange']:
        assert(state.get_colour_count(colour, 'tower0') == 0)
        assert (state.get_colour_count(colour, 'tower1') == 0)
Esempio n. 7
0
    def add_joint_violations(self, colour_count: ColourCountRule,
                             red_on_blue_options: list, time: int,
                             colours_in_tower: list, model: FactorGraph):

        p = self.get_p(CorrectionType.TABLE)

        cpds = [
            colour_count.generateCPD(num_blocks_in_tower=len(colours_in_tower),
                                     correction_type=CorrectionType.TABLE,
                                     p=p) for rule in red_on_blue_options
        ]
        violated_rule_factor_name1 = f"V_{time}({colour_count} && {red_on_blue_options[0]})"
        violated_rule_factor_name2 = f"V_{time}({colour_count} && {red_on_blue_options[1]})"

        evidence1 = colours_in_tower + [red_on_blue_options[0], colour_count]
        evidence2 = colours_in_tower + [red_on_blue_options[1], colour_count]

        cpd1, cpd2 = cpds

        rule_violated_factor1 = TabularCPD(violated_rule_factor_name1,
                                           2,
                                           cpd1,
                                           evidence=evidence1,
                                           evidence_card=[2] * len(evidence1))
        rule_violated_factor1 = rule_violated_factor1.to_factor()

        rule_violated_factor2 = TabularCPD(violated_rule_factor_name2,
                                           2,
                                           cpd1,
                                           evidence=evidence2,
                                           evidence_card=[2] * len(evidence2))
        rule_violated_factor2 = rule_violated_factor2.to_factor()

        self.add_factor([
            violated_rule_factor_name1, rule_violated_factor1,
            red_on_blue_options[0], colour_count
        ], rule_violated_factor1, model)
        self.add_factor([
            violated_rule_factor_name2, rule_violated_factor2,
            red_on_blue_options[1], colour_count
        ], rule_violated_factor2, model)
        return [violated_rule_factor_name1, violated_rule_factor_name2]
Esempio n. 8
0
def test_colour_count_CPD_generation3():
    pgm_model = CorrectionPGMModel()
    time = 0
    colour_count = ColourCountRule('blue', 1)
    red_on_blue_options = Rule.generate_red_on_blue_options('blue', 'red')

    blue_cm = KDEColourModel('blue')
    red_cm = KDEColourModel('red')

    objects = [f"b{o}" for o in range(10)]
    objects_in_tower = objects[:5]
    top_object = objects[4]

    violations = pgm_model.add_cc_and_rob(colour_count, red_on_blue_options,
                                          blue_cm, red_cm, objects_in_tower,
                                          top_object, time)

    data = {f'corr_{time}': 1}
    for obj in objects:
        data[f'F({obj})'] = [1, 1, 1]

    pgm_model.observe(data)

    colours_in_tower = [f'blue({obj})' for obj in objects_in_tower[:-1]
                        ] + [f'red({top_object})']

    evidence = colours_in_tower + violations + red_on_blue_options + [
        colour_count
    ]
    q = pgm_model.query(evidence)

    assert (abs(q[violations[0]] - 2 / 3) < 0.001)
    assert (abs(q[violations[1]] - 2 / 3) < 0.0001)

    assert (abs(q[colour_count] - 1.0) < 0.00001)
    assert (abs(q[red_on_blue_options[0]] - 2 / 3) < 0.00001)

    assert (pgm_model.colours['red'] is not None)
    assert (abs(q[colours_in_tower[-1]] - 1.0) < 0.00001)
    for colour in colours_in_tower[:-1]:
        assert (abs(q[colour] - 1 / (len(colours_in_tower) - 1) < 0.0001))
Esempio n. 9
0
def test_colour_count_CPD_generation():
    pgm_model = CorrectionPGMModel()
    time = 0
    rule = ColourCountRule('blue', 1)

    cm = KDEColourModel('blue')

    new_model = pgm_model.add_new_model()

    violations = pgm_model.add_colour_count_correction(rule, cm, ['b1', 'b2'],
                                                       time, new_model)

    pgm_model.observe({
        'F(b1)': [1, 1, 1],
        'F(b2)': [0, 0, 0],
        f'corr_{time}': 1
    })

    q = pgm_model.query(['blue(b1)', 'blue(b2)', violations[0]])

    assert (q[violations[0]] == 1.0)
    assert (q["blue(b1)"] == 1.0)
    assert (q["blue(b2)"] == 1.0)
Esempio n. 10
0
    def build_model(self, message, args):
        data = self.get_data(message, args)
        i = 0
        # To reduce the number of individual colour names reuse colour models where the colour is the same
        # This only works because we have a set number of colour terms
        # First try to find colour models that match particular data points
        if data['o2'] is None:
            print("o2 is bottom position")
            return
        try:
            c1, c1_model = self.find_matching_cm(data['o1'])
            c2, c2_model = self.find_matching_cm(data['o2'])
        except ValueError:
            print("too many colour models, stopped learning")
            return

        #If tower: create the negated rule and two new colour variables.
        if message.T == 'tower':
            if data['o2'] is None:
                return
            rule = correctingagent.world.rules.NotRedOnBlueRule(c1, c2).to_formula()

            self.colour_models.update({c1:c1_model, c2:c2_model})
            self.rule_models['not {} and {}'.format(c1, c2)] = rule
            return

        elif message.T == 'table':
            c3, c3_model = self.find_matching_cm(data['o3'])


            rule1 = RedOnBlueRule(c3, c2, 1)
            rule2 = RedOnBlueRule(c1, c3, 2)

            # rule1 = correctingagent.world.rules.create_red_on_blue_rule([c3], [c2])
            # rule2 = correctingagent.world.rules.create_red_on_blue_rule([c3], [c1], ['?y', '?x'])
            if data['o2'] is None:
                self.colour_models.update({c1:c1_model, c3:c3_model})
                self.rule_models[rule2] = rule2
            else:
                self.colour_models.update({c1:c1_model, c2:c2_model, c3:c3_model})

                obs = self.world.sense()
                try:
                    test = search.ActiveLearningTest(rule1.to_formula(), rule2.to_formula(),
                                                 obs.colours, c3_model, args[1],
                                                 self.world)
                    self.active_tests.append(test)
                except TestFailed:
                    self.rule_models[rule2] = rule2.to_formula()
            return
        elif message.T == 'colour count':
            n = 0
            tower = args[-1]

            try:
                cm = self.colour_models['blue']
            except KeyError:
                cm = prob_model.KDEColourModel('blue', data=np.array([data['o1']]),
                                        weights=np.array([1]), **self.model_config)

            blocks_in_tower = self.world.state.get_objects_in_tower(tower)
            for block in blocks_in_tower:
                datum = self.world.observe_object(block)
                if cm.p(1, datum) > 0.5:
                    n = min(3, n+1)
            if n > 0:
                rule = ColourCountRule('blue', n)
                self.colour_models.update({'blue':cm})
                self.rule_models["colour_count"] = rule.to_formula()
            return
        elif message.T == 'colour count+tower':
            if 'blue' in self.colour_models.keys():
                rule = RedOnBlueRule('blue', c1, 2)
                self.colour_models.update({c1:c1_model})
                self.rule_models[str(rule)] = rule.to_formula()
                return
            else:
                return
Esempio n. 11
0
    def update_model(self, user_input, args):
        message = read_sentence(user_input)
        # print("correction ", message.T, args)
        if message.T == 'recover':
            red = message.o1[0]
            blue = message.o2[0]
            rules = Rule.generate_red_on_blue_options(red, blue)

            red_cm = self.add_cm(red)
            blue_cm = self.add_cm(blue)
            #
            # print(message.o3)
            # print(rules)

            violations = self.pgm_model.add_recovery(message.o3, self.time,
                                                     rules, red_cm, blue_cm)

            # print(violations)

            data = self.get_colour_data(message.o3)
            corr = corr_variable_name(self.time)
            data[corr] = 1

        elif message.T == 'same reason':

            for prev_corr, prev_time in self.previous_corrections[::-1]:
                if 'same reason' not in prev_corr:
                    break
            # print("Same reason", self.time)
            # print(prev_corr, prev_time)

            prev_message = read_sentence(prev_corr, use_dmrs=False)
            user_input = prev_corr
            violations = self.build_same_reason(prev_message, args, prev_time)
            data = self.get_relevant_data(args, prev_message)
            message = prev_message

        elif message.T in ['tower', 'table']:

            if isinstance(
                    self.teacher, FaultyTeacherAgent
            ) and message.T == 'table' and self.teacher.recover_prob > 0:
                # print(self.teacher.recover_prob)

                red = message.o1[0]
                blue = message.o2[0]
                rules = Rule.generate_red_on_blue_options(red, blue)

                tower_name = args[-1] if 't' in args[-1] else None

                red_cm = self.add_cm(red)
                blue_cm = self.add_cm(blue)

                objects = self.world.state.get_objects_in_tower(tower_name)

                violations = self.pgm_model.create_uncertain_table_model(
                    rules, red_cm, blue_cm, args + [message.o3], objects,
                    self.time)

                data = self.get_colour_data(objects + [message.o3])
                corr = corr_variable_name(self.time)
                data[corr] = 1
            else:
                data = self.get_relevant_data(args, message)
                violations = self.build_pgm_model(message, args)
                corr = corr_variable_name(self.time)
                data[corr] = 1

        elif 'partial.neg' == message.T:
            # print("partial.neg", self.time)
            for i, (prev_corr,
                    prev_time) in enumerate(self.previous_corrections[::-1]):
                if message.o1 in prev_corr:
                    break

            user_input = prev_corr
            prev_message = read_sentence(prev_corr, use_dmrs=False)

            violations = self.build_pgm_model(prev_message, args)
            prev_args = self.previous_args[prev_time]

            if message.o1 == prev_message.o1[0]:
                curr_negation = f'{message.o1}({args[0]})'
                prev_negation = f'{message.o1}({prev_args[0]})'
            else:
                curr_negation = f'{message.o1}({args[1]})'
                prev_negation = f'{message.o1}({prev_args[1]})'

            data = self.get_relevant_data(args, prev_message)
            data[curr_negation] = 0
            data[prev_negation] = 0
            # print(data)
        elif 'colour count' == message.T:
            colour_name = message.o1
            number = message.o2
            rule = ColourCountRule(colour_name, number)
            cm = self.add_cm(colour_name)
            tower_name = args[-1]
            objects = self.world.state.get_objects_in_tower(tower_name)
            top, _ = self.world.state.get_top_two(tower_name)
            if self.simplified_colour_count:
                objects = [top]

            violations = self.pgm_model.add_colour_count_correction(
                rule,
                cm,
                objects,
                self.time,
                faulty_teacher=isinstance(self.teacher, FaultyTeacherAgent))
            data = self.get_colour_data(objects)
            corr = f"corr_{self.time}"
            data[corr] = 1
            blue_top_obj = f"{colour_name}({top})"
            data[blue_top_obj] = 1
        elif 'colour count+tower' == message.T:
            colour_name = message.o1
            number = message.o2
            colour_count = ColourCountRule(colour_name, number)
            red_cm = self.add_cm(message.o3.o1[0])
            blue_cm = self.add_cm(message.o3.o2[0])
            red_on_blue = Rule.generate_red_on_blue_options(
                message.o3.o1[0], message.o3.o2[0])
            tower_name = args[-1]
            top, _ = self.world.state.get_top_two(tower_name)
            objects = self.world.state.get_objects_in_tower(tower_name)
            if self.simplified_colour_count:
                objects = [top]
            violations = self.pgm_model.add_cc_and_rob(colour_count,
                                                       red_on_blue, red_cm,
                                                       blue_cm, objects, top,
                                                       self.time)
            data = self.get_colour_data(objects)
            corr = f"corr_{self.time}"
            data[corr] = 1

        self.previous_corrections.append((user_input, self.time))
        self.previous_args[self.time] = args

        return violations, data, message