Beispiel #1
0
def plot_multi(p_params, filename=None):
    plotting.plot_multi_test(smoothing_window=30,
                             x_label="episode",
                             y_label="smoothed rewards",
                             curve_to_draw=[_.curve_to_draw for _ in p_params],
                             labels=[_.label for _ in p_params],
                             filename=filename)
Beispiel #2
0
def grid_maze_env():
    env = MazeWorldEpisodeLength(maze=generate_maze_please())
    ep_length = 2000

    # q-learning policy on MazeWorldEpisodeLength
    q_stats, q_table = q_learning(env, ep_length)
    s, t = check_policy(env, q_table)
    print(s, t)

    # random policy on MazeWorldEpisodeLength
    r_stats = random_policy(env, ep_length)

    plot_multi_test([q_stats, r_stats])
Beispiel #3
0
def arm_env():
    env = ArmEnv(episode_max_length=500,
                 size_x=10,
                 size_y=10,
                 cubes_cnt=5,
                 action_minus_reward=-1,
                 finish_reward=1000,
                 tower_target_size=5)

    ep_length = 10000
    # q-learning policy on MazeWorldEpisodeLength
    q_stats, q_table = q_learning(env, ep_length)
    s, t = check_policy(env, q_table)
    print(s, t)

    # random policy on MazeWorldEpisodeLength
    r_stats = random_policy(env, ep_length)

    plot_multi_test([q_stats, r_stats])
Beispiel #4
0
plot_multi_test(
    labels=["LABEL"],
    x_label="Episodes number",
    y_label="Reward",
    curve_to_draw=[[
        -0.050000000000000405, -0.050000000000000405, -0.050000000000000405,
        0.46520000000000006, -0.050000000000000405, -0.050000000000000405,
        -0.050000000000000405, 0.46219999999999994, 0.48030000000000006,
        0.4878, -0.050000000000000405, 0.47300000000000014, 0.4670000000000001,
        0.4939, 0.48140000000000005, -0.050000000000000405,
        -0.050000000000000405, 0.4775000000000001, 0.4533999999999997,
        0.4776000000000001, -0.050000000000000405, 0.46430000000000005,
        0.4669000000000001, 0.46930000000000016, -0.050000000000000405,
        0.46490000000000004, 0.48630000000000007, -0.050000000000000405,
        -0.050000000000000405, -0.050000000000000405, 0.4673000000000001,
        0.47980000000000006, 0.4540999999999997, 0.46900000000000014,
        0.4745000000000001, -0.050000000000000405, 0.48530000000000006, 0.4925,
        -0.050000000000000405, 0.45479999999999976, -0.050000000000000405,
        0.4916, 0.46209999999999996, 0.4765000000000001, 0.46530000000000005,
        0.4922, 0.4902, 0.48610000000000003, 0.4591999999999999,
        0.4679000000000001, 0.46169999999999994, 0.4857, 0.49420000000000003,
        0.47050000000000014, 0.47460000000000013, 0.4785000000000001,
        0.46830000000000016, 0.4927, 0.48630000000000007,
        -0.050000000000000405, 0.4502999999999996, -0.050000000000000405,
        0.4892, -0.050000000000000405, -0.050000000000000405,
        0.4608999999999999, 0.491, 0.48090000000000005, 0.48250000000000004,
        0.4951, 0.48350000000000004, 0.4888, 0.48560000000000003,
        0.4697000000000001, 0.47090000000000015, 0.45499999999999974, 0.4915,
        0.47200000000000014, 0.4837000000000001, 0.4645, 0.4757000000000001,
        0.4779000000000001, 0.4811000000000001, 0.49010000000000004, 0.494,
        0.48230000000000006, 0.48480000000000006, 0.4907, 0.48760000000000003,
        0.4921, 0.48410000000000003, 0.47260000000000013, 0.4857,
        0.48430000000000006, 0.4939, 0.48440000000000005, 0.47820000000000007,
        0.4857, -0.050000000000000405, 0.48760000000000003,
        0.48130000000000006, 0.495, 0.4857, 0.4921, 0.4624, 0.4899, 0.4862,
        0.4868, 0.4718000000000001, 0.4934, 0.4929, 0.48700000000000004,
        0.4941, 0.49010000000000004, 0.4637, 0.4867, 0.48180000000000006,
        0.47830000000000006, 0.4763000000000001, 0.4953, 0.4919, 0.4883,
        0.48080000000000006, 0.4941, 0.4916, 0.48340000000000005, 0.4918,
        0.4795000000000001, 0.48700000000000004, 0.4954, 0.4889,
        0.48610000000000003, 0.48860000000000003, 0.48600000000000004, 0.4954,
        0.4928, 0.48030000000000006, 0.4724000000000001, 0.47870000000000007,
        0.4927, 0.4774000000000001, 0.4941, 0.47930000000000006,
        0.48560000000000003, 0.48240000000000005, 0.4945, 0.49520000000000003,
        0.48530000000000006, 0.4868, 0.4941, 0.47030000000000016, 0.4953,
        0.4922, 0.4826000000000001, 0.49470000000000003, 0.4767000000000001,
        0.48510000000000003, 0.4964, 0.4899, 0.4954, 0.4955, 0.4852, 0.4955,
        0.48690000000000005, 0.493, 0.4743000000000001, 0.4925,
        0.4800000000000001, 0.4965, 0.46960000000000013, 0.48090000000000005,
        0.4933, 0.4873, 0.4964, 0.4909, 0.4888, 0.4963, 0.4917, 0.4963,
        0.49520000000000003, 0.4919, 0.4953, 0.4965, 0.4925, 0.4879, 0.4889,
        0.4946, 0.4964, 0.4955, 0.4894, 0.49620000000000003, 0.4954, 0.4929,
        0.47930000000000006, 0.4965, 0.4963, 0.49560000000000004, 0.4965,
        0.4965, 0.4963, 0.4965, 0.4965, 0.4965, 0.4965, 0.4965, 0.4953, 0.4965,
        0.4953, 0.4946, 0.4965, 0.4965, 0.49520000000000003, 0.4965, 0.4965,
        0.4965, 0.4949, 0.4965, 0.4965, 0.4953, 0.4963, 0.4955, 0.4963, 0.4964,
        0.4965, 0.4965, 0.4963, 0.4955, 0.49560000000000004, 0.4965, 0.4964,
        0.4963, 0.4964, 0.49470000000000003, 0.49560000000000004, 0.4964,
        0.4955, 0.4965, 0.4965, 0.4965, 0.4965, 0.49560000000000004,
        0.49560000000000004, 0.4964, 0.49610000000000004, 0.4965, 0.4965,
        0.4929, 0.495, 0.4965, 0.4953, 0.4963, 0.49560000000000004, 0.4965,
        0.4965, 0.4965, 0.4964, 0.4965, 0.49620000000000003,
        0.49560000000000004, 0.4955, 0.4945, 0.4965, 0.4965, 0.4965, 0.4964,
        0.4965, 0.4965, 0.4964, 0.4965, 0.4965, 0.4965, 0.4955, 0.4965, 0.4963,
        0.4965, 0.4965, 0.4965, 0.4965, 0.4963, 0.49620000000000003, 0.4965,
        0.4963, 0.4963, 0.4965, 0.4955, 0.4965, 0.4963, 0.4963, 0.4963, 0.4955,
        0.4965, 0.4964, 0.49620000000000003, 0.49560000000000004,
        0.49610000000000004, 0.4963, 0.4946, 0.4964, 0.49620000000000003,
        0.4953, 0.4964, 0.4955, 0.496, 0.4965, 0.4965, 0.4965, 0.4963,
        0.49620000000000003, 0.4965, 0.49560000000000004, 0.4965, 0.4964,
        0.4965, 0.49610000000000004, 0.4944, 0.4963, 0.4965, 0.4963, 0.4964,
        0.49620000000000003, 0.4965, 0.4938, 0.4948, 0.4964,
        0.49560000000000004, 0.4963, 0.4965, 0.4963, 0.4943, 0.4963, 0.4965,
        0.4965, 0.49560000000000004, 0.4965, 0.49620000000000003, 0.4965,
        0.49560000000000004, 0.49610000000000004, 0.4965, 0.4965, 0.4955,
        0.4965, 0.49610000000000004, 0.4965, 0.4946, 0.4965, 0.4965,
        0.49620000000000003, 0.4965, 0.4965, 0.4965, 0.4963, 0.4963, 0.4965,
        0.49560000000000004, 0.4965, 0.49620000000000003, 0.4946,
        0.49520000000000003, 0.49560000000000004, 0.4964, 0.49560000000000004,
        0.4963, 0.49560000000000004, 0.4965, 0.4963, 0.49560000000000004,
        0.4965, 0.49620000000000003, 0.4965, 0.4963, 0.4965, 0.4965, 0.4965,
        0.4965, 0.49610000000000004, 0.4965, 0.4965, 0.4964, 0.4965, 0.4963,
        0.4965, 0.4965, 0.4965, 0.4965, 0.4964, 0.4953, 0.49560000000000004,
        0.4944, 0.4955, 0.4951, 0.4955, 0.4954, 0.4965, 0.49560000000000004,
        0.49610000000000004, 0.4946, 0.4965, 0.4965, 0.4965, 0.4965, 0.4965,
        0.4965, 0.4955, 0.4965, 0.4965, 0.4955, 0.4965, 0.4951,
        0.49610000000000004, 0.4965, 0.4953, 0.4964, 0.49560000000000004,
        0.4954, 0.4955, 0.4965, 0.4944, 0.4965, 0.49560000000000004,
        0.49560000000000004, 0.4955, 0.4965, 0.4955, 0.4965,
        0.49560000000000004, 0.4964, 0.4963, 0.4964, 0.4964,
        0.49620000000000003, 0.4964, 0.4963, 0.4953, 0.4965, 0.4965, 0.4963,
        0.4963, 0.4955, 0.4953, 0.4965, 0.4965, 0.4955, 0.4965, 0.4965, 0.4965,
        0.4965, 0.4955, 0.4965, 0.49620000000000003, 0.4955, 0.4965, 0.4964,
        0.4963, 0.4944, 0.4964, 0.4965, 0.4963, 0.4965, 0.4965, 0.4963, 0.4965,
        0.4963, 0.4945, 0.4945, 0.4965, 0.4965, 0.4955, 0.4965, 0.4953, 0.4955,
        0.4944, 0.4965, 0.4965, 0.4965, 0.4965, 0.4963, 0.4964,
        0.49560000000000004, 0.4965, 0.4963, 0.4965, 0.4965, 0.4965, 0.4963,
        0.4964, 0.4955, 0.4965, 0.4965, 0.4953, 0.49560000000000004,
        0.49620000000000003, 0.4965, 0.4965, 0.4965, 0.4953,
        0.49620000000000003, 0.4965, 0.4965, 0.4964
    ]])
Beispiel #5
0
def main():
    env = ArmEnv(episode_max_length=300,
                 size_x=5,
                 size_y=3,
                 cubes_cnt=4,
                 action_minus_reward=-1,
                 finish_reward=100,
                 tower_target_size=4)

    params = HAMParams(q_value={},
                       env=env,
                       current_state=None,
                       eps=0.1,
                       gamma=0.9,
                       alpha=0.1,
                       string_prefix_of_machine=None,
                       accumulated_discount=1,
                       accumulated_rewards=0,
                       previous_machine_choice_state=None,
                       env_is_done=None,
                       logs={"reward": 0, "ep_rewards": []},
                       on_model_transition_id_function=lambda env_: 1 if env_.is_done() else 0,
                       )

    start = Start()
    choice_one = Choice()
    left = Action(action=env.get_actions_as_dict()["LEFT"])
    right = Action(action=env.get_actions_as_dict()["RIGHT"])
    up = Action(action=env.get_actions_as_dict()["UP"])
    down = Action(action=env.get_actions_as_dict()["DOWN"])
    on = Action(action=env.get_actions_as_dict()["ON"])
    off = Action(action=env.get_actions_as_dict()["OFF"])

    stop = Stop()
    simple_machine = AbstractMachine(
        MachineGraph(transitions=(
            MachineRelation(left=start, right=choice_one),
            MachineRelation(left=choice_one, right=left),
            MachineRelation(left=choice_one, right=right),
            MachineRelation(left=choice_one, right=up),
            MachineRelation(left=choice_one, right=down),
            MachineRelation(left=choice_one, right=on),
            MachineRelation(left=choice_one, right=off),

            MachineRelation(left=left, right=stop, label=0),
            MachineRelation(left=right, right=stop, label=0),
            MachineRelation(left=up, right=stop, label=0),
            MachineRelation(left=down, right=stop, label=0),
            MachineRelation(left=on, right=stop, label=0),
            MachineRelation(left=off, right=stop, label=0),

            MachineRelation(left=left, right=stop, label=1),
            MachineRelation(left=right, right=stop, label=1),
            MachineRelation(left=up, right=stop, label=1),
            MachineRelation(left=down, right=stop, label=1),
            MachineRelation(left=on, right=stop, label=1),
            MachineRelation(left=off, right=stop, label=1),
        ), )
    )

    root = RootMachine(machine_to_invoke=LoopInvokerMachine(machine_to_invoke=simple_machine))
    num_episodes = 1500
    for i_episode in range(num_episodes):
        env.reset()
        root.run(params)
        if i_episode % 10 == 0:
            print("\r{root} episode {i_episode}/{num_episodes}.".format(**locals()), end="")
            sys.stdout.flush()
    plotting.plot_multi_test(smoothing_window=30,
                             x_label="episode",
                             y_label="smoothed rewards",
                             curve_to_draw=[params.logs["ep_rewards"]
                                            ],
                             labels=["HAM_basic"]
                             )
Beispiel #6
0
def experiment_slam_input():
    from PIL import Image, ImageDraw
    im = Image.open('robots_map.jpg')
    img_drawer = ImageDraw.Draw(im)
    block_sizes = [
        1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 16, 20, 24, 30, 32, 40, 48, 60, 64,
        80, 96, 120, 160, 192, 240, 320, 480, 960
    ]
    block_size = block_sizes[3]
    n, m = im.height, im.width
    ss = set()
    for i in range(n):
        for j in range(m):
            q = sum(im.getpixel((i, j))) // 3
            offset = 253
            if q > offset:
                img_drawer.point((i, j), fill=(0, 0, 0))
            elif q > 50:
                img_drawer.point((i, j), fill=(255, 255, 255))
            else:
                img_drawer.point((i, j), fill=(0, 0, 0))

    N, M = n // block_size, m // block_size
    maze = np.zeros(shape=(N, M)).astype(int)

    for i in range(n // block_size):
        for j in range(m // block_size):
            colors_sum = 0
            x, y = i, j
            for ii in range(x * block_size, x * block_size + block_size):
                for jj in range(y * block_size, y * block_size + block_size):
                    colors_sum += sum(im.getpixel((ii, jj))) // 3

            colors_sum /= block_size * block_size
            ss.add(colors_sum)
            for ii in range(x * block_size, x * block_size + block_size):
                for jj in range(y * block_size, y * block_size + block_size):
                    if colors_sum > 240:
                        maze[j][i] = 0
                    else:
                        maze[j][i] = 1
                    if colors_sum > 240:
                        img_drawer.point((ii, jj), fill=(255, 255, 255))
                    else:
                        img_drawer.point((ii, jj), fill=(0, 0, 0))

    # TODO rewrite with new HAM

    maze = place_start_finish(prepare_maze(maze))

    episode_max_length = 1000
    env = MazeWorldEpisodeLength(maze=maze,
                                 finish_reward=1000000,
                                 episode_max_length=episode_max_length)
    env.render()
    params = {
        "env": env,
        "num_episodes": 800,
        "machine": L2Interesting,
        "alpha": 0.1,
        "epsilon": 0.1,
        "discount_factor": 1,
        "path": []
    }
    Q1, stats1 = ham_learning(**params)
    plotting.plot_multi_test(curve_to_draw=[stats1.episode_rewards],
                             smoothing_window=10)

    im = Image.open('robots_map.jpg')

    d = params["path"][-episode_max_length:]
    images = []
    for index, item in enumerate(d):
        img_drawer = ImageDraw.Draw(im)
        y, x = item
        for ii in range(x * block_size, x * block_size + block_size):
            for jj in range(y * block_size, y * block_size + block_size):
                img_drawer.point((ii, jj), fill=(240, 13, 13))

        images.append(pil_to_list(im))

        for ii in range(x * block_size, x * block_size + block_size):
            for jj in range(y * block_size, y * block_size + block_size):
                img_drawer.point((ii, jj), fill=(255, 255, 0))
                # if index > 100:
                #     break
    imageio.mimsave('movie.gif', images)