Example #1
0
    def get_successors(self, node, mprim):
        # obs = node.obs
        # self.model.set_sim_state(copy.deepcopy(obs['true_state']))

        # next_obs, reward, _, _ = self.model.step_mprim(mprim)

        # cost = -reward
        # if self.discrepancy_fn is not None:
        #     cost = self.discrepancy_fn(obs, mprim, cost)

        # next_node = Node(next_obs)

        obs = node.obs
        cost = 0
        # Step through all discrete states that the mprim goes through
        # and sum up the cost
        for discrete_state in mprim.discrete_states:
            xd, yd, thetad = discrete_state
            current_observation = np.array(
                [max(min(obs['observation'][0] + xd, X_DISCRETIZATION-1), 0),
                 max(min(obs['observation'][1] + yd, Y_DISCRETIZATION-1), 0),
                 thetad], dtype=int)
            cost_step = self.cost_map[current_observation[0],
                                      current_observation[1]]
            cost += cost_step

        next_obs = {'observation': current_observation}
        next_node = QNode(next_obs)

        return next_node, cost
Example #2
0
    def get_successors(self, node, action):
        obs = node.obs
        set_gridworld_state_and_goal(
            self.model,
            obs['observation'].copy(),
            obs['desired_goal'].copy(),
        )

        next_obs, cost, _, _ = self.model.step(action)

        next_node = QNode(obs=next_obs, dummy=False)

        return next_node, cost
Example #3
0
    def get_successors(self, node, ac):
        obs = node.obs
        # Set the model to the sim state
        self.model.set_observation(obs)
        # Step the model
        next_obs, cost = self.model.step(ac)
        # Check if it was a previously known incorrect transition
        if self.discrepancy_fn is not None:
            cost = self.discrepancy_fn(obs['observation'], ac, cost)
            # We should not be querying successors for an already known to be incorrect transition
            assert not self.check_discrepancy_fn(obs, ac)

        # Create a node
        next_node = QNode(next_obs)
        # print(obs['observation'], ac, next_obs['observation'])

        return next_node, cost
Example #4
0
    def act(self, start_node):
        closed_set = set()
        inconsistent_set = set()
        open = []

        if hasattr(start_node, '_came_from'):
            del start_node._came_from

        reached_goal = False
        popped_dummy = False

        start_node._g = 0
        h = start_node._h = self.heuristic_fn(start_node)
        f = start_node._g + start_node._h

        count = 0
        start_triplet = [f, h, count, start_node]
        heapq.heappush(open, start_triplet)
        count += 1
        open_d = {start_node: start_triplet}

        for _ in range(self.num_expansions):
            f, h, _, node = heapq.heappop(open)
            del open_d[node]

            # Check if the node is dummy, if so pop it and stop
            if node.dummy:
                popped_dummy = True
                best_node = node
                break

            closed_set.add(node)

            # Check if node is goal
            if self.check_goal_fn(node):
                reached_goal = True
                best_node = node
                break

            for mprim in self.mprims_fn(node):
                # Check if this transition has any known discrepancy
                if self.discrepancy_fn(node.obs, mprim):
                    # This transition is known to be incorrect
                    # Create a dummy node and add to open list
                    new_node = QNode(obs=None, dummy=True)
                    new_node._g = node._g
                    new_node._h = self.qvalue_fn(node.obs, mprim)
                    new_node._came_from = node
                    new_node._action = mprim
                    new_f = new_node._g + new_node._h

                    d = open_d[new_node] = [
                        new_f, new_node._h, count, new_node
                    ]
                    heapq.heappush(open, d)
                    count += 1
                    # Add this node to inconsistent set
                    inconsistent_set.add(new_node)
                else:
                    # No known discrepancy for this transition
                    neighbor, cost = self.successors_fn(node, mprim)
                    if neighbor in closed_set:
                        continue

                    tentative_g = node._g + cost
                    if neighbor not in open_d:
                        neighbor._came_from = node
                        neighbor._action = mprim
                        neighbor._g = tentative_g
                        h = neighbor._h = self.heuristic_fn(neighbor)
                        f = neighbor._g + neighbor._h
                        d = open_d[neighbor] = [
                            tentative_g + h, h, count, neighbor
                        ]
                        heapq.heappush(open, d)
                        count += 1
                    else:
                        neighbor = open_d[neighbor][3]
                        if tentative_g < neighbor._g:
                            neighbor._came_from = node
                            neighbor._action = mprim
                            neighbor._g = tentative_g
                            open_d[neighbor][0] = tentative_g + neighbor._h
                            heapq.heapify(open)

        # Check if we either reached the goal or popped a dummy
        if (not reached_goal) and (not popped_dummy):
            # Pop the open list again
            best_node_f, best_node_h, _, best_node = heapq.heappop(open)
            # del open_d[best_node]

        info = {
            'best_node_f': best_node._g + best_node._h,
            'start_node_h': start_node._h,
            'best_node': best_node,
            'closed': closed_set,
            'open': open_d.keys(),
            'dummy': popped_dummy
        }

        best_mprim, path = self.get_best_mprim(start_node, best_node)
        info['path'] = path
        info['successor_obs'] = path[1]
        return best_mprim, info
Example #5
0
 def act(self, obs):
     start_node = QNode(obs)
     best_action, info = self.astar.act(start_node)
     return best_action, info
Example #6
0
 def act(self, obs):
     start_node = QNode(obs=obs, dummy=False)
     best_action, info = self.qastar.act(start_node)
     return best_action, info
Example #7
0
 def get_successors_obs(self, obs, mprim):
     node = QNode(obs)
     next_node, cost = self.get_successors(node, mprim)
     return next_node.obs, cost