def main(args): env = BlockPushingMetric( render_type="shapes", background=BlockPushingMetric.BACKGROUND_DETERMINISTIC, num_objects=args.num_objects, reward_num_goals=args.num_goals, all_goals=args.all_goals) env.load_metric(args.load_path) env.metric[list(range(env.num_states)), list(range(env.num_states))] = 1.0 for idx1 in range(env.num_states): dists = env.metric[idx1, :] idx2 = np.argmin(dists) s1 = env.all_states[idx1] s2 = env.all_states[idx2] env.load_state_new_(s1) i1 = env.render() env.load_state_new_(s2) i2 = env.render() print(env.metric[idx1, idx2], env.metric[idx1, idx2]) plt.subplot(1, 2, 1) plt.imshow(utils.css_to_ssc(i1)) plt.subplot(1, 2, 2) plt.imshow(utils.css_to_ssc(i2)) plt.show()
def main(): env = BlockPushing(render_type="cubes", background=BlockPushing.BACKGROUND_DETERMINISTIC) env.reset() while True: img = env.render() plt.imshow(utils.css_to_ssc(img)) plt.show() while True: x = input("action: ") try: x = int(x) except Exception: continue if x < 9: break if x == 8: env.reset() continue env.step(x)
def main(): env = BlockPushingCursor(render_type="cubes") env.reset() while True: img = env.render() plt.imshow(utils.css_to_ssc(img)) plt.show() while True: x = input("action: ") try: x = int(x) except Exception: continue if x < 8: break env.step(x)
state_np = state.cpu().numpy() print("object embeddings:", state[0]) print("action", actions[0][0] // 4, actions[0][0] % 4) pred_trans = model.transition_model(state, actions[0]) pred_state = state + pred_trans pred_state_np = pred_state.cpu().numpy() num_objects = state_ext.shape[1] # current obs | obj 1 | obj 2 | ... # next obs | obj 1 | obj 2 | ... plt.figure(figsize=(12, 7)) plt.subplot(3, 7, 1) plt.imshow(utils.css_to_ssc(utils.to_np(obs[0]))) plt.axis("off") for i in range(num_objects): plt.subplot(3, 7, 2 + i) plt.imshow(utils.to_np(state_ext[0, i])) plt.axis("off") for i in range(num_objects): plt.subplot(3, 7, 9 + i) plt.scatter(all_states[:, i, 0], all_states[:, i, 1]) plt.scatter(state_np[0, i, 0], state_np[0, i, 1]) plt.subplot(3, 7, 15)
print("object embeddings:", state[idx]) print("raw action", actions[idx]) print("action", actions[idx] // 4, actions[idx] % 4) pred_state = model.forward_transition(state, actions) pred_state_np = pred_state.cpu().numpy() num_objects = state_ext.shape[1] # current obs | obj 1 | obj 2 | ... # next obs | obj 1 | obj 2 | ... plt.figure(figsize=(12, 4)) plt.subplot(2, num_objects + 1, 1) plt.imshow(utils.css_to_ssc(utils.to_np(obs[idx][:3, :, :]))) plt.axis("off") for i in range(num_objects): plt.subplot(2, num_objects + 1, 2 + i) plt.imshow(utils.to_np(state_ext[idx, i])) plt.axis("off") for i in range(num_objects): plt.subplot(2, num_objects + 1, num_objects + 3 + i) plt.scatter(all_states[:, i, 0], all_states[:, i, 1]) plt.scatter(state_np[idx, i, 0], state_np[idx, i, 1]) plt.axis("off")