def test_top_two():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'testing'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    action.apply_action(state, ['b1', 't0'])
    action.apply_action(state, ['b2', 'b1'])
    assert (state.get_top_two() == ('b2', 'b1'))

    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    action.apply_action(state, ['b1', 't0', 'tower0'])
    action.apply_action(state, ['b2', 'b1', 'tower0'])
    action.apply_action(state, ['b3', 't1', 'tower1'])
    action.apply_action(state, ['b4', 'b3', 'tower1'])
    assert (state.get_top_two(tower='tower0') == ('b2', 'b1'))
    assert (state.get_top_two(tower='tower1') == ('b4', 'b3'))
def test_pddl_state():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)

    assert ([str(pred) for pred in state.get_predicates('b7')] == [
        "(on-table b7)", "(clear b7)", "(lightyellow b7)", "(yellow b7)"
    ])
    assert (state.get_clear_objs() == [
        'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9', 't0', 't1'
    ])

    assert (state._predicate_holds(pddl_functions.Predicate(
        'on-table', ['b7'])) is True)
    assert (state._predicate_holds(pddl_functions.Predicate('red', ['b7'])) is
            False)
    assert (state._predicate_holds(pddl_functions.Predicate('yellow', ['b7']))
            is True)
def test_get_objects_in_tower():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    action.apply_action(state, ['b1', 't0', 'tower0'])
    action.apply_action(state, ['b2', 'b1', 'tower0'])
    action.apply_action(state, ['b3', 't1', 'tower1'])
    action.apply_action(state, ['b4', 'b3', 'tower1'])
    action.apply_action(state, ['b7', 'b4', 'tower1'])
    action.apply_action(state, ['b5', 'b7', 'tower1'])

    objects_in_tower1 = state.get_objects_in_tower('tower0')
    assert ('b1' in objects_in_tower1)
    assert ('b2' in objects_in_tower1)
    assert ('b3' not in objects_in_tower1)
    objects_in_tower2 = state.get_objects_in_tower('tower1')
    assert ('b3' in objects_in_tower2)
    assert ('b4' in objects_in_tower2)
    assert ('b1' not in objects_in_tower2)
    assert ('b5' in objects_in_tower2)
    assert ('b6' not in objects_in_tower1)
    assert ('t0' not in objects_in_tower1)
    assert (objects_in_tower2[-1] == 'b5')
def test_pddl_state_apply_effect():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)

    assert (state._predicate_holds(pddl_functions.Predicate(
        'on-table', ['b7'])) is True)
    assert (state._predicate_holds(pddl_functions.Predicate('red', ['b7'])) is
            False)
    assert (state._predicate_holds(pddl_functions.Predicate('yellow', ['b7']))
            is True)

    state.apply_effect(Predicate('on-table', ['b7'], op='not'))
    state.apply_effect(Predicate('red', ['b7']))
    state.apply_effect(Predicate('yellow', ['b7'], op='not'))

    assert (state._predicate_holds(pddl_functions.Predicate(
        'on-table', ['b7'])) is False)
    assert (state._predicate_holds(pddl_functions.Predicate('red', ['b7'])) is
            True)
    assert (state._predicate_holds(pddl_functions.Predicate('yellow', ['b7']))
            is False)

    assert (state.fexpressions[0].number == 0.0)
    state.apply_effect(Increase('blue', 'tower1', 1))
    assert (state.fexpressions[0].number == 1.0)
    assert (state.get_colour_count('blue', 'tower1') == 1)
def test_update_state():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    pddl_state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    state = [('b1', 'red'), ('b2', 'blue'), ('b3', 'blue'), ('b4', 'purple'),
             ('b5', 'blue')]
    action.apply_action(pddl_state, ['b1', 't0', 'tower0'])
    action.apply_action(pddl_state, ['b2', 'b1', 'tower0'])
    action.apply_action(pddl_state, ['b3', 't1', 'tower1'])
    action.apply_action(pddl_state, ['b4', 'b3', 'tower1'])
    action.apply_action(pddl_state, ['b7', 'b4', 'tower1'])
    action.apply_action(pddl_state, ['b5', 'b7', 'tower1'])

    colour_counts = defaultdict(int)
    for o, c in state:
        pddl_state.apply_effect(Predicate(c, [o]))
        for t in pddl_state.towers:
            tower = t.replace('t', 'tower')
            if pddl_state.predicate_holds("in-tower", [o, tower]):
                colour_counts[(c, tower)] += 1

    new_fexpressions = []
    for cc in pddl_state.fexpressions:
        cc = ColourCount(cc.colour, cc.tower,
                         colour_counts[(cc.colour, cc.tower)])
        new_fexpressions.append(cc)
    pddl_state.fexpressions = new_fexpressions

    assert (pddl_state.predicate_holds("in-tower", ['b2', 'tower0']))

    assert (pddl_state.get_colour_count('red', 'tower0') == 1)
    assert (pddl_state.get_colour_count('red', 'tower1') == 0)
    assert (colour_counts[('blue', 'tower1')] == 2)
    assert (colour_counts[('blue', 'tower0')] == 1)
    assert (pddl_state.get_colour_count('blue', 'tower0') == 1)

    assert (pddl_state.get_colour_count('purple', 'tower1') == 1)

    assert (pddl_state.get_colour_count('blue', 'tower1') == 2)
def test_get_colour():

    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'testing'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)

    assert (state.get_colour_name('b0') == 'blue')
    assert (state.get_colour_name('b8') == 'pink')
    assert (state.get_colour_name('t0') is None)
def test_objects_on_table():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)

    assert (state.objects == [f'b{i}' for i in range(10)])

    objs_on_table = state.get_objects_on_table()
    assert (objs_on_table == state.objects)
def test_action():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    assert (action.preconditions_hold(state, ['b1', 't1', 'tower1']) is True)
    assert (action.preconditions_hold(state, ['b1', 't0', 'tower0']) is True)
    assert (action.preconditions_hold(state, ['b1', 't1', 'tower0']) is False)
    assert (action.preconditions_hold(state, ['b1', 'b2', 'tower1']) is False)
def test_apply_action():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    action.apply_action(state, ['b1', 't1', 'tower1'])
    assert (state._predicate_holds(Predicate('on', ['b1', 't1'])))

    action.apply_action(state, ['b0', 't0', 'tower0'])
    assert (state._predicate_holds(Predicate('on', ['b0', 't0'])))
    assert (state.get_colour_count('blue', 'tower0') == 1)
def test_update_state():
    config = get_config()
    data_dir = Path(config['data_location'])
    domain_file = 'blocks-domain-updated.pddl'
    domain_file = data_dir / 'domain' / domain_file
    problem_directory = 'multitower'
    problem_number = 1
    problem_file = data_dir / problem_directory / f'problem{problem_number}.pddl'
    domain, problem = pddl_functions.parse(domain_file, problem_file)

    pddl_state = pddl_functions.PDDLState.from_problem(problem)
    action = pddl_functions.Action.from_pddl(domain.actions[0])

    # action.apply_action(pddl_state, ['b1', 't0', 'tower0'])
    # action.apply_action(pddl_state, ['b2', 'b1', 'tower0'])

    s = copy.deepcopy(pddl_state)
    assert (s == pddl_state)

    action.apply_action(pddl_state, ['b1', 't0', 'tower0'])
    assert (s != pddl_state)