def get_successors(self, node, mprim): # obs = node.obs # self.model.set_sim_state(copy.deepcopy(obs['true_state'])) # next_obs, reward, _, _ = self.model.step_mprim(mprim) # cost = -reward # if self.discrepancy_fn is not None: # cost = self.discrepancy_fn(obs, mprim, cost) # next_node = Node(next_obs) obs = node.obs cost = 0 # Step through all discrete states that the mprim goes through # and sum up the cost for discrete_state in mprim.discrete_states: xd, yd, thetad = discrete_state current_observation = np.array( [max(min(obs['observation'][0] + xd, X_DISCRETIZATION-1), 0), max(min(obs['observation'][1] + yd, Y_DISCRETIZATION-1), 0), thetad], dtype=int) cost_step = self.cost_map[current_observation[0], current_observation[1]] cost += cost_step next_obs = {'observation': current_observation} next_node = QNode(next_obs) return next_node, cost
def get_successors(self, node, action): obs = node.obs set_gridworld_state_and_goal( self.model, obs['observation'].copy(), obs['desired_goal'].copy(), ) next_obs, cost, _, _ = self.model.step(action) next_node = QNode(obs=next_obs, dummy=False) return next_node, cost
def get_successors(self, node, ac): obs = node.obs # Set the model to the sim state self.model.set_observation(obs) # Step the model next_obs, cost = self.model.step(ac) # Check if it was a previously known incorrect transition if self.discrepancy_fn is not None: cost = self.discrepancy_fn(obs['observation'], ac, cost) # We should not be querying successors for an already known to be incorrect transition assert not self.check_discrepancy_fn(obs, ac) # Create a node next_node = QNode(next_obs) # print(obs['observation'], ac, next_obs['observation']) return next_node, cost
def act(self, start_node): closed_set = set() inconsistent_set = set() open = [] if hasattr(start_node, '_came_from'): del start_node._came_from reached_goal = False popped_dummy = False start_node._g = 0 h = start_node._h = self.heuristic_fn(start_node) f = start_node._g + start_node._h count = 0 start_triplet = [f, h, count, start_node] heapq.heappush(open, start_triplet) count += 1 open_d = {start_node: start_triplet} for _ in range(self.num_expansions): f, h, _, node = heapq.heappop(open) del open_d[node] # Check if the node is dummy, if so pop it and stop if node.dummy: popped_dummy = True best_node = node break closed_set.add(node) # Check if node is goal if self.check_goal_fn(node): reached_goal = True best_node = node break for mprim in self.mprims_fn(node): # Check if this transition has any known discrepancy if self.discrepancy_fn(node.obs, mprim): # This transition is known to be incorrect # Create a dummy node and add to open list new_node = QNode(obs=None, dummy=True) new_node._g = node._g new_node._h = self.qvalue_fn(node.obs, mprim) new_node._came_from = node new_node._action = mprim new_f = new_node._g + new_node._h d = open_d[new_node] = [ new_f, new_node._h, count, new_node ] heapq.heappush(open, d) count += 1 # Add this node to inconsistent set inconsistent_set.add(new_node) else: # No known discrepancy for this transition neighbor, cost = self.successors_fn(node, mprim) if neighbor in closed_set: continue tentative_g = node._g + cost if neighbor not in open_d: neighbor._came_from = node neighbor._action = mprim neighbor._g = tentative_g h = neighbor._h = self.heuristic_fn(neighbor) f = neighbor._g + neighbor._h d = open_d[neighbor] = [ tentative_g + h, h, count, neighbor ] heapq.heappush(open, d) count += 1 else: neighbor = open_d[neighbor][3] if tentative_g < neighbor._g: neighbor._came_from = node neighbor._action = mprim neighbor._g = tentative_g open_d[neighbor][0] = tentative_g + neighbor._h heapq.heapify(open) # Check if we either reached the goal or popped a dummy if (not reached_goal) and (not popped_dummy): # Pop the open list again best_node_f, best_node_h, _, best_node = heapq.heappop(open) # del open_d[best_node] info = { 'best_node_f': best_node._g + best_node._h, 'start_node_h': start_node._h, 'best_node': best_node, 'closed': closed_set, 'open': open_d.keys(), 'dummy': popped_dummy } best_mprim, path = self.get_best_mprim(start_node, best_node) info['path'] = path info['successor_obs'] = path[1] return best_mprim, info
def act(self, obs): start_node = QNode(obs) best_action, info = self.astar.act(start_node) return best_action, info
def act(self, obs): start_node = QNode(obs=obs, dummy=False) best_action, info = self.qastar.act(start_node) return best_action, info
def get_successors_obs(self, obs, mprim): node = QNode(obs) next_node, cost = self.get_successors(node, mprim) return next_node.obs, cost