示例#1
0
def visualize_autoencoder_result():

    d = load_model()
    dataset = d['dataset']
    model = d['model']

    model = model.train()
    model = model.cuda()


    for i in range(10):
        idx = np.random.randint(0, len(dataset))
        data = dataset[idx]
        input = data['input_tensor'].unsqueeze(0).cuda()
        out = model(input)



        target_tensor = data['target_tensor'].squeeze()
        target_pred = out['output'].squeeze()
        keypoints_xy = out['expected_xy'].squeeze()



        print("keypoints_xy[0]", keypoints_xy[0])

        print("target_tensor.shape", target_tensor.shape)
        print("target_pred.shape", target_pred.shape)

        print("target_pred.dtype", target_pred.dtype)
        print("target_pred.max()", target_pred.max())
        print("target_pred.min()", target_pred.min())

        target_tensor_np = convert_float_image_to_uint8(torch_utils.cast_to_numpy(target_tensor))
        target_pred_np = convert_float_image_to_uint8(torch_utils.cast_to_numpy(target_pred))

        print(target_pred_np.shape)
        print(target_tensor_np.shape)


        # draw reticles on input image
        H, W, _ = data['input'].shape
        keypoints_uv = torch_utils.convert_xy_to_uv_coordinates(keypoints_xy, H=H, W=W)


        input_wr = np.copy(data['input'])
        draw_reticles(input_wr,
                                 u_vec=keypoints_uv[:, 0],
                                 v_vec=keypoints_uv[:, 1],
                                 label_color=[0,255,0])

        print("type(input_wr)", type(input_wr))

        figsize = 2*np.array([4.8,6.4])
        fig = plt.figure(figsize=figsize)
        ax = fig.subplots(1,3)
        ax[0].imshow(input_wr)
        ax[1].imshow(target_tensor_np, cmap='gray', vmin=0, vmax=255)
        ax[2].imshow(target_pred_np, cmap='gray', vmin=0, vmax=255)
        plt.show()
def visualize_model_prediction_single_timestep(
    vis,
    config,
    z_pred,  # [z_dim]
    display_idx,
    name_prefix=None,
    color=None,
    display_robot_state=True,
):

    if color is None:
        color = [0, 0, 255]

    idx_dict = get_object_and_robot_state_indices(config)

    z_object = z_pred[idx_dict['object_indices']].reshape(
        config['dataset']['object_state_shape'])
    z_robot = z_pred[idx_dict['robot_indices']].reshape(
        config['dataset']['robot_state_shape'])

    name = "z_object/%d" % (display_idx)
    if name_prefix is not None:
        name = name_prefix + "/" + name
    meshcat_utils.visualize_points(
        vis,
        name,
        torch_utils.cast_to_numpy(z_object),
        color=color,
        size=0.01,
    )

    if display_robot_state:
        name = "z_robot/%d" % (display_idx)
        if name_prefix is not None:
            name = name_prefix + "/" + name
        meshcat_utils.visualize_points(
            vis,
            name,
            torch_utils.cast_to_numpy(z_robot),
            color=color,
            size=0.01,
        )
示例#3
0
def main():
    # load dynamics model
    model_dict = load_model_state_dict()
    model = model_dict['model_dy']
    model_dd = model_dict['model_dd']
    config = model.config

    env_config = load_yaml(os.path.join(get_project_root(), 'experiments/exp_20_mugs/config.yaml'))
    env_config['env']['observation']['depth_int16'] = True
    n_history = config['train']['n_history']

    initial_cond = generate_initial_condition(env_config, push_length=PUSH_LENGTH)
    env_config = initial_cond['config']

    # enable the right observations

    camera_name = model_dict['metadata']['camera_name']
    spatial_descriptor_data = model_dict['spatial_descriptor_data']
    ref_descriptors = spatial_descriptor_data['spatial_descriptors']
    K = ref_descriptors.shape[0]

    ref_descriptors = torch.Tensor(ref_descriptors).cuda()  # put them on the GPU

    print("ref_descriptors\n", ref_descriptors)
    print("ref_descriptors.shape", ref_descriptors.shape)

    # create the environment
    # create the environment
    env = DrakeMugsEnv(env_config)
    env.reset()

    T_world_camera = env.camera_pose(camera_name)
    camera_K_matrix = env.camera_K_matrix(camera_name)

    # create another environment for doing rollouts
    env2 = DrakeMugsEnv(env_config, visualize=False)
    env2.reset()

    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.drake_pusher_position_3D(config)
    visual_observation_function = \
        VisualObservationFunctionFactory.descriptor_keypoints_3D(config=config,
                                                                 camera_name=camera_name,
                                                                 model_dd=model_dd,
                                                                 ref_descriptors=ref_descriptors,
                                                                 K_matrix=camera_K_matrix,
                                                                 T_world_camera=T_world_camera,
                                                                 )

    episode = OnlineEpisodeReader()
    mpc_input_builder = DynamicsModelInputBuilder(observation_function=observation_function,
                                                  visual_observation_function=visual_observation_function,
                                                  action_function=action_function,
                                                  episode=episode)

    vis = meshcat_utils.make_default_visualizer_object()
    vis.delete()

    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    obs_init = env.get_observation()

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    def goal_func(obs_tmp):
        state_tmp = mpc_input_builder.get_state_input_single_timestep({'observation': obs_tmp})['state']
        return model.compute_z_state(state_tmp.unsqueeze(0))['z_object'].flatten()


    #
    idx = episode.get_latest_idx()
    obs_raw = episode.get_observation(idx)
    z_object_goal = goal_func(obs_raw)
    z_keypoints_init_W = keypoints_3D_from_dynamics_model_output(z_object_goal, K)
    z_keypoints_init_W = torch_utils.cast_to_numpy(z_keypoints_init_W)

    z_keypoints_obj = keypoints_world_frame_to_object_frame(z_keypoints_init_W,
                                                          T_W_obj=slider_pose_from_observation(obs_init))

    color = [1, 0, 0]
    meshcat_utils.visualize_points(vis=vis,
                                   name="keypoints_W",
                                   pts=z_keypoints_init_W,
                                   color=color,
                                   size=0.02,
                                   )

    # input("press Enter to continue")

    # rollout single action sequence using the simulator
    action_sequence_np = torch_utils.cast_to_numpy(initial_cond['action_sequence'])
    N = action_sequence_np.shape[0]
    obs_rollout_gt = env_utils.rollout_action_sequence(env, action_sequence_np)[
        'observations']

    # using the vision model to get "goal" keypoints
    z_object_goal = goal_func(obs_rollout_gt[-1])
    z_object_goal_np = torch_utils.cast_to_numpy(z_object_goal)
    z_keypoints_goal = keypoints_3D_from_dynamics_model_output(z_object_goal, K)
    z_keypoints_goal = torch_utils.cast_to_numpy(z_keypoints_goal)

    # visualize goal keypoints
    color = [0, 1, 0]
    meshcat_utils.visualize_points(vis=vis,
                                   name="goal_keypoints",
                                   pts=z_keypoints_goal,
                                   color=color,
                                   size=0.02,
                                   )

    # input("press Enter to continue")

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # [n_history, state_dim]
    idx = episode.get_latest_idx()

    dyna_net_input = mpc_input_builder.get_dynamics_model_input(idx, n_history=n_history)
    state_init = dyna_net_input['states'].cuda() # [n_history, state_dim]
    action_init = dyna_net_input['actions'] # [n_history, action_dim]


    print("state_init.shape", state_init.shape)
    print("action_init.shape", action_init.shape)


    action_seq_gt_torch = torch_utils.cast_to_torch(initial_cond['action_sequence'])
    action_input = torch.cat((action_init[:(n_history-1)], action_seq_gt_torch), dim=0).cuda()
    print("action_input.shape", action_input.shape)


    # rollout using the ground truth actions and learned model
    # need to add the batch dim to do that
    z_init = model.compute_z_state(state_init)['z']
    rollout_pred = rollout_model(state_init=z_init.unsqueeze(0),
                                 action_seq=action_input.unsqueeze(0),
                                 dynamics_net=model,
                                 compute_debug_data=True)

    state_pred_rollout = rollout_pred['state_pred']

    print("state_pred_rollout.shape", state_pred_rollout.shape)

    for i in range(N):
        # vis GT for now
        name = "GT_3D/%d" % (i)
        T_W_obj = slider_pose_from_observation(obs_rollout_gt[i])
        # print("T_W_obj", T_W_obj)

        # green
        color = np.array([0, 1, 0]) * get_color_intensity(i, N)
        meshcat_utils.visualize_points(vis=vis,
                                       name=name,
                                       pts=z_keypoints_obj,
                                       color=color,
                                       size=0.01,
                                       T=T_W_obj)

        # red
        color = np.array([0, 0, 1]) * get_color_intensity(i, N)
        state_pred = state_pred_rollout[:, i, :]
        pts_pred = keypoints_3D_from_dynamics_model_output(state_pred, K).squeeze()
        pts_pred = pts_pred.detach().cpu().numpy()
        name = "pred_3D/%d" % (i)
        meshcat_utils.visualize_points(vis=vis,
                                       name=name,
                                       pts=pts_pred,
                                       color=color,
                                       size=0.01,
                                       )

    # input("finished visualizing GT rollout\npress Enter to continue")
    index_dict = get_object_and_robot_state_indices(config)
    object_indices = index_dict['object_indices']

    # reset the environment and use the MPC controller to stabilize this
    # now setup the MPC to try to stabilize this . . . .
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    episode.clear()

    # add just some large number of these
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # input("press Enter to continue")

    # make a planner config
    planner_config = copy.copy(config)
    config_tmp = load_yaml(os.path.join(get_project_root(), 'experiments/drake_pusher_slider/eval_config.yaml'))
    planner_config['mpc'] = config_tmp['mpc']
    planner = None
    if PLANNER_TYPE == "random_shooting":
        planner = RandomShootingPlanner(planner_config)
    elif PLANNER_TYPE == "mppi":
        planner = PlannerMPPI(planner_config)
    else:
        raise ValueError("unknown planner type: %s" % (PLANNER_TYPE))

    mpc_out = None
    action_seq_mpc = None
    state_pred_mpc = None
    counter = -1
    while True:
        counter += 1
        print("\n\n-----Running MPC Optimization: Counter (%d)-------" % (counter))

        obs_cur = env.get_observation()
        episode.add_observation_only(obs_cur)

        if counter == 0 or REPLAN:
            print("replanning")
            ####### Run the MPC ##########

            # [1, state_dim]

            n_look_ahead = N - counter
            if USE_FIXED_MPC_HORIZON:
                n_look_ahead = MPC_HORIZON
            if n_look_ahead == 0:
                break

            # start_time = time.time()
            # idx of current observation
            idx = episode.get_latest_idx()
            mpc_start_time = time.time()
            mpc_input_data = mpc_input_builder.get_dynamics_model_input(idx, n_history=n_history)
            state_cur = mpc_input_data['states']
            action_his = mpc_input_data['actions']

            if mpc_out is not None:
                action_seq_rollout_init = mpc_out['action_seq'][1:]
            else:
                action_seq_rollout_init = None

            # run MPPI
            z_cur = None
            with torch.no_grad():
                z_cur = model.compute_z_state(state_cur.unsqueeze(0).cuda())['z'].squeeze(0)



            mpc_out = planner.trajectory_optimization(state_cur=z_cur,
                                                      action_his=action_his,
                                                      obs_goal=z_object_goal_np,
                                                      model_dy=model,
                                                      action_seq_rollout_init=action_seq_rollout_init,
                                                      n_look_ahead=n_look_ahead,
                                                      eval_indices=object_indices,
                                                      rollout_best_action_sequence=True,
                                                      verbose=True,
                                                      )

            print("MPC step took %.4f seconds" %(time.time() - mpc_start_time))
            action_seq_mpc = mpc_out['action_seq'].cpu().numpy()


        # Rollout with ground truth simulator dynamics
        action_seq_mpc = torch_utils.cast_to_numpy(mpc_out['action_seq'])
        env2.set_simulator_state_from_observation_dict(env2.get_mutable_context(), obs_cur)
        obs_mpc_gt = env_utils.rollout_action_sequence(env2, action_seq_mpc)['observations']
        state_pred_mpc = torch_utils.cast_to_numpy(mpc_out['state_pred'])

        vis['mpc_3D'].delete()
        vis['mpc_GT_3D'].delete()

        L = len(obs_mpc_gt)
        print("L", L)
        if L == 0:
            break
        for i in range(L):
            # red
            color = np.array([1, 0, 0]) * get_color_intensity(i, L)
            state_pred = state_pred_mpc[i, :]
            state_pred = np.expand_dims(state_pred, 0)  # may need to expand dims here
            pts_pred = keypoints_3D_from_dynamics_model_output(state_pred, K).squeeze()
            name = "mpc_3D/%d" % (i)
            meshcat_utils.visualize_points(vis=vis,
                                           name=name,
                                           pts=pts_pred,
                                           color=color,
                                           size=0.01,
                                           )

            # ground truth rollout of the MPC action_seq
            name = "mpc_GT_3D/%d" % (i)
            T_W_obj = slider_pose_from_observation(obs_mpc_gt[i])

            # green
            color = np.array([1, 1, 0]) * get_color_intensity(i, L)
            meshcat_utils.visualize_points(vis=vis,
                                           name=name,
                                           pts=z_keypoints_obj,
                                           color=color,
                                           size=0.01,
                                           T=T_W_obj)

        action_cur = action_seq_mpc[0]

        print("action_cur", action_cur)
        # print("action_GT", initial_cond['action'])
        input("press Enter to continue")

        # add observation actions to the episode
        obs_cur = env.get_observation()
        episode.replace_observation_action(obs_cur, action_cur)

        # step the simulator
        env.step(action_cur)

        # visualize current keypoint positions
        obs_cur = env.get_observation()
        T_W_obj = slider_pose_from_observation(obs_cur)

        # yellow
        color = np.array([1, 1, 0])
        meshcat_utils.visualize_points(vis=vis,
                                       name="keypoint_cur",
                                       pts=z_keypoints_obj,
                                       color=color,
                                       size=0.02,
                                       T=T_W_obj)

        action_seq_mpc = action_seq_mpc[1:]
        state_pred_mpc = state_pred_mpc[1:]

    obs_final = env.get_observation()

    pose_error = compute_pose_error(obs_rollout_gt[-1],
                                    obs_final)

    print("position_error: %.3f"  %(pose_error['position_error']))
    print("angle error degrees: %.3f" %(pose_error['angle_error_degrees']))
def visualize_episode_data_single_timestep(
    vis,
    dataset,
    camera_name,
    episode,
    episode_idx,
    display_idx,
):

    image_episode_idx = episode.image_episode_idx_from_query_idx(episode_idx)
    image_data = episode.image_episode.get_image_data(camera_name=camera_name,
                                                      idx=image_episode_idx)

    depth = image_data['depth_int16'] / DEPTH_IM_SCALE

    # pointcloud
    name = "pointclouds/%d" % (display_idx)
    pdc_meshcat_utils.visualize_pointcloud(
        vis,
        name,
        depth=depth,
        K=image_data['K'],
        rgb=image_data['rgb'],
        T_world_camera=image_data['T_world_camera'])

    data = dataset._getitem(episode,
                            episode_idx,
                            n_history=1,
                            rollout_length=0)
    # action position
    # name = "actions/%d" %(display_idx)
    # actions = torch_utils.cast_to_numpy(data['actions'])
    #
    # meshcat_utils.visualize_points(vis,
    #                                name,
    #                                actions,
    #                                color=[255,0,0],
    #                                size=0.01,
    #                                )

    # observation position
    name = "observations/%d" % (display_idx)
    observations = torch_utils.cast_to_numpy(data['observations'])
    # print("observations.shape", observations.shape)
    # print("observations", observations)
    meshcat_utils.visualize_points(
        vis,
        name,
        observations,
        color=[0, 255, 0],
        size=0.01,
    )

    # keypoints
    if True:
        name = "keypoints/%d" % (display_idx)
        keypoints = torch_utils.cast_to_numpy(
            data['visual_observation_func_collated']['keypoints_3d'][0])
        # print("keypoints.shape", keypoints.shape)
        meshcat_utils.visualize_points(
            vis,
            name,
            keypoints,
            color=[0, 255, 0],
            size=0.01,
        )
def evaluate_mpc(
    model_dy,  # dynamics model
    env,  # the environment
    episode,  # OnlineEpisodeReader
    mpc_input_builder,  # DynamicsModelInputBuilder
    planner,  # RandomShooting planner
    eval_indices=None,
    goal_func=None,  # function that gets goal from observation
    config=None,
    wait_for_user_input=False,
    save_dir=None,
    model_name="",
    experiment_name="",
    generate_initial_condition_func=None,
    # (optional) function to generate initial condition, takes episode length N as parameter
):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # must specify initial condition distribution
    assert generate_initial_condition_func is not None

    save_yaml(config, os.path.join(save_dir, 'config.yaml'))
    writer = SummaryWriter(log_dir=save_dir)

    pandas_data_list = []
    for episode_length in config['eval']['episode_length']:
        counter = 0
        seed = 0
        while counter < config['eval']['num_episodes']:

            start_time = time.time()
            seed += 1
            set_seed(seed)  # make it repeatable
            # initial_cond = generate_initial_condition(config, N=episode_length)
            initial_cond = generate_initial_condition_func(N=episode_length)

            env.set_initial_condition_from_dict(initial_cond)

            action_sequence_np = torch_utils.cast_to_numpy(
                initial_cond['action_sequence'])
            episode_data = mpc_single_episode(
                model_dy=model_dy,
                env=env,
                action_sequence=action_sequence_np,
                action_zero=np.zeros(2),
                episode=episode,
                mpc_input_builder=mpc_input_builder,
                planner=planner,
                eval_indices=eval_indices,
                goal_func=goal_func,
                config=config,
                wait_for_user_input=wait_for_user_input,
            )

            # continue if invalid
            if not episode_data['valid']:
                print("invalid episode, skipping")
                continue

            pose_error = compute_pose_error(
                obs=episode_data['obs_mpc_final'],
                obs_goal=episode_data['obs_goal'],
            )

            object_delta = compute_pose_error(
                obs=episode_data['obs_init'],
                obs_goal=episode_data['obs_goal'])

            print("object_delta\n", object_delta)

            if wait_for_user_input:
                print("pose_error\n", pose_error)

            pandas_data = {
                'episode_length': episode_length,
                'seed': counter,
                'model_name': model_name,
                'experiment_name': experiment_name,
                'object_pos_delta': object_delta['position_error'],
                'object_angle_delta': object_delta['angle_error'],
                'object_angle_delta_degrees':
                object_delta['angle_error_degrees'],
            }

            pandas_data.update(pose_error)
            pandas_data_list.append(pandas_data)

            # log to tensorboard
            for key, val in pose_error.items():
                plot_name = "%s/episode_len_%d" % (key, episode_length)
                writer.add_scalar(plot_name, val, counter)

            writer.flush()

            print("episode [%d/%d], episode_length %d, duration %.2f" %
                  (counter, config['eval']['num_episodes'], episode_length,
                   time.time() - start_time))
            counter += 1

        df_tmp = pd.DataFrame(pandas_data_list)
        keys = ["angle_error_degrees", "position_error"]
        for key in keys:
            for i in range(10):
                mean = df_tmp[key][df_tmp.episode_length ==
                                   episode_length].mean()
                median = df_tmp[key][df_tmp.episode_length ==
                                     episode_length].median()

                plot_name_mean = "mean/%s/episode_len_%d" % (key,
                                                             episode_length)
                writer.add_scalar(plot_name_mean, mean, i)

                plot_name_median = "median/%s/episode_len_%d" % (
                    key, episode_length)
                writer.add_scalar(plot_name_median, median, i)

    # save some data
    df = pd.DataFrame(pandas_data_list)
    df.to_csv(os.path.join(save_dir, "data.csv"))
示例#6
0
    def _on_compute_control_action(
        self,
        msg,
        visualize=True,
    ):
        """
        Computes the current control action

        :param msg:
        :type msg:
        :return:
        :rtype:
        """
        print("\n\n----------")

        start_time = time.time()

        assert self._state in [
            ControllerState.PLAN_READY, ControllerState.RUNNING
        ]

        # allow setting the visualize flag from the message
        try:
            visualize = bool(int(msg['debug']))
        except KeyError:
            pass

        # add data to the OnlineEpisodeReader
        episode = self._state_dict['episode']
        input_builder = self._state_dict['input_builder']
        plan = self._state_dict['plan']

        observation = msg['data']['observations']
        action_dict = msg['data']['actions']

        if self._state == ControllerState.PLAN_READY:
            # this is the first time through the loop
            # then add a few to populate n_his
            for i in range(self._n_history):
                # episode.add_observation_action(copy.deepcopy(observation), copy.deepcopy(action_dict))
                episode.add_data(copy.deepcopy(msg['data']))

        episode.add_data(msg['data'])

        # run the planner

        # seed with previous actions
        mpc_out = self._state_dict['mpc_out']
        n_look_ahead = None  # this is the MPC horizon
        mpc_horizon_type = self._planner.config['mpc']['hardware'][
            'mpc_horizon']["type"]
        mpc_hardware_config = self._planner.config['mpc']['hardware']

        # previous actions to seed with
        # compute the MPC horizon
        action_seq_rollout_init = None
        if mpc_out is not None:
            # this means it's not our first time through the loop
            # the plan is already running
            action_seq_rollout_init = mpc_out['action_seq'][1:]

            if mpc_horizon_type == "PLAN_LENGTH":
                n_look_ahead = action_seq_rollout_init.shape[0]
                # this means we are at the end of the plan
                # so send the stop message
                if n_look_ahead == 0:
                    print("Plan finished, sending STOP message")
                    return self.make_stop_message()
            if mpc_horizon_type == "MIN_HORIZON":
                n_look_ahead = action_seq_rollout_init.shape[0]
                H_min = mpc_hardware_config['mpc_horizon']['H_min']
                n_look_ahead = max(H_min, n_look_ahead)
                # maybe set overwrite the config to be traj cost . . .
            elif mpc_horizon_type == "FIXED":
                n_look_ahead = mpc_hardware_config['mpc_horizon']["H_fixed"]
            else:
                raise ValueError("unknow mpc_horizon_type")

            # add zeros to the end of the action trajectory as a good
            # starting point
            # extend the previous action, either with zeros or by repeating last action
            if action_seq_rollout_init.shape[0] < n_look_ahead:
                if mpc_hardware_config['action_extension_type'] == "CONSTANT":
                    num_steps = n_look_ahead - action_seq_rollout_init.shape[0]
                    # [action_dim]
                    action_extend = action_seq_rollout_init[-1]
                    action_extend = action_extend.unsqueeze(0).expand(
                        [num_steps, -1])

                    # print("action_seq_rollout_init.shape", action_seq_rollout_init.shape)
                    # print("action_extend.shape", action_extend.shape)
                    action_seq_rollout_init = torch.cat(
                        (action_seq_rollout_init, action_extend), dim=0)
                elif mpc_hardware_config['action_extension_type'] == "ZERO":
                    num_steps = n_look_ahead - action_seq_rollout_init.shape[0]
                    action_seq_zero = torch.zeros([num_steps, 2]).to(
                        action_seq_rollout_init.device)
                    action_seq_rollout_init = torch.cat(
                        (action_seq_rollout_init, action_seq_zero), dim=0)
        else:
            if mpc_horizon_type == "FIXED":
                n_look_ahead = mpc_hardware_config['mpc_horizon']["H_fixed"]
            else:
                n_look_ahead = mpc_hardware_config['mpc_horizon']["H_init"]

            action_seq_rollout_init = None

        print("n_look_ahead", n_look_ahead)
        print("plan_counter", plan.counter)
        start_time_tmp = time.time()
        idx = episode.get_latest_idx()
        mpc_input_data = input_builder.get_dynamics_model_input(
            idx, n_history=self._n_history)
        print("computing dynamics model input took",
              time.time() - start_time_tmp)
        start_time_tmp = time.time()
        state_cur = mpc_input_data['states']
        action_his = mpc_input_data['actions']

        current_reward_data = None

        # run the planner
        with torch.no_grad():
            # z_goal_dict = plan.data[-1]['dynamics_model_input_data']['z']
            # obs_goal = plan.data[-1]['dynamics_model_input_data']['z']['z_object_flat']

            # convert state_cur to z_cur
            z_cur_dict = self._model_dy.compute_z_state(state_cur)
            z_cur = z_cur_dict['z']
            z_cur_no_his = z_cur[-1]
            print("z_cur.shape")

            # for computing current cost
            z_batch = None

            # for now we support just final state rather than trajectory costs
            z_goal = None
            if self._planner.config['mpc']['reward'][
                    "goal_type"] == "TRAJECTORY":
                z_goal = plan.get_trajectory_goal(
                    counter=self._state_dict['action_counter'],
                    n_look_ahead=n_look_ahead)

                # [1, 1, z_dim]
                z_batch = z_cur_no_his.unsqueeze(0).unsqueeze(0)

                # [1, n_look_aheda, z_dim]
                z_batch = z_batch.expand([-1, n_look_ahead, -1])
            elif self._planner.config['mpc']['reward'][
                    "goal_type"] == "FINAL_STATE":
                z_goal = plan.get_final_state_goal()

                # [1, 1, z_dim]
                z_batch = z_cur_no_his.unsqueeze(0).unsqueeze(0)
            else:
                raise ValueError("unknown goal type")

            print("z_batch.shape", z_batch.shape)

            obs_goal = torch.index_select(z_goal,
                                          dim=-1,
                                          index=self._object_indices)
            print("obs_goal.shape", obs_goal.shape)

            # compute the cost of stopping now so we can compare it to the cost
            # of running the trajectory optimization

            # note this is on the CPU
            # [1, 1, z_dim]
            z_batch = z_cur_no_his.unsqueeze(0).unsqueeze(0)
            # [goal_dim]
            obs_goal_final = torch.index_select(plan.get_final_state_goal(),
                                                dim=-1,
                                                index=self._object_indices)
            current_reward_data = \
                planner_utils.evaluate_model_rollout(state_pred=z_batch,
                                                     obs_goal=obs_goal_final,
                                                     eval_indices=self._object_indices,

                                                     **self._planner.config['mpc']['reward'])

            print("current reward:", current_reward_data['best_reward'])
            if current_reward_data['best_reward'] > mpc_hardware_config[
                    'goal_reward_threshold']:
                print("Below goal reward threshold, stopping")
                return self.make_stop_message()

            # run the planner
            mpc_out = self._planner.trajectory_optimization(
                state_cur=z_cur,
                action_his=action_his,
                obs_goal=obs_goal,
                model_dy=self._model_dy,
                action_seq_rollout_init=action_seq_rollout_init,
                n_look_ahead=n_look_ahead,
                eval_indices=self._object_indices,
                rollout_best_action_sequence=self._debug,
                verbose=self._debug,
                add_grid_action_samples=False,
            )

        # update the action that was actually applied
        # equivalent to
        action_seq_mpc = mpc_out['action_seq']
        action_dict['ee_setpoint']['setpoint_linear_velocity']['x'] = float(
            action_seq_mpc[0][0])
        action_dict['ee_setpoint']['setpoint_linear_velocity']['x'] = float(
            action_seq_mpc[0][1])

        state_pred = mpc_out['state_pred']

        if mpc_out['reward'].cpu() < current_reward_data['best_reward'].cpu(
        ) + mpc_hardware_config['reward_improvement_tol']:

            if mpc_hardware_config['terminate_if_no_improvement']:
                print("Traj opt didn't yield successive improvement, STOPPING")
                return self.make_stop_message()
            else:
                print(
                    "Traj opt didn't yield successive improvement, HOLDING STILL"
                )
                return self.make_zero_action_message()

        if self._debug:
            # pass
            print("action_seq_mpc\n", action_seq_mpc)

        if visualize:
            vis = meshcat_utils.make_default_visualizer_object()
            vis["mpc"].delete()

            visualize_model_prediction_single_timestep(vis,
                                                       self._config,
                                                       z_cur[0],
                                                       display_idx=0,
                                                       name_prefix="start",
                                                       color=[0, 0, 255])

            # visualize pointcloud of current position
            start_data = episode.get_data(0)
            depth = start_data['observations']['images'][self._camera_info[
                'camera_name']]['depth_int16'] / DEPTH_IM_SCALE
            rgb = start_data['observations']['images'][
                self._camera_info['camera_name']]['rgb']
            # pointcloud
            name = "mpc/pointclouds/start"
            pdc_meshcat_utils.visualize_pointcloud(
                vis,
                name,
                depth=depth,
                K=self._camera_info['K'],
                rgb=rgb,
                T_world_camera=self._camera_info['T_world_camera'])

            #
            # visualize_model_prediction_single_timestep(vis,
            #                                            self._config,
            #                                            z_goal_dict['z'],
            #                                            display_idx=0,
            #                                            name_prefix="goal",
            #                                            color=[0, 255, 0])
            #
            # goal_data = plan.data[-1]
            # depth = goal_data['observations']['images'][self._camera_info['camera_name']][
            #             'depth_int16'] / DEPTH_IM_SCALE
            # rgb = goal_data['observations']['images'][self._camera_info['camera_name']]['rgb']
            # # pointcloud
            # name = "pointclouds/goal"
            # pdc_meshcat_utils.visualize_pointcloud(vis,
            #                                        name,
            #                                        depth=depth,
            #                                        K=self._camera_info['K'],
            #                                        rgb=rgb,
            #                                        T_world_camera=self._camera_info['T_world_camera'])

            for i in range(state_pred.shape[0]):
                visualize_model_prediction_single_timestep(
                    vis,
                    self._config,
                    state_pred[i],
                    display_idx=(i + 1),
                    name_prefix="mpc",
                    color=[255, 0, 0],
                )
            # show pointclouds . . .

        # store the results of this computation
        self._state_dict['action_counter'] += 1
        self._state_dict['timestamp_system'].append(time.time())
        self._state_dict['mpc_out'] = mpc_out
        self._state = ControllerState.RUNNING  # or maybe also STOPPED/FINISHED

        # data we want to save out from the MPC
        mpc_save_data = {
            'n_look_ahead': n_look_ahead,
            'action_seq': mpc_out['action_seq'].cpu(),
            'reward': mpc_out['reward'].cpu(),
            'state_pred': mpc_out['state_pred'].cpu(),
            'current_reward': current_reward_data,
        }

        # add some data for saving out later
        idx = episode.get_latest_idx()
        episode_data = episode.get_data(idx)
        episode_data['mpc'] = mpc_save_data

        plan.increment_counter()
        action_seq_mpc_np = torch_utils.cast_to_numpy(action_seq_mpc)

        # todo: update the action in the episode reader. Currently
        # we have an off by one error
        if False:
            action = action_seq_mpc[0]
            idx = episode.get_latest_idx()
            episode_data = episode.get_data(idx)
            episode_data['actions']['mpc']  # would be pretty hacky . . . .
            episode_data['actions']['setpoint_linear_velocity']['x'] = action[
                0]
            episode_data['actions']['setpoint_linear_velocity']['y'] = action[
                1]

        resp_data = {
            'action': action_seq_mpc_np[0],
            'action_seq': action_seq_mpc_np,
        }

        resp = {'type': 'CONTROL_ACTION', 'data': resp_data}

        # we are seeing averages around 0.12 seconds (total) for length 10 plan
        print("Compute Control Action Took %.3f" % (time.time() - start_time))

        return resp
示例#7
0
def main():
    # load dynamics model
    model_dict = load_autoencoder_model()
    model = model_dict['model_dy']
    model_ae = model_dict['model_ae']
    visual_observation_function = model_dict['visual_observation_function']

    config = model.config

    env_config = load_yaml(
        os.path.join(get_project_root(),
                     'experiments/exp_18_box_on_side/config.yaml'))
    env_config['env']['observation']['depth_int16'] = True
    n_history = config['train']['n_history']

    # create the environment
    # create the environment
    env = DrakePusherSliderEnv(env_config)
    env.reset()

    # create another environment for doing rollouts
    env2 = DrakePusherSliderEnv(env_config, visualize=False)
    env2.reset()

    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.drake_pusher_position_3D(
        config)

    episode = OnlineEpisodeReader()
    mpc_input_builder = DynamicsModelInputBuilder(
        observation_function=observation_function,
        visual_observation_function=visual_observation_function,
        action_function=action_function,
        episode=episode)

    vis = meshcat_utils.make_default_visualizer_object()
    vis.delete()

    initial_cond = get_initial_state()
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    obs_init = env.get_observation()

    # visualize starting position of the object
    print("obs_init.keys()", obs_init.keys())
    print("obs_init['slider']['position']", obs_init['slider']['position'])
    T = DrakePusherSliderEnv.object_position_from_observation(obs_init)
    vis['start_pose'].set_object(triad(scale=0.1))
    vis['state_pose'].set_transform(T)

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    #### ROLLOUT THE ACTION SEQUENCE USING THE SIMULATOR ##########
    # rollout single action sequence using the simulator
    gt_rollout_data = env_utils.rollout_action_sequence(
        env, initial_cond['action_sequence'].cpu().numpy())
    env_obs_rollout_gt = gt_rollout_data['observations']
    gt_rollout_episode = gt_rollout_data['episode_reader']

    for i, env_obs in enumerate(gt_rollout_data['observations']):
        T = DrakePusherSliderEnv.object_position_from_observation(env_obs)
        vis_name = "GT_trajectory/%d" % (i)
        vis[vis_name].set_object(triad(scale=0.1))
        vis[vis_name].set_transform(T)

    action_state_gt = mpc_input_builder.get_action_state_tensors(
        start_idx=0, num_timesteps=N, episode=gt_rollout_episode)

    state_rollout_gt = action_state_gt['states']
    action_rollout_gt = action_state_gt['actions']
    z_object_rollout_gt = model.compute_z_state(
        state_rollout_gt)['z_object_flat']
    print('state_rollout_gt.shape', state_rollout_gt.shape)
    print("z_object_rollout_gt.shape", z_object_rollout_gt.shape)

    def goal_func(obs_tmp):
        state_tmp = mpc_input_builder.get_state_input_single_timestep(
            {'observation': obs_tmp})['state']
        return model.compute_z_state(
            state_tmp.unsqueeze(0))['z_object'].flatten()

    # using the vision model to get "goal" keypoints
    z_object_goal = goal_func(env_obs_rollout_gt[-1])
    z_object_goal_np = torch_utils.cast_to_numpy(z_object_goal)

    # input("press Enter to continue")

    #### ROLLOUT USING LEARNED MODEL + GROUND TRUTH ACTIONS ############
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    # add just some large number of these
    episode.clear()
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # [n_history, state_dim]
    idx = episode.get_latest_idx()

    dyna_net_input = mpc_input_builder.get_dynamics_model_input(
        idx, n_history=n_history)
    state_init = dyna_net_input['states'].cuda()  # [n_history, state_dim]
    action_init = dyna_net_input['actions']  # [n_history, action_dim]

    print("state_init.shape", state_init.shape)
    print("action_init.shape", action_init.shape)
    print("n_history", n_history)

    action_seq_gt_torch = initial_cond['action_sequence']
    action_input = torch.cat(
        (action_init[:(n_history - 1)], action_seq_gt_torch), dim=0).cuda()
    print("action_input.shape", action_input.shape)

    # rollout using the ground truth actions and learned model
    # need to add the batch dim to do that
    z_init = model.compute_z_state(state_init)['z']
    rollout_pred = rollout_model(state_init=z_init.unsqueeze(0),
                                 action_seq=action_input.unsqueeze(0),
                                 dynamics_net=model,
                                 compute_debug_data=True)

    state_pred_rollout = rollout_pred['state_pred'].squeeze(0)

    print("state_pred_rollout.shape", state_pred_rollout.shape)
    # input("press Enter to continue")

    # check L2 distance between predicted and actual
    # basically comparing state_pred_rollout and state_rollout_gt
    print("state_rollout_gt[-1]\n", state_rollout_gt[-1])
    print("state_pred_rollout[-1]\n", state_pred_rollout[-1])

    index_dict = get_object_and_robot_state_indices(config)
    object_indices = index_dict['object_indices']

    # reset the environment and use the MPC controller to stabilize this
    # now setup the MPC to try to stabilize this . . . .
    reset_environment(env, initial_cond['q_pusher'], initial_cond['q_slider'])
    episode.clear()

    # add just some large number of these
    for i in range(n_history):
        action_zero = np.zeros(2)
        obs_tmp = env.get_observation()
        episode.add_observation_action(obs_tmp, action_zero)

    # input("press Enter to continue")

    # make a planner config
    planner_config = copy.copy(config)
    config_tmp = load_yaml(
        os.path.join(get_project_root(),
                     'experiments/drake_pusher_slider/eval_config.yaml'))
    planner_config['mpc'] = config_tmp['mpc']

    planner_config['mpc']['mppi']['terminal_cost_only'] = TERMINAL_COST_ONLY
    planner_config['mpc']['random_shooting'][
        'terminal_cost_only'] = TERMINAL_COST_ONLY

    planner = None
    if PLANNER_TYPE == "random_shooting":
        planner = RandomShootingPlanner(planner_config)
    elif PLANNER_TYPE == "mppi":
        planner = PlannerMPPI(planner_config)
    else:
        raise ValueError("unknown planner type: %s" % (PLANNER_TYPE))

    mpc_out = None
    action_seq_mpc = None
    state_pred_mpc = None
    counter = -1
    while True:
        counter += 1
        print("\n\n-----Running MPC Optimization: Counter (%d)-------" %
              (counter))

        obs_cur = env.get_observation()
        episode.add_observation_only(obs_cur)

        if counter == 0 or REPLAN:
            print("replanning")
            ####### Run the MPC ##########

            # [1, state_dim]

            n_look_ahead = N - counter
            if USE_FIXED_MPC_HORIZON:
                n_look_ahead = MPC_HORIZON
            elif USE_SHORT_HORIZON_MPC:
                n_look_ahead = min(MPC_HORIZON, N - counter)
            if n_look_ahead == 0:
                break

            start_idx = counter
            end_idx = counter + n_look_ahead

            print("start_idx:", start_idx)
            print("end_idx:", end_idx)
            print("n_look_ahead", n_look_ahead)

            # start_time = time.time()
            # idx of current observation
            idx = episode.get_latest_idx()
            mpc_start_time = time.time()
            mpc_input_data = mpc_input_builder.get_dynamics_model_input(
                idx, n_history=n_history)
            state_cur = mpc_input_data['states']
            action_his = mpc_input_data['actions']

            if SEED_WITH_NOMINAL_ACTIONS:
                action_seq_rollout_init = action_seq_gt_torch[
                    start_idx:end_idx]
            else:
                if mpc_out is not None:
                    action_seq_rollout_init = mpc_out['action_seq'][1:]
                    print("action_seq_rollout_init.shape",
                          action_seq_rollout_init.shape)

                    if action_seq_rollout_init.shape[0] < n_look_ahead:
                        num_steps = n_look_ahead - action_seq_rollout_init.shape[
                            0]
                        action_seq_zero = torch.zeros([num_steps, 2])

                        action_seq_rollout_init = torch.cat(
                            (action_seq_rollout_init, action_seq_zero), dim=0)
                        print("action_seq_rollout_init.shape",
                              action_seq_rollout_init.shape)
                else:
                    action_seq_rollout_init = None

            # run MPPI
            z_cur = None
            with torch.no_grad():
                z_cur = model.compute_z_state(
                    state_cur.unsqueeze(0).cuda())['z'].squeeze(0)

            if action_seq_rollout_init is not None:
                n_look_ahead = action_seq_rollout_init.shape[0]

            obs_goal = None
            print("z_object_rollout_gt.shape", z_object_rollout_gt.shape)
            if TRAJECTORY_GOAL:
                obs_goal = z_object_rollout_gt[start_idx:end_idx]

                print("n_look_ahead", n_look_ahead)
                print("obs_goal.shape", obs_goal.shape)

                # add the final goal state on as needed
                if obs_goal.shape[0] < n_look_ahead:
                    obs_goal_final = z_object_rollout_gt[-1].unsqueeze(0)
                    num_repeat = n_look_ahead - obs_goal.shape[0]
                    obs_goal_final_expand = obs_goal_final.expand(
                        [num_repeat, -1])
                    obs_goal = torch.cat((obs_goal, obs_goal_final_expand),
                                         dim=0)
            else:
                obs_goal = z_object_rollout_gt[-1]

            obs_goal = torch_utils.cast_to_numpy(obs_goal)
            print("obs_goal.shape", obs_goal.shape)

            seed = 1

            set_seed(seed)
            mpc_out = planner.trajectory_optimization(
                state_cur=z_cur,
                action_his=action_his,
                obs_goal=obs_goal,
                model_dy=model,
                action_seq_rollout_init=action_seq_rollout_init,
                n_look_ahead=n_look_ahead,
                eval_indices=object_indices,
                rollout_best_action_sequence=True,
                verbose=True,
                add_grid_action_samples=True,
            )

            print("MPC step took %.4f seconds" %
                  (time.time() - mpc_start_time))
            action_seq_mpc = torch_utils.cast_to_numpy(mpc_out['action_seq'])
            state_pred_mpc = torch_utils.cast_to_numpy(mpc_out['state_pred'])

        # Rollout with ground truth simulator dynamics
        env2.set_simulator_state_from_observation_dict(
            env2.get_mutable_context(), obs_cur)
        obs_mpc_gt = env_utils.rollout_action_sequence(
            env2, action_seq_mpc)['observations']

        vis['mpc_3D'].delete()
        vis['mpc_GT_3D'].delete()

        L = len(obs_mpc_gt)
        print("L", L)
        if L == 0:
            break
        for i in range(L):

            # ground truth rollout of the MPC action_seq
            name = "mpc_GT_3D/%d" % (i)
            T_W_obj = DrakePusherSliderEnv.object_position_from_observation(
                obs_mpc_gt[i])
            vis[name].set_object(triad(scale=0.1))
            vis[name].set_transform(T_W_obj)

        action_cur = action_seq_mpc[0]

        print("action_cur", action_cur)
        print("action_GT", initial_cond['action'])
        input("press Enter to continue")

        # add observation actions to the episode
        obs_cur = env.get_observation()
        episode.replace_observation_action(obs_cur, action_cur)

        # step the simulator
        env.step(action_cur)

        # update the trajectories, in case we aren't replanning
        action_seq_mpc = action_seq_mpc[1:]
        state_pred_mpc = state_pred_mpc[1:]

        pose_error = compute_pose_error(env_obs_rollout_gt[-1], obs_cur)

        print("position_error: %.3f" % (pose_error['position_error']))
        print("angle error degrees: %.3f" %
              (pose_error['angle_error_degrees']))

    obs_final = env.get_observation()

    pose_error = compute_pose_error(env_obs_rollout_gt[-1], obs_final)

    print("position_error: %.3f" % (pose_error['position_error']))
    print("angle error degrees: %.3f" % (pose_error['angle_error_degrees']))
示例#8
0
def localize_transporter_keypoints(
    model_kp,
    rgb=None,  # np.array [H, W, 3]
    mask=None,  # np.array [H, W]
    depth=None,  # np.array [H, W] in meters
    K=None,  # camera intrinsics matrix [3,3]
    T_world_camera=None,
):

    processed_image_dict = process_image(rgb=rgb,
                                         config=model_kp.config,
                                         mask=mask)

    rgb_input = processed_image_dict['image']
    crop_param = processed_image_dict['crop_param']

    H, W, _ = rgb.shape

    # cast them to torch so we can use them later
    crop_param_torch = None

    if crop_param is not None:
        crop_param_torch = dict()
        for key, val in crop_param.items():
            # cast to torch and add batch dim
            crop_param_torch[key] = torch.Tensor([val]).unsqueeze(0)

    rgb_to_tensor = ImageTupleDataset.make_rgb_image_to_tensor_transform()
    rgb_input_tensor = rgb_to_tensor(rgb_input).unsqueeze(
        0).cuda()  # make it batch

    # [1, n_kp, 2]
    kp_pred = model_kp.predict_keypoint(rgb_input_tensor).cpu()

    # if it was cropped, then an extra step is needed
    # kp_pred_full_pixels = None
    # if crop_param is not None:
    #     kp_pred_full_pixels = map_cropped_pixels_to_full_pixels_torch(kp_pred,
    #                                                                  crop_param_torch)
    # else:
    #     kp_pred_full_pixels = kp_pred
    #
    # xy = kp_pred_full_pixels.clone()
    # xy[:, :, 0] = (xy[:, :, 0]) * 2.0 / W - 1.0
    # xy[:, :, 1] = (xy[:, :, 1]) * 2.0 / H - 1.0
    #
    # # get depth values
    # kp_pred_full_pixels_int = kp_pred_full_pixels.type(torch.LongTensor)

    full_image_coords = map_transporter_keypoints_to_full_image(
        kp_pred, crop_params=crop_param_torch, full_image_size=(H, W))

    uv = full_image_coords['uv']
    uv_int = full_image_coords['uv_int']
    xy = full_image_coords['xy']

    depth_batch = torch.from_numpy(depth).unsqueeze(0)

    z = None
    pts_world_frame = None
    pts_camera_frame = None
    if depth is not None:
        z = pdc_utils.index_into_batch_image_tensor(depth_batch.unsqueeze(1),
                                                    uv_int.transpose(1, 2))

        z = z.squeeze(1)

        K_inv = np.linalg.inv(K)
        K_inv_torch = torch.Tensor(K_inv).unsqueeze(0)  # add batch dim
        pts_camera_frame = pdc_torch_utils.pinhole_unprojection(
            uv, z, K_inv_torch)

        # print("pts_camera_frame.shape", pts_camera_frame.shape)

        pts_world_frame = pdc_torch_utils.transform_points_3D(
            torch.from_numpy(T_world_camera), pts_camera_frame)

        pts_world_frame_np = torch_utils.cast_to_numpy(
            pts_world_frame.squeeze())
        pts_camera_frame_np = torch_utils.cast_to_numpy(
            pts_camera_frame.squeeze())

    uv_np = torch_utils.cast_to_numpy(uv.squeeze())
    uv_int_np = torch_utils.cast_to_numpy(uv_int.squeeze())

    return {
        'uv': uv_np,
        'uv_int': uv_int_np,
        'xy': torch_utils.cast_to_numpy(xy.squeeze()),
        'z': torch_utils.cast_to_numpy(z.squeeze()),
        'pts_world_frame': pts_world_frame_np,
        'pts_camera_frame': pts_camera_frame_np,
        'pts_W': pts_world_frame_np,  # just for backwards compatibility
        'kp_pred': kp_pred,
        'rgb_input': rgb_input,
    }
def precompute_transporter_keypoints(
    multi_episode_dict,
    model_kp,
    output_dir,  # str
    batch_size=10,
    num_workers=10,
    camera_names=None,
    model_file=None,
):

    assert model_file is not None
    metadata = dict()
    metadata['model_file'] = model_file

    save_yaml(metadata, os.path.join(output_dir, 'metadata.yaml'))
    start_time = time.time()

    log_freq = 10

    device = next(model_kp.parameters()).device
    model_kp = model_kp.eval()  # make sure model is in eval mode

    image_data_config = {
        'rgb': True,
        'mask': True,
        'depth_int16': True,
    }

    # build all the dataset
    datasets = {}
    dataloaders = {}
    for episode_name, episode in multi_episode_dict.items():
        single_episode_dict = {episode_name: episode}
        config = model_kp.config

        # need to do this since transporter type data sampling only works
        # with tuple_size = 1
        dataset_config = copy.deepcopy(config)
        dataset_config['dataset']['use_transporter_type_data_sampling'] = False

        datasets[episode_name] = ImageTupleDataset(
            dataset_config,
            single_episode_dict,
            phase="all",
            image_data_config=image_data_config,
            tuple_size=1,
            compute_K_inv=True,
            camera_names=camera_names)

        dataloaders[episode_name] = DataLoader(datasets[episode_name],
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=False)

    episode_counter = 0
    num_episodes = len(multi_episode_dict)

    for episode_name, dataset in datasets.items():
        episode_counter += 1
        print("\n\n")

        episode = multi_episode_dict[episode_name]
        hdf5_file = None
        try:
            hdf5_file = os.path.basename(episode.image_data_file)
        except AttributeError:
            hdf5_file = "%s.h5" % (episode.name)

        hdf5_file_fullpath = os.path.join(output_dir, hdf5_file)

        str_split = hdf5_file_fullpath.split(".")
        assert len(str_split) == 2
        pickle_file_fullpath = str_split[0] + ".p"

        # print("episode_name", episode_name)
        # print("hdf5_file_fullpath", hdf5_file_fullpath)
        # print("pickle_file_fullpath", pickle_file_fullpath)

        if os.path.isfile(hdf5_file_fullpath):
            os.remove(hdf5_file_fullpath)

        if os.path.isfile(pickle_file_fullpath):
            os.remove(pickle_file_fullpath)

        episode_keypoint_data = dict()

        episode_start_time = time.time()
        with h5py.File(hdf5_file_fullpath, 'w') as hf:
            for i, data in enumerate(dataloaders[episode_name]):
                data = data[0]
                rgb_crop_tensor = data['rgb_crop_tensor'].to(device)
                crop_params = data['crop_param']
                depth_int16 = data['depth_int16']
                key_tree_joined = data['key_tree_joined']

                # print("\n\n i = %d, idx = %d, camera_name = %s" %(i, data['idx'], data['camera_name']))

                depth = depth_int16.float() * 1.0 / DEPTH_IM_SCALE

                if (i % log_freq) == 0:
                    log_msg = "computing [%d/%d][%d/%d]" % (
                        episode_counter, num_episodes, i + 1,
                        len(dataloaders[episode_name]))
                    print(log_msg)

                B = rgb_crop_tensor.shape[0]

                _, H, W, _ = data['rgb'].shape

                kp_pred = None
                kp_pred_full_pixels = None
                with torch.no_grad():
                    kp_pred = model_kp.predict_keypoint(rgb_crop_tensor)

                    # [B, n_kp, 2]
                    kp_pred_full_pixels = transporter_utils.map_cropped_pixels_to_full_pixels_torch(
                        kp_pred, crop_params)

                    xy = kp_pred_full_pixels.clone()
                    xy[:, :, 0] = (xy[:, :, 0]) * 2.0 / W - 1.0
                    xy[:, :, 1] = (xy[:, :, 1]) * 2.0 / H - 1.0

                    # debug
                    # print("xy[0,0]", xy[0,0])

                    # get depth values
                    kp_pred_full_pixels_int = kp_pred_full_pixels.type(
                        torch.LongTensor)

                    z = pdc_utils.index_into_batch_image_tensor(
                        depth.unsqueeze(1),
                        kp_pred_full_pixels_int.transpose(1, 2))

                    z = z.squeeze(1)
                    K_inv = data['K_inv']
                    pts_camera_frame = pdc_torch_utils.pinhole_unprojection(
                        kp_pred_full_pixels, z, K_inv)

                    # print("pts_camera_frame.shape", pts_camera_frame.shape)

                    pts_world_frame = pdc_torch_utils.transform_points_3D(
                        data['T_W_C'], pts_camera_frame)

                    # print("pts_world_frame.shape", pts_world_frame.shape)

                for j in range(B):

                    keypoint_data = {}

                    # this goes from [-1,1]
                    keypoint_data['xy'] = torch_utils.cast_to_numpy(xy[j])
                    keypoint_data['uv'] = torch_utils.cast_to_numpy(
                        kp_pred_full_pixels[j])
                    keypoint_data['uv_int'] = torch_utils.cast_to_numpy(
                        kp_pred_full_pixels_int[j])
                    keypoint_data['z'] = torch_utils.cast_to_numpy(z[j])
                    keypoint_data[
                        'pos_world_frame'] = torch_utils.cast_to_numpy(
                            pts_world_frame[j])
                    keypoint_data[
                        'pos_camera_frame'] = torch_utils.cast_to_numpy(
                            pts_camera_frame[j])

                    # save out some data in both hdf5 and pickle format
                    for key, val in keypoint_data.items():
                        save_key = key_tree_joined[
                            j] + "/transporter_keypoints/%s" % (key)
                        hf.create_dataset(save_key, data=val)
                        episode_keypoint_data[save_key] = val

            save_pickle(episode_keypoint_data, pickle_file_fullpath)
            print("duration: %.3f seconds" %
                  (time.time() - episode_start_time))
示例#10
0
def precompute_descriptor_keypoints(multi_episode_dict,
                                    model,
                                    output_dir,  # str
                                    ref_descriptors_metadata,
                                    batch_size=10,
                                    num_workers=10,
                                    localization_type="spatial_expectation",  # ['spatial_expectation', 'argmax']
                                    compute_3D=True,  # in world frame
                                    camera_names=None,
                                    ):
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    start_time = time.time()

    log_freq = 10

    device = next(model.parameters()).device
    model = model.eval()  # make sure model is in eval mode

    # build all the dataset
    datasets = {}
    dataloaders = {}
    for episode_name, episode in iteritems(multi_episode_dict):
        single_episode_dict = {episode_name: episode}
        config = None
        datasets[episode_name] = ImageDataset(config,
                                              single_episode_dict,
                                              phase="all",
                                              camera_names=camera_names)
        dataloaders[episode_name] = DataLoader(datasets[episode_name],
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=False)

    # K = num_ref_descriptors
    metadata = ref_descriptors_metadata
    ref_descriptors = torch.Tensor(metadata['ref_descriptors'])
    ref_descriptors = ref_descriptors.cuda()
    K, _ = ref_descriptors.shape

    metadata_file = os.path.join(output_dir, 'metadata.p')
    save_pickle(metadata, metadata_file)

    episode_counter = 0
    num_episodes = len(multi_episode_dict)

    for episode_name, dataset in iteritems(datasets):
        episode_counter += 1
        print("\n\n")

        episode = multi_episode_dict[episode_name]
        hdf5_file = None
        try:
            hdf5_file = os.path.basename(episode.image_data_file)
        except AttributeError:
            hdf5_file = "%s.h5" % (episode.name)

        hdf5_file_fullpath = os.path.join(output_dir, hdf5_file)

        str_split = hdf5_file_fullpath.split(".")
        assert len(str_split) == 2
        pickle_file_fullpath = str_split[0] + ".p"

        # print("hdf5_file_fullpath", hdf5_file_fullpath)
        # print("pickle_file_fullpath", pickle_file_fullpath)

        if os.path.isfile(hdf5_file_fullpath):
            os.remove(hdf5_file_fullpath)

        if os.path.isfile(pickle_file_fullpath):
            os.remove(pickle_file_fullpath)

        dataloader = dataloaders[episode_name]


        episode_keypoint_data = dict()

        episode_start_time = time.time()
        with h5py.File(hdf5_file_fullpath, 'w') as hf:
            for i, data in enumerate(dataloaders[episode_name]):
                rgb_tensor = data['rgb_tensor'].to(device)
                key_tree_joined = data['key_tree_joined']

                if (i % log_freq) == 0:
                    log_msg = "computing [%d/%d][%d/%d]" % (episode_counter, num_episodes, i + 1, len(dataloader))
                    print(log_msg)

                # don't use gradients
                tmp_time = time.time()
                with torch.no_grad():
                    out = model.forward(rgb_tensor)

                    # [B, D, H, W]
                    des_img = out['descriptor_image']

                    B, _, H, W = rgb_tensor.shape

                    # [B, N, 2]
                    batch_indices = None
                    preds_3d = None
                    if localization_type == "spatial_expectation":
                        sigma_descriptor_heatmap = 5  # default
                        try:
                            sigma_descriptor_heatmap = model.config['network']['sigma_descriptor_heatmap']
                        except:
                            pass

                        # print("ref_descriptors.shape", ref_descriptors.shape)
                        # print("des_img.shape", des_img.shape)
                        d = get_spatial_expectation(ref_descriptors,
                                                    des_img,
                                                    sigma=sigma_descriptor_heatmap,
                                                    type='exp',
                                                    return_heatmap=True,
                                                    )

                        batch_indices = d['uv']

                        # [B, K, H, W]
                        if compute_3D:
                            # [B*K, H, W]
                            heatmaps_no_batch = d['heatmap_no_batch']

                            # [B, H, W]
                            depth = data['depth_int16'].to(device)

                            # expand depth images and convert to meters, instead of mm
                            # [B, K, H, W]
                            depth_expand = depth.unsqueeze(1).expand([B, K, H, W]).reshape([B * K, H, W])
                            depth_expand = depth_expand.type(torch.FloatTensor) / constants.DEPTH_IM_SCALE
                            depth_expand = depth_expand.to(heatmaps_no_batch.device)

                            pred_3d = get_integral_preds_3d(heatmaps_no_batch,
                                                            depth_images=depth_expand,
                                                            compute_uv=True)

                            pred_3d['uv'] = pred_3d['uv'].reshape([B, K, 2])
                            pred_3d['xy'] = pred_3d['xy'].reshape([B, K, 2])
                            pred_3d['z'] = pred_3d['z'].reshape([B, K])


                    elif localization_type == "argmax":
                        # localize descriptors
                        best_match_dict = get_argmax_l2(ref_descriptors,
                                                        des_img)

                        # [B, N, 2]
                        # where N is num_ref_descriptors
                        batch_indices = best_match_dict['indices']
                    else:
                        raise ValueError("unknown localization type: %s" % (localization_type))

                    print("computing keypoints took", time.time() - tmp_time)

                    tmp_time = time.time()
                    # iterate over elements in the batch
                    for j in range(B):
                        keypoint_data = {} # dict that stores information to save out

                        # [N,2]
                        # indices = batch_indices[j].cpu().numpy()
                        # key = key_tree_joined[j] + "/descriptor_keypoints"


                        # hf.create_dataset(key, data=indices)
                        # keypoint_indices_dict[key] = indices

                        # stored 3D keypoint locations (in both camera and world frame)
                        if pred_3d is not None:


                            # key_3d_W = key_tree_joined[j] + "/descriptor_keypoints_3d_world_frame"
                            # key_3d_C = key_tree_joined[j] + "/descriptor_keypoints_3d_camera_frame"

                            # T_W_C = data['T_world_camera'][j].cpu().numpy()
                            # K_matrix = data['K'][j].cpu().numpy()

                            T_W_C = torch_utils.cast_to_numpy(data['T_world_camera'][j])
                            K_matrix = torch_utils.cast_to_numpy(data['K'][j])

                            uv = torch_utils.cast_to_numpy(pred_3d['uv'][j])
                            xy = torch_utils.cast_to_numpy(pred_3d['xy'][j])
                            z = torch_utils.cast_to_numpy(pred_3d['z'][j])

                            # [K, 3]
                            # this is in camera frame
                            pts_3d_C = pdc_utils.pinhole_unprojection(uv, z, K_matrix)
                            # hf.create_dataset(key_3d_C, data=pts_3d_C)
                            # keypoint_indices_dict[key_3d_C] = pts_3d_C

                            # project into world frame
                            pts_3d_W = transform_utils.transform_points_3D(transform=T_W_C,
                                                                           points=pts_3d_C)

                            # hf.create_dataset(key_3d_W, data=pts_3d_W)
                            # keypoint_indices_dict[key_3d_W] = pts_3d_W

                            keypoint_data['xy'] = torch_utils.cast_to_numpy(xy)
                            keypoint_data['uv'] = torch_utils.cast_to_numpy(uv)
                            keypoint_data['z'] = torch_utils.cast_to_numpy(z)
                            keypoint_data['pos_world_frame'] = torch_utils.cast_to_numpy(pts_3d_W)
                            keypoint_data['pos_camera_frame'] = torch_utils.cast_to_numpy(pts_3d_C)

                        # save out some data in both hdf5 and pickle format
                        for key, val in keypoint_data.items():
                            save_key = key_tree_joined[j] + "/descriptor_keypoints/%s" % (key)
                            hf.create_dataset(save_key, data=val)
                            episode_keypoint_data[save_key] = val


                    print("saving to disk took", time.time() - tmp_time)

        # save_pickle(keypoint_indices_dict, pickle_file_fullpath)
        save_pickle(episode_keypoint_data, pickle_file_fullpath)
        print("duration: %.3f seconds" % (time.time() - episode_start_time))

    print("total time to compute descriptors: %.3f seconds" % (time.time() - start_time))