示例#1
0
 def set_worker_params(self,
                       value_residual_state_dict,
                       feature_norm_dict,
                       kdtrees_serialized=None,
                       residual_dynamics_state_dict=None):
     # Load all worker parameters
     self.state_value_residual.load_state_dict(value_residual_state_dict)
     self.features_normalizer.load_state_dict(feature_norm_dict)
     # Reconfigure controller heuristic function
     self.controller.reconfigure_heuristic(
         lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                              .state_value_residual))
     if kdtrees_serialized:
         self.kdtrees_set = True
         self.kdtrees = pickle.loads(kdtrees_serialized)
         self.controller.reconfigure_discrepancy(
             lambda obs, ac: get_discrepancy_neighbors(
                 obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                 neighbor_radius))
     if residual_dynamics_state_dict:
         self.residual_dynamics_set = True
         if self.args.agent == 'mbpo':
             self.residual_dynamics.load_state_dict(
                 residual_dynamics_state_dict)
         elif self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
             self.residual_dynamics = pickle.loads(
                 residual_dynamics_state_dict)
         else:
             raise NotImplementedError
         self.controller.reconfigure_residual_dynamics(
             self.get_residual_dynamics)
     return
示例#2
0
    def _update_state_value_residual(self):
        transitions = self.memory.sample_internal_world_memory(
            self.args.batch_size)

        obs, g, ag = transitions['obs'], transitions['g'], transitions['ag']
        features, heuristic = transitions['features'], transitions['heuristic']
        targets = []

        for i in range(self.args.batch_size):
            observation = {}
            observation['observation'] = obs[i].copy()
            observation['desired_goal'] = g[i].copy()
            observation['achieved_goal'] = ag[i].copy()

            _, info = self.controller.act(observation)
            targets.append(info['best_node_f'])
        targets = np.array(targets).reshape(-1, 1)
        features_norm = self.features_normalizer.normalize(features)

        inputs_norm = torch.as_tensor(features_norm, dtype=torch.float32)
        targets = torch.as_tensor(targets, dtype=torch.float32)

        h_tensor = torch.as_tensor(heuristic,
                                   dtype=torch.float32).unsqueeze(-1)
        # Compute target residuals
        target_residual_tensor = targets - h_tensor
        # Clip target residual tenssor to avoid value function less than zero
        target_residual_tensor = torch.max(target_residual_tensor, -h_tensor)
        # Clip target residual tensor to avoid value function greater than horizon
        if self.args.offline:
            target_residual_tensor = torch.min(
                target_residual_tensor,
                self.env_params['offline_max_timesteps'] - h_tensor)

        # COmpute predictions
        residual_tensor = self.state_value_residual(inputs_norm)
        # COmpute loss
        state_value_residual_loss = (residual_tensor -
                                     target_residual_tensor).pow(2).mean()

        # Backprop and step
        self.state_value_residual_optim.zero_grad()
        state_value_residual_loss.backward()
        self.state_value_residual_optim.step()

        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))

        return state_value_residual_loss
示例#3
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()

        # Reset the environment
        observation = self.env.reset()
        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))
        # Configure dynamics for controller
        if self.args.agent == 'rts':
            self.controller.reconfigure_discrepancy(
                lambda obs, ac: get_discrepancy_neighbors(
                    obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                    neighbor_radius))

        # Configure dynamics for controller
        if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
            self.controller.reconfigure_residual_dynamics(
                self.get_residual_dynamics)
        # Count of total number of steps
        total_n_steps = 0
        while True:
            obs = observation['observation']
            g = observation['desired_goal']
            qpos = observation['sim_state'].qpos
            qvel = observation['sim_state'].qvel
            # Get action from the controller
            ac, info = self.controller.act(observation)
            if self.args.agent == 'rts':
                assert self.controller.residual_dynamics_fn is None
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                assert self.controller.discrepancy_fn is None
            # Get discrete action index
            ac_ind = self.env.discrete_actions[tuple(ac)]
            # Get the next observation
            next_observation, rew, _, _ = self.env.step(ac)
            # if np.array_equal(obs, next_observation['observation']):
            #     import ipdb
            #     ipdb.set_trace()
            # print('ACTION', ac)
            # print('VALUE PREDICTED', info['start_node_h'])
            # print('COST', -rew)
            if self.args.render:
                self.env.render()
            total_n_steps += 1
            # Check if we reached the goal
            if self.env.env._is_success(next_observation['achieved_goal'], g):
                print('REACHED GOAL!')
                break
            # Get the next obs
            obs_next = next_observation['observation']
            # Get the sim next obs
            set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(),
                                   g.copy())
            next_observation_sim, _, _, _ = self.planning_env.step(ac)
            obs_sim_next = next_observation_sim['observation']
            # Store transition
            transition = [
                obs.copy(),
                g.copy(), ac_ind,
                qpos.copy(),
                qvel.copy(),
                obs_next.copy(),
                obs_sim_next.copy()
            ]
            dynamics_losses = []
            # RTS
            if self.args.agent == 'rts' and self._check_dynamics_transition(
                    transition):
                # print('DISCREPANCY IN DYNAMICS')
                self.memory.store_real_world_transition(transition)
                # # Fit model
                self._update_discrepancy_model()
            # MBPO
            elif self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                self.memory.store_real_world_transition(transition)
                # Update the dynamics
                if self.args.agent == 'mbpo':
                    for _ in range(self.args.n_online_planning_updates):
                        # Update dynamics
                        loss = self._update_batch_residual_dynamics()
                        dynamics_losses.append(loss.item())
                else:
                    loss = self._update_residual_dynamics()
                    dynamics_losses.append(loss)
            # # Plan in the model
            value_loss = self.plan_online_in_model(
                n_planning_updates=self.args.n_online_planning_updates,
                initial_observation=copy.deepcopy(observation))

            # Log
            logger.record_tabular('n_steps', total_n_steps)
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                logger.record_tabular('dynamics loss',
                                      np.mean(dynamics_losses))
            logger.record_tabular('residual_loss', value_loss)
            # logger.dump_tabular()
            # Move to next iteration
            observation = copy.deepcopy(next_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps
示例#4
0
    def _update_state_value_residual(self):
        # Sample transitions
        transitions = self.memory.sample_internal_world_memory(
            self.args.batch_size)
        qpos, qvel = transitions['qpos'], transitions['qvel']
        obs, g, ag = transitions['obs'], transitions['g'], transitions['ag']
        # features, heuristic = transitions['features'], transitions['heuristic']

        # Compute target by restarting search from the sampled states
        num_workers = self.args.n_rts_workers
        if self.args.batch_size < self.args.n_rts_workers:
            num_workers = self.args.batch_size
        num_per_worker = self.args.batch_size // num_workers
        # Put residual in object store
        value_target_residual_state_dict_id = ray.put(
            self.state_value_target_residual.state_dict())
        # Put normalizer in object store
        feature_norm_dict_id = ray.put(self.features_normalizer.state_dict())
        # Put knn dynamics residuals in object store
        if self.args.agent == 'rts':
            kdtrees_serialized_id = ray.put(pickle.dumps(self.kdtrees))
        else:
            kdtrees_serialized_id = None
        # Put residual dynamics in object store
        if self.args.agent == 'mbpo':
            residual_dynamics_state_dict_id = ray.put(
                self.residual_dynamics.state_dict())
        elif self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
            residual_dynamics_state_dict_id = ray.put(
                pickle.dumps(self.residual_dynamics))
        else:
            residual_dynamics_state_dict_id = None
        results, count = [], 0
        # Set all workers num expansions
        set_workers_num_expansions(self.internal_rollout_workers,
                                   self.args.n_offline_expansions)
        for worker_id in range(num_workers):
            if worker_id == num_workers - 1:
                # last worker takes the remaining load
                num_per_worker = self.args.batch_size - count
            # Set parameters
            ray.get(
                self.internal_rollout_workers[worker_id].set_worker_params.
                remote(
                    value_residual_state_dict=
                    value_target_residual_state_dict_id,
                    feature_norm_dict=feature_norm_dict_id,
                    kdtrees_serialized=kdtrees_serialized_id,
                    residual_dynamics_state_dict=residual_dynamics_state_dict_id
                ))
            # Send Job
            results.append(self.internal_rollout_workers[worker_id].
                           lookahead_batch.remote(
                               obs[count:count + num_per_worker],
                               g[count:count + num_per_worker],
                               ag[count:count + num_per_worker],
                               qpos[count:count + num_per_worker],
                               qvel[count:count + num_per_worker]))
            count += num_per_worker
        # Check if all transitions have targets
        assert count == self.args.batch_size
        # Get all targets
        results = ray.get(results)
        target_infos = [item for sublist in results for item in sublist]

        # Extract the states, their features and their corresponding targets
        obs_closed = [
            k.obs['observation'].copy() for info in target_infos
            for k in info['closed']
        ]
        goals_closed = [
            k.obs['desired_goal'].copy() for info in target_infos
            for k in info['closed']
        ]
        heuristic_closed = [
            self.controller.heuristic_obs_g(obs_closed[i], goals_closed[i])
            for i in range(len(obs_closed))
        ]
        features_closed = [
            self.env.extract_features(obs_closed[i], goals_closed[i])
            for i in range(len(obs_closed))
        ]
        targets_closed = [
            info['best_node_f'] - k._g for info in target_infos
            for k in info['closed']
        ]

        targets_closed = np.array(targets_closed).reshape(-1, 1)
        targets_tensor = torch.as_tensor(targets_closed, dtype=torch.float32)
        # Set all workers num expansions
        set_workers_num_expansions(self.internal_rollout_workers,
                                   self.args.n_expansions)
        # Normalize features
        inputs_norm = torch.as_tensor(
            self.features_normalizer.normalize(features_closed),
            dtype=torch.float32)
        heuristic_tensor = torch.as_tensor(heuristic_closed,
                                           dtype=torch.float32).view(-1, 1)

        # Compute target residuals
        target_residual_tensor = targets_tensor - heuristic_tensor
        # Clip target residual tenssor to avoid value function less than zero
        target_residual_tensor = torch.max(target_residual_tensor,
                                           -heuristic_tensor)
        # Clip target residual tensor to avoid value function greater than horizon
        if self.args.offline:
            target_residual_tensor = torch.min(
                target_residual_tensor,
                self.env_params['offline_max_timesteps'] - heuristic_tensor)

        # COmpute predictions
        residual_tensor = self.state_value_residual(inputs_norm)
        # COmpute loss
        state_value_residual_loss = (residual_tensor -
                                     target_residual_tensor).pow(2).mean()

        # Backprop and step
        self.state_value_residual_optim.zero_grad()
        state_value_residual_loss.backward()
        self.state_value_residual_optim.step()

        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))

        return state_value_residual_loss
示例#5
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()

        # Reset the environment
        observation = self.env.reset()
        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))

        # Configure dynamics for controller
        self.controller.reconfigure_residual_dynamics(
            self.get_residual_dynamics)

        # Count total number of steps
        total_n_steps = 0
        while True:
            obs = observation['observation']
            g = observation['desired_goal']
            qpos = observation['sim_state'].qpos
            qvel = observation['sim_state'].qvel

            # Get action from controller
            ac, info = self.controller.act(observation)
            # Get discrete action index
            ac_ind = self.env.discrete_actions[tuple(ac)]
            # Get next observation
            next_observation, rew, _, _ = self.env.step(ac)
            # Increment counter
            total_n_steps += 1
            if self.env.env._is_success(next_observation['achieved_goal'], g):
                print('REACHED GOAL!')
                break
            if self.args.render:
                self.env.render()
            # Get next obs
            obs_next = next_observation['observation']
            # GEt sim next obs
            set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(),
                                   g.copy())
            next_observation_sim, _, _, _ = self.planning_env.step(ac)
            obs_sim_next = next_observation_sim['observation']
            # Store transition in real world memory
            transition = [
                obs.copy(),
                g.copy(), ac_ind,
                qpos.copy(),
                qvel.copy(),
                obs_next.copy(),
                obs_sim_next.copy()
            ]
            self.memory.store_real_world_transition(transition)

            # Update the dynamics
            dynamics_losses = []
            for _ in range(self.args.n_online_planning_updates):
                # Update dynamics
                loss = self._update_residual_dynamics()
                dynamics_losses.append(loss.item())

            # Update state value residual
            value_loss = self.plan_online_in_model(
                self.args.n_online_planning_updates,
                initial_observation=copy.deepcopy(observation))
            # log
            logger.record_tabular('n_steps', total_n_steps)
            logger.record_tabular('dynamics_loss', np.mean(dynamics_losses))
            logger.record_tabular('residual_loss', value_loss)
            logger.dump_tabular()

            # Move to next iteration
            observation = copy.deepcopy(next_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps