Пример #1
0
    def _update_discrepancy_model(self):
        # For now updating the KDTrees in batches, which is not really efficient
        # Future TODO: is to make it incremental and efficient
        # Get all transitions with discrepancy in dynamics
        transitions = self.memory.sample_real_world_memory()
        # Extract relevant quantities
        obs, ac_ind = transitions['obs'], transitions['actions']

        # Construct 4D points
        # obs[0:2] is gripper 2D position
        # obs[3:5] is object 2D position
        real_pos = np.concatenate([obs[:, 0:2], obs[:, 3:5]], axis=1)

        # Add it to the respective KDTrees
        for i in range(self.env_params['num_actions']):
            # Get points corresponding to this action
            ac_mask = ac_ind == i
            points = real_pos[ac_mask]

            if points.shape[0] == 0:
                # No data points for this action
                continue

            # Fit the KDTree
            self.kdtrees[i] = KDTree(points)

        # Configure discrepancy model for controller
        assert self.args.agent == 'rts'
        self.controller.reconfigure_discrepancy(
            lambda obs, ac: get_discrepancy_neighbors(
                obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                neighbor_radius))

        return
Пример #2
0
 def set_worker_params(self,
                       value_residual_state_dict,
                       feature_norm_dict,
                       kdtrees_serialized=None,
                       residual_dynamics_state_dict=None):
     # Load all worker parameters
     self.state_value_residual.load_state_dict(value_residual_state_dict)
     self.features_normalizer.load_state_dict(feature_norm_dict)
     # Reconfigure controller heuristic function
     self.controller.reconfigure_heuristic(
         lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                              .state_value_residual))
     if kdtrees_serialized:
         self.kdtrees_set = True
         self.kdtrees = pickle.loads(kdtrees_serialized)
         self.controller.reconfigure_discrepancy(
             lambda obs, ac: get_discrepancy_neighbors(
                 obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                 neighbor_radius))
     if residual_dynamics_state_dict:
         self.residual_dynamics_set = True
         if self.args.agent == 'mbpo':
             self.residual_dynamics.load_state_dict(
                 residual_dynamics_state_dict)
         elif self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
             self.residual_dynamics = pickle.loads(
                 residual_dynamics_state_dict)
         else:
             raise NotImplementedError
         self.controller.reconfigure_residual_dynamics(
             self.get_residual_dynamics)
     return
Пример #3
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()

        # Reset the environment
        observation = self.env.reset()
        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))
        # Configure dynamics for controller
        if self.args.agent == 'rts':
            self.controller.reconfigure_discrepancy(
                lambda obs, ac: get_discrepancy_neighbors(
                    obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                    neighbor_radius))

        # Configure dynamics for controller
        if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
            self.controller.reconfigure_residual_dynamics(
                self.get_residual_dynamics)
        # Count of total number of steps
        total_n_steps = 0
        while True:
            obs = observation['observation']
            g = observation['desired_goal']
            qpos = observation['sim_state'].qpos
            qvel = observation['sim_state'].qvel
            # Get action from the controller
            ac, info = self.controller.act(observation)
            if self.args.agent == 'rts':
                assert self.controller.residual_dynamics_fn is None
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                assert self.controller.discrepancy_fn is None
            # Get discrete action index
            ac_ind = self.env.discrete_actions[tuple(ac)]
            # Get the next observation
            next_observation, rew, _, _ = self.env.step(ac)
            # if np.array_equal(obs, next_observation['observation']):
            #     import ipdb
            #     ipdb.set_trace()
            # print('ACTION', ac)
            # print('VALUE PREDICTED', info['start_node_h'])
            # print('COST', -rew)
            if self.args.render:
                self.env.render()
            total_n_steps += 1
            # Check if we reached the goal
            if self.env.env._is_success(next_observation['achieved_goal'], g):
                print('REACHED GOAL!')
                break
            # Get the next obs
            obs_next = next_observation['observation']
            # Get the sim next obs
            set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(),
                                   g.copy())
            next_observation_sim, _, _, _ = self.planning_env.step(ac)
            obs_sim_next = next_observation_sim['observation']
            # Store transition
            transition = [
                obs.copy(),
                g.copy(), ac_ind,
                qpos.copy(),
                qvel.copy(),
                obs_next.copy(),
                obs_sim_next.copy()
            ]
            dynamics_losses = []
            # RTS
            if self.args.agent == 'rts' and self._check_dynamics_transition(
                    transition):
                # print('DISCREPANCY IN DYNAMICS')
                self.memory.store_real_world_transition(transition)
                # # Fit model
                self._update_discrepancy_model()
            # MBPO
            elif self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                self.memory.store_real_world_transition(transition)
                # Update the dynamics
                if self.args.agent == 'mbpo':
                    for _ in range(self.args.n_online_planning_updates):
                        # Update dynamics
                        loss = self._update_batch_residual_dynamics()
                        dynamics_losses.append(loss.item())
                else:
                    loss = self._update_residual_dynamics()
                    dynamics_losses.append(loss)
            # # Plan in the model
            value_loss = self.plan_online_in_model(
                n_planning_updates=self.args.n_online_planning_updates,
                initial_observation=copy.deepcopy(observation))

            # Log
            logger.record_tabular('n_steps', total_n_steps)
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                logger.record_tabular('dynamics loss',
                                      np.mean(dynamics_losses))
            logger.record_tabular('residual_loss', value_loss)
            # logger.dump_tabular()
            # Move to next iteration
            observation = copy.deepcopy(next_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps