def set_worker_params(self, value_residual_state_dict, feature_norm_dict, kdtrees_serialized=None, residual_dynamics_state_dict=None): # Load all worker parameters self.state_value_residual.load_state_dict(value_residual_state_dict) self.features_normalizer.load_state_dict(feature_norm_dict) # Reconfigure controller heuristic function self.controller.reconfigure_heuristic( lambda obs: get_state_value_residual(obs, self.preproc_inputs, self .state_value_residual)) if kdtrees_serialized: self.kdtrees_set = True self.kdtrees = pickle.loads(kdtrees_serialized) self.controller.reconfigure_discrepancy( lambda obs, ac: get_discrepancy_neighbors( obs, ac, self.construct_4d_point, self.kdtrees, self.args. neighbor_radius)) if residual_dynamics_state_dict: self.residual_dynamics_set = True if self.args.agent == 'mbpo': self.residual_dynamics.load_state_dict( residual_dynamics_state_dict) elif self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': self.residual_dynamics = pickle.loads( residual_dynamics_state_dict) else: raise NotImplementedError self.controller.reconfigure_residual_dynamics( self.get_residual_dynamics) return
def _update_state_value_residual(self): transitions = self.memory.sample_internal_world_memory( self.args.batch_size) obs, g, ag = transitions['obs'], transitions['g'], transitions['ag'] features, heuristic = transitions['features'], transitions['heuristic'] targets = [] for i in range(self.args.batch_size): observation = {} observation['observation'] = obs[i].copy() observation['desired_goal'] = g[i].copy() observation['achieved_goal'] = ag[i].copy() _, info = self.controller.act(observation) targets.append(info['best_node_f']) targets = np.array(targets).reshape(-1, 1) features_norm = self.features_normalizer.normalize(features) inputs_norm = torch.as_tensor(features_norm, dtype=torch.float32) targets = torch.as_tensor(targets, dtype=torch.float32) h_tensor = torch.as_tensor(heuristic, dtype=torch.float32).unsqueeze(-1) # Compute target residuals target_residual_tensor = targets - h_tensor # Clip target residual tenssor to avoid value function less than zero target_residual_tensor = torch.max(target_residual_tensor, -h_tensor) # Clip target residual tensor to avoid value function greater than horizon if self.args.offline: target_residual_tensor = torch.min( target_residual_tensor, self.env_params['offline_max_timesteps'] - h_tensor) # COmpute predictions residual_tensor = self.state_value_residual(inputs_norm) # COmpute loss state_value_residual_loss = (residual_tensor - target_residual_tensor).pow(2).mean() # Backprop and step self.state_value_residual_optim.zero_grad() state_value_residual_loss.backward() self.state_value_residual_optim.step() # Configure heuristic for controller self.controller.reconfigure_heuristic( lambda obs: get_state_value_residual(obs, self.preproc_inputs, self .state_value_residual)) return state_value_residual_loss
def learn_online_in_real_world(self, max_timesteps=None): # If any pre-existing model is given, load it if self.args.load_dir: self.load() # Reset the environment observation = self.env.reset() # Configure heuristic for controller self.controller.reconfigure_heuristic( lambda obs: get_state_value_residual(obs, self.preproc_inputs, self .state_value_residual)) # Configure dynamics for controller if self.args.agent == 'rts': self.controller.reconfigure_discrepancy( lambda obs, ac: get_discrepancy_neighbors( obs, ac, self.construct_4d_point, self.kdtrees, self.args. neighbor_radius)) # Configure dynamics for controller if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': self.controller.reconfigure_residual_dynamics( self.get_residual_dynamics) # Count of total number of steps total_n_steps = 0 while True: obs = observation['observation'] g = observation['desired_goal'] qpos = observation['sim_state'].qpos qvel = observation['sim_state'].qvel # Get action from the controller ac, info = self.controller.act(observation) if self.args.agent == 'rts': assert self.controller.residual_dynamics_fn is None if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': assert self.controller.discrepancy_fn is None # Get discrete action index ac_ind = self.env.discrete_actions[tuple(ac)] # Get the next observation next_observation, rew, _, _ = self.env.step(ac) # if np.array_equal(obs, next_observation['observation']): # import ipdb # ipdb.set_trace() # print('ACTION', ac) # print('VALUE PREDICTED', info['start_node_h']) # print('COST', -rew) if self.args.render: self.env.render() total_n_steps += 1 # Check if we reached the goal if self.env.env._is_success(next_observation['achieved_goal'], g): print('REACHED GOAL!') break # Get the next obs obs_next = next_observation['observation'] # Get the sim next obs set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(), g.copy()) next_observation_sim, _, _, _ = self.planning_env.step(ac) obs_sim_next = next_observation_sim['observation'] # Store transition transition = [ obs.copy(), g.copy(), ac_ind, qpos.copy(), qvel.copy(), obs_next.copy(), obs_sim_next.copy() ] dynamics_losses = [] # RTS if self.args.agent == 'rts' and self._check_dynamics_transition( transition): # print('DISCREPANCY IN DYNAMICS') self.memory.store_real_world_transition(transition) # # Fit model self._update_discrepancy_model() # MBPO elif self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': self.memory.store_real_world_transition(transition) # Update the dynamics if self.args.agent == 'mbpo': for _ in range(self.args.n_online_planning_updates): # Update dynamics loss = self._update_batch_residual_dynamics() dynamics_losses.append(loss.item()) else: loss = self._update_residual_dynamics() dynamics_losses.append(loss) # # Plan in the model value_loss = self.plan_online_in_model( n_planning_updates=self.args.n_online_planning_updates, initial_observation=copy.deepcopy(observation)) # Log logger.record_tabular('n_steps', total_n_steps) if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': logger.record_tabular('dynamics loss', np.mean(dynamics_losses)) logger.record_tabular('residual_loss', value_loss) # logger.dump_tabular() # Move to next iteration observation = copy.deepcopy(next_observation) if max_timesteps and total_n_steps >= max_timesteps: break return total_n_steps
def _update_state_value_residual(self): # Sample transitions transitions = self.memory.sample_internal_world_memory( self.args.batch_size) qpos, qvel = transitions['qpos'], transitions['qvel'] obs, g, ag = transitions['obs'], transitions['g'], transitions['ag'] # features, heuristic = transitions['features'], transitions['heuristic'] # Compute target by restarting search from the sampled states num_workers = self.args.n_rts_workers if self.args.batch_size < self.args.n_rts_workers: num_workers = self.args.batch_size num_per_worker = self.args.batch_size // num_workers # Put residual in object store value_target_residual_state_dict_id = ray.put( self.state_value_target_residual.state_dict()) # Put normalizer in object store feature_norm_dict_id = ray.put(self.features_normalizer.state_dict()) # Put knn dynamics residuals in object store if self.args.agent == 'rts': kdtrees_serialized_id = ray.put(pickle.dumps(self.kdtrees)) else: kdtrees_serialized_id = None # Put residual dynamics in object store if self.args.agent == 'mbpo': residual_dynamics_state_dict_id = ray.put( self.residual_dynamics.state_dict()) elif self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp': residual_dynamics_state_dict_id = ray.put( pickle.dumps(self.residual_dynamics)) else: residual_dynamics_state_dict_id = None results, count = [], 0 # Set all workers num expansions set_workers_num_expansions(self.internal_rollout_workers, self.args.n_offline_expansions) for worker_id in range(num_workers): if worker_id == num_workers - 1: # last worker takes the remaining load num_per_worker = self.args.batch_size - count # Set parameters ray.get( self.internal_rollout_workers[worker_id].set_worker_params. remote( value_residual_state_dict= value_target_residual_state_dict_id, feature_norm_dict=feature_norm_dict_id, kdtrees_serialized=kdtrees_serialized_id, residual_dynamics_state_dict=residual_dynamics_state_dict_id )) # Send Job results.append(self.internal_rollout_workers[worker_id]. lookahead_batch.remote( obs[count:count + num_per_worker], g[count:count + num_per_worker], ag[count:count + num_per_worker], qpos[count:count + num_per_worker], qvel[count:count + num_per_worker])) count += num_per_worker # Check if all transitions have targets assert count == self.args.batch_size # Get all targets results = ray.get(results) target_infos = [item for sublist in results for item in sublist] # Extract the states, their features and their corresponding targets obs_closed = [ k.obs['observation'].copy() for info in target_infos for k in info['closed'] ] goals_closed = [ k.obs['desired_goal'].copy() for info in target_infos for k in info['closed'] ] heuristic_closed = [ self.controller.heuristic_obs_g(obs_closed[i], goals_closed[i]) for i in range(len(obs_closed)) ] features_closed = [ self.env.extract_features(obs_closed[i], goals_closed[i]) for i in range(len(obs_closed)) ] targets_closed = [ info['best_node_f'] - k._g for info in target_infos for k in info['closed'] ] targets_closed = np.array(targets_closed).reshape(-1, 1) targets_tensor = torch.as_tensor(targets_closed, dtype=torch.float32) # Set all workers num expansions set_workers_num_expansions(self.internal_rollout_workers, self.args.n_expansions) # Normalize features inputs_norm = torch.as_tensor( self.features_normalizer.normalize(features_closed), dtype=torch.float32) heuristic_tensor = torch.as_tensor(heuristic_closed, dtype=torch.float32).view(-1, 1) # Compute target residuals target_residual_tensor = targets_tensor - heuristic_tensor # Clip target residual tenssor to avoid value function less than zero target_residual_tensor = torch.max(target_residual_tensor, -heuristic_tensor) # Clip target residual tensor to avoid value function greater than horizon if self.args.offline: target_residual_tensor = torch.min( target_residual_tensor, self.env_params['offline_max_timesteps'] - heuristic_tensor) # COmpute predictions residual_tensor = self.state_value_residual(inputs_norm) # COmpute loss state_value_residual_loss = (residual_tensor - target_residual_tensor).pow(2).mean() # Backprop and step self.state_value_residual_optim.zero_grad() state_value_residual_loss.backward() self.state_value_residual_optim.step() # Configure heuristic for controller self.controller.reconfigure_heuristic( lambda obs: get_state_value_residual(obs, self.preproc_inputs, self .state_value_residual)) return state_value_residual_loss
def learn_online_in_real_world(self, max_timesteps=None): # If any pre-existing model is given, load it if self.args.load_dir: self.load() # Reset the environment observation = self.env.reset() # Configure heuristic for controller self.controller.reconfigure_heuristic( lambda obs: get_state_value_residual(obs, self.preproc_inputs, self .state_value_residual)) # Configure dynamics for controller self.controller.reconfigure_residual_dynamics( self.get_residual_dynamics) # Count total number of steps total_n_steps = 0 while True: obs = observation['observation'] g = observation['desired_goal'] qpos = observation['sim_state'].qpos qvel = observation['sim_state'].qvel # Get action from controller ac, info = self.controller.act(observation) # Get discrete action index ac_ind = self.env.discrete_actions[tuple(ac)] # Get next observation next_observation, rew, _, _ = self.env.step(ac) # Increment counter total_n_steps += 1 if self.env.env._is_success(next_observation['achieved_goal'], g): print('REACHED GOAL!') break if self.args.render: self.env.render() # Get next obs obs_next = next_observation['observation'] # GEt sim next obs set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(), g.copy()) next_observation_sim, _, _, _ = self.planning_env.step(ac) obs_sim_next = next_observation_sim['observation'] # Store transition in real world memory transition = [ obs.copy(), g.copy(), ac_ind, qpos.copy(), qvel.copy(), obs_next.copy(), obs_sim_next.copy() ] self.memory.store_real_world_transition(transition) # Update the dynamics dynamics_losses = [] for _ in range(self.args.n_online_planning_updates): # Update dynamics loss = self._update_residual_dynamics() dynamics_losses.append(loss.item()) # Update state value residual value_loss = self.plan_online_in_model( self.args.n_online_planning_updates, initial_observation=copy.deepcopy(observation)) # log logger.record_tabular('n_steps', total_n_steps) logger.record_tabular('dynamics_loss', np.mean(dynamics_losses)) logger.record_tabular('residual_loss', value_loss) logger.dump_tabular() # Move to next iteration observation = copy.deepcopy(next_observation) if max_timesteps and total_n_steps >= max_timesteps: break return total_n_steps