def visualize_model_prediction_single_timestep(
    vis,
    config,
    z_pred,  # [z_dim]
    display_idx,
    name_prefix=None,
    color=None,
    display_robot_state=True,
):

    if color is None:
        color = [0, 0, 255]

    idx_dict = get_object_and_robot_state_indices(config)

    z_object = z_pred[idx_dict['object_indices']].reshape(
        config['dataset']['object_state_shape'])
    z_robot = z_pred[idx_dict['robot_indices']].reshape(
        config['dataset']['robot_state_shape'])

    name = "z_object/%d" % (display_idx)
    if name_prefix is not None:
        name = name_prefix + "/" + name
    meshcat_utils.visualize_points(
        vis,
        name,
        torch_utils.cast_to_numpy(z_object),
        color=color,
        size=0.01,
    )

    if display_robot_state:
        name = "z_robot/%d" % (display_idx)
        if name_prefix is not None:
            name = name_prefix + "/" + name
        meshcat_utils.visualize_points(
            vis,
            name,
            torch_utils.cast_to_numpy(z_robot),
            color=color,
            size=0.01,
        )
Esempio n. 2
0
def main():
    # load dynamics model
    model_dict = load_model_state_dict()
    model = model_dict['model_dy']
    model_dd = model_dict['model_dd']
    config = model.config

    env_config = load_yaml(os.path.join(get_project_root(), 'experiments/exp_20_mugs/config.yaml'))
    env_config['env']['observation']['depth_int16'] = True
    n_history = config['train']['n_history']

    initial_cond = generate_initial_condition(env_config, push_length=PUSH_LENGTH)
    env_config = initial_cond['config']

    # enable the right observations

    camera_name = model_dict['metadata']['camera_name']
    spatial_descriptor_data = model_dict['spatial_descriptor_data']
    ref_descriptors = spatial_descriptor_data['spatial_descriptors']
    K = ref_descriptors.shape[0]

    ref_descriptors = torch.Tensor(ref_descriptors).cuda()  # put them on the GPU

    print("ref_descriptors\n", ref_descriptors)
    print("ref_descriptors.shape", ref_descriptors.shape)

    # create the environment
    # create the environment
    env = DrakeMugsEnv(env_config)
    env.reset()

    T_world_camera = env.camera_pose(camera_name)
    camera_K_matrix = env.camera_K_matrix(camera_name)

    # create another environment for doing rollouts
    env2 = DrakeMugsEnv(env_config, visualize=False)
    env2.reset()

    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.drake_pusher_position_3D(config)
    visual_observation_function = \
        VisualObservationFunctionFactory.descriptor_keypoints_3D(config=config,
                                                                 camera_name=camera_name,
                                                                 model_dd=model_dd,
                                                                 ref_descriptors=ref_descriptors,
                                                                 K_matrix=camera_K_matrix,
                                                                 T_world_camera=T_world_camera,
                                                                 )

    episode = OnlineEpisodeReader()
    mpc_input_builder = DynamicsModelInputBuilder(observation_function=observation_function,
                                                  visual_observation_function=visual_observation_function,
                                                  action_function=action_function,
                                                  episode=episode)

    vis = meshcat_utils.make_default_visualizer_object()
    vis.delete()

    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    obs_init = env.get_observation()

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    def goal_func(obs_tmp):
        state_tmp = mpc_input_builder.get_state_input_single_timestep({'observation': obs_tmp})['state']
        return model.compute_z_state(state_tmp.unsqueeze(0))['z_object'].flatten()


    #
    idx = episode.get_latest_idx()
    obs_raw = episode.get_observation(idx)
    z_object_goal = goal_func(obs_raw)
    z_keypoints_init_W = keypoints_3D_from_dynamics_model_output(z_object_goal, K)
    z_keypoints_init_W = torch_utils.cast_to_numpy(z_keypoints_init_W)

    z_keypoints_obj = keypoints_world_frame_to_object_frame(z_keypoints_init_W,
                                                          T_W_obj=slider_pose_from_observation(obs_init))

    color = [1, 0, 0]
    meshcat_utils.visualize_points(vis=vis,
                                   name="keypoints_W",
                                   pts=z_keypoints_init_W,
                                   color=color,
                                   size=0.02,
                                   )

    # input("press Enter to continue")

    # rollout single action sequence using the simulator
    action_sequence_np = torch_utils.cast_to_numpy(initial_cond['action_sequence'])
    N = action_sequence_np.shape[0]
    obs_rollout_gt = env_utils.rollout_action_sequence(env, action_sequence_np)[
        'observations']

    # using the vision model to get "goal" keypoints
    z_object_goal = goal_func(obs_rollout_gt[-1])
    z_object_goal_np = torch_utils.cast_to_numpy(z_object_goal)
    z_keypoints_goal = keypoints_3D_from_dynamics_model_output(z_object_goal, K)
    z_keypoints_goal = torch_utils.cast_to_numpy(z_keypoints_goal)

    # visualize goal keypoints
    color = [0, 1, 0]
    meshcat_utils.visualize_points(vis=vis,
                                   name="goal_keypoints",
                                   pts=z_keypoints_goal,
                                   color=color,
                                   size=0.02,
                                   )

    # input("press Enter to continue")

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # [n_history, state_dim]
    idx = episode.get_latest_idx()

    dyna_net_input = mpc_input_builder.get_dynamics_model_input(idx, n_history=n_history)
    state_init = dyna_net_input['states'].cuda() # [n_history, state_dim]
    action_init = dyna_net_input['actions'] # [n_history, action_dim]


    print("state_init.shape", state_init.shape)
    print("action_init.shape", action_init.shape)


    action_seq_gt_torch = torch_utils.cast_to_torch(initial_cond['action_sequence'])
    action_input = torch.cat((action_init[:(n_history-1)], action_seq_gt_torch), dim=0).cuda()
    print("action_input.shape", action_input.shape)


    # rollout using the ground truth actions and learned model
    # need to add the batch dim to do that
    z_init = model.compute_z_state(state_init)['z']
    rollout_pred = rollout_model(state_init=z_init.unsqueeze(0),
                                 action_seq=action_input.unsqueeze(0),
                                 dynamics_net=model,
                                 compute_debug_data=True)

    state_pred_rollout = rollout_pred['state_pred']

    print("state_pred_rollout.shape", state_pred_rollout.shape)

    for i in range(N):
        # vis GT for now
        name = "GT_3D/%d" % (i)
        T_W_obj = slider_pose_from_observation(obs_rollout_gt[i])
        # print("T_W_obj", T_W_obj)

        # green
        color = np.array([0, 1, 0]) * get_color_intensity(i, N)
        meshcat_utils.visualize_points(vis=vis,
                                       name=name,
                                       pts=z_keypoints_obj,
                                       color=color,
                                       size=0.01,
                                       T=T_W_obj)

        # red
        color = np.array([0, 0, 1]) * get_color_intensity(i, N)
        state_pred = state_pred_rollout[:, i, :]
        pts_pred = keypoints_3D_from_dynamics_model_output(state_pred, K).squeeze()
        pts_pred = pts_pred.detach().cpu().numpy()
        name = "pred_3D/%d" % (i)
        meshcat_utils.visualize_points(vis=vis,
                                       name=name,
                                       pts=pts_pred,
                                       color=color,
                                       size=0.01,
                                       )

    # input("finished visualizing GT rollout\npress Enter to continue")
    index_dict = get_object_and_robot_state_indices(config)
    object_indices = index_dict['object_indices']

    # reset the environment and use the MPC controller to stabilize this
    # now setup the MPC to try to stabilize this . . . .
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    episode.clear()

    # add just some large number of these
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # input("press Enter to continue")

    # make a planner config
    planner_config = copy.copy(config)
    config_tmp = load_yaml(os.path.join(get_project_root(), 'experiments/drake_pusher_slider/eval_config.yaml'))
    planner_config['mpc'] = config_tmp['mpc']
    planner = None
    if PLANNER_TYPE == "random_shooting":
        planner = RandomShootingPlanner(planner_config)
    elif PLANNER_TYPE == "mppi":
        planner = PlannerMPPI(planner_config)
    else:
        raise ValueError("unknown planner type: %s" % (PLANNER_TYPE))

    mpc_out = None
    action_seq_mpc = None
    state_pred_mpc = None
    counter = -1
    while True:
        counter += 1
        print("\n\n-----Running MPC Optimization: Counter (%d)-------" % (counter))

        obs_cur = env.get_observation()
        episode.add_observation_only(obs_cur)

        if counter == 0 or REPLAN:
            print("replanning")
            ####### Run the MPC ##########

            # [1, state_dim]

            n_look_ahead = N - counter
            if USE_FIXED_MPC_HORIZON:
                n_look_ahead = MPC_HORIZON
            if n_look_ahead == 0:
                break

            # start_time = time.time()
            # idx of current observation
            idx = episode.get_latest_idx()
            mpc_start_time = time.time()
            mpc_input_data = mpc_input_builder.get_dynamics_model_input(idx, n_history=n_history)
            state_cur = mpc_input_data['states']
            action_his = mpc_input_data['actions']

            if mpc_out is not None:
                action_seq_rollout_init = mpc_out['action_seq'][1:]
            else:
                action_seq_rollout_init = None

            # run MPPI
            z_cur = None
            with torch.no_grad():
                z_cur = model.compute_z_state(state_cur.unsqueeze(0).cuda())['z'].squeeze(0)



            mpc_out = planner.trajectory_optimization(state_cur=z_cur,
                                                      action_his=action_his,
                                                      obs_goal=z_object_goal_np,
                                                      model_dy=model,
                                                      action_seq_rollout_init=action_seq_rollout_init,
                                                      n_look_ahead=n_look_ahead,
                                                      eval_indices=object_indices,
                                                      rollout_best_action_sequence=True,
                                                      verbose=True,
                                                      )

            print("MPC step took %.4f seconds" %(time.time() - mpc_start_time))
            action_seq_mpc = mpc_out['action_seq'].cpu().numpy()


        # Rollout with ground truth simulator dynamics
        action_seq_mpc = torch_utils.cast_to_numpy(mpc_out['action_seq'])
        env2.set_simulator_state_from_observation_dict(env2.get_mutable_context(), obs_cur)
        obs_mpc_gt = env_utils.rollout_action_sequence(env2, action_seq_mpc)['observations']
        state_pred_mpc = torch_utils.cast_to_numpy(mpc_out['state_pred'])

        vis['mpc_3D'].delete()
        vis['mpc_GT_3D'].delete()

        L = len(obs_mpc_gt)
        print("L", L)
        if L == 0:
            break
        for i in range(L):
            # red
            color = np.array([1, 0, 0]) * get_color_intensity(i, L)
            state_pred = state_pred_mpc[i, :]
            state_pred = np.expand_dims(state_pred, 0)  # may need to expand dims here
            pts_pred = keypoints_3D_from_dynamics_model_output(state_pred, K).squeeze()
            name = "mpc_3D/%d" % (i)
            meshcat_utils.visualize_points(vis=vis,
                                           name=name,
                                           pts=pts_pred,
                                           color=color,
                                           size=0.01,
                                           )

            # ground truth rollout of the MPC action_seq
            name = "mpc_GT_3D/%d" % (i)
            T_W_obj = slider_pose_from_observation(obs_mpc_gt[i])

            # green
            color = np.array([1, 1, 0]) * get_color_intensity(i, L)
            meshcat_utils.visualize_points(vis=vis,
                                           name=name,
                                           pts=z_keypoints_obj,
                                           color=color,
                                           size=0.01,
                                           T=T_W_obj)

        action_cur = action_seq_mpc[0]

        print("action_cur", action_cur)
        # print("action_GT", initial_cond['action'])
        input("press Enter to continue")

        # add observation actions to the episode
        obs_cur = env.get_observation()
        episode.replace_observation_action(obs_cur, action_cur)

        # step the simulator
        env.step(action_cur)

        # visualize current keypoint positions
        obs_cur = env.get_observation()
        T_W_obj = slider_pose_from_observation(obs_cur)

        # yellow
        color = np.array([1, 1, 0])
        meshcat_utils.visualize_points(vis=vis,
                                       name="keypoint_cur",
                                       pts=z_keypoints_obj,
                                       color=color,
                                       size=0.02,
                                       T=T_W_obj)

        action_seq_mpc = action_seq_mpc[1:]
        state_pred_mpc = state_pred_mpc[1:]

    obs_final = env.get_observation()

    pose_error = compute_pose_error(obs_rollout_gt[-1],
                                    obs_final)

    print("position_error: %.3f"  %(pose_error['position_error']))
    print("angle error degrees: %.3f" %(pose_error['angle_error_degrees']))
data = load_pickle(data_file)
pts = data['plan']['plan_data'][-1]['dynamics_model_input_data'][
    'visual_observation']['pts_W']
print("pts\n", pts)

centroid = np.mean(pts, axis=0)
pts_centered = pts - centroid
save_data = {'object_points': pts_centered.tolist()}
save_file = "object_points_master.yaml"
save_yaml(save_data, save_file)

# do some meshcat debug
vis = meshcat_utils.make_default_visualizer_object()
meshcat_utils.visualize_points(vis,
                               "object_points_centered",
                               pts_centered,
                               color=[0, 0, 255],
                               size=0.01)

meshcat_utils.visualize_points(vis,
                               "object_points_world",
                               pts,
                               color=[0, 255, 0],
                               size=0.01)

# do procustes
T = transform_utils.procrustes_alignment(pts_centered, pts)
pts_transformed = transform_utils.transform_points_3D(T, pts_centered)

meshcat_utils.visualize_points(vis,
                               "object_points_aligned",
def visualize_episode_data_single_timestep(
    vis,
    dataset,
    camera_name,
    episode,
    episode_idx,
    display_idx,
):

    image_episode_idx = episode.image_episode_idx_from_query_idx(episode_idx)
    image_data = episode.image_episode.get_image_data(camera_name=camera_name,
                                                      idx=image_episode_idx)

    depth = image_data['depth_int16'] / DEPTH_IM_SCALE

    # pointcloud
    name = "pointclouds/%d" % (display_idx)
    pdc_meshcat_utils.visualize_pointcloud(
        vis,
        name,
        depth=depth,
        K=image_data['K'],
        rgb=image_data['rgb'],
        T_world_camera=image_data['T_world_camera'])

    data = dataset._getitem(episode,
                            episode_idx,
                            n_history=1,
                            rollout_length=0)
    # action position
    # name = "actions/%d" %(display_idx)
    # actions = torch_utils.cast_to_numpy(data['actions'])
    #
    # meshcat_utils.visualize_points(vis,
    #                                name,
    #                                actions,
    #                                color=[255,0,0],
    #                                size=0.01,
    #                                )

    # observation position
    name = "observations/%d" % (display_idx)
    observations = torch_utils.cast_to_numpy(data['observations'])
    # print("observations.shape", observations.shape)
    # print("observations", observations)
    meshcat_utils.visualize_points(
        vis,
        name,
        observations,
        color=[0, 255, 0],
        size=0.01,
    )

    # keypoints
    if True:
        name = "keypoints/%d" % (display_idx)
        keypoints = torch_utils.cast_to_numpy(
            data['visual_observation_func_collated']['keypoints_3d'][0])
        # print("keypoints.shape", keypoints.shape)
        meshcat_utils.visualize_points(
            vis,
            name,
            keypoints,
            color=[0, 255, 0],
            size=0.01,
        )