예제 #1
0
def main(args):

    env = BlockPushingMetric(
        render_type="shapes",
        background=BlockPushingMetric.BACKGROUND_DETERMINISTIC,
        num_objects=args.num_objects,
        reward_num_goals=args.num_goals,
        all_goals=args.all_goals)
    env.load_metric(args.load_path)

    env.metric[list(range(env.num_states)), list(range(env.num_states))] = 1.0

    for idx1 in range(env.num_states):

        dists = env.metric[idx1, :]
        idx2 = np.argmin(dists)

        s1 = env.all_states[idx1]
        s2 = env.all_states[idx2]

        env.load_state_new_(s1)
        i1 = env.render()
        env.load_state_new_(s2)
        i2 = env.render()

        print(env.metric[idx1, idx2], env.metric[idx1, idx2])

        plt.subplot(1, 2, 1)
        plt.imshow(utils.css_to_ssc(i1))
        plt.subplot(1, 2, 2)
        plt.imshow(utils.css_to_ssc(i2))
        plt.show()
예제 #2
0
def main():

    env = BlockPushing(render_type="cubes", background=BlockPushing.BACKGROUND_DETERMINISTIC)
    env.reset()

    while True:

        img = env.render()
        plt.imshow(utils.css_to_ssc(img))
        plt.show()

        while True:
            x = input("action: ")
            try:
                x = int(x)
            except Exception:
                continue
            if x < 9:
                break

        if x == 8:
            env.reset()
            continue

        env.step(x)
예제 #3
0
def main():

    env = BlockPushingCursor(render_type="cubes")
    env.reset()

    while True:

        img = env.render()
        plt.imshow(utils.css_to_ssc(img))
        plt.show()

        while True:
            x = input("action: ")
            try:
                x = int(x)
            except Exception:
                continue
            if x < 8:
                break

        env.step(x)
예제 #4
0
        state_np = state.cpu().numpy()
        print("object embeddings:", state[0])
        print("action", actions[0][0] // 4, actions[0][0] % 4)

        pred_trans = model.transition_model(state, actions[0])
        pred_state = state + pred_trans
        pred_state_np = pred_state.cpu().numpy()

        num_objects = state_ext.shape[1]

        # current obs | obj 1 | obj 2 | ...
        # next obs | obj 1 | obj 2 | ...
        plt.figure(figsize=(12, 7))

        plt.subplot(3, 7, 1)
        plt.imshow(utils.css_to_ssc(utils.to_np(obs[0])))
        plt.axis("off")

        for i in range(num_objects):

            plt.subplot(3, 7, 2 + i)
            plt.imshow(utils.to_np(state_ext[0, i]))
            plt.axis("off")

        for i in range(num_objects):

            plt.subplot(3, 7, 9 + i)
            plt.scatter(all_states[:, i, 0], all_states[:, i, 1])
            plt.scatter(state_np[0, i, 0], state_np[0, i, 1])

        plt.subplot(3, 7, 15)
예제 #5
0
파일: vis4.py 프로젝트: ondrejba/c-swm
            print("object embeddings:", state[idx])
            print("raw action", actions[idx])
            print("action", actions[idx] // 4, actions[idx] % 4)

            pred_state = model.forward_transition(state, actions)
            pred_state_np = pred_state.cpu().numpy()

            num_objects = state_ext.shape[1]

            # current obs | obj 1 | obj 2 | ...
            # next obs | obj 1 | obj 2 | ...
            plt.figure(figsize=(12, 4))

            plt.subplot(2, num_objects + 1, 1)
            plt.imshow(utils.css_to_ssc(utils.to_np(obs[idx][:3, :, :])))
            plt.axis("off")

            for i in range(num_objects):

                plt.subplot(2, num_objects + 1, 2 + i)
                plt.imshow(utils.to_np(state_ext[idx, i]))
                plt.axis("off")

            for i in range(num_objects):

                plt.subplot(2, num_objects + 1, num_objects + 3 + i)
                plt.scatter(all_states[:, i, 0], all_states[:, i, 1])
                plt.scatter(state_np[idx, i, 0], state_np[idx, i, 1])
                plt.axis("off")