Ejemplo n.º 1
0
    def _answer_question(
            self, actions, ob, q
    ):  #### Change the agent's interpretation of the answers here ####
        scan = ob['scan']
        instr = ob['instruction']
        current_viewpoint = ob['viewpoint']
        start_viewpoint = ob['init_viewpoint']
        goal_viewpoints = ob['goal_viewpoints']

        panos_to_region = utils.load_panos_to_region(scan,
                                                     None,
                                                     include_region_id=True)
        current_region_id, current_region = panos_to_region[current_viewpoint]
        goal_region_ids = []
        for viewpoint in goal_viewpoints:
            id, region = panos_to_region[viewpoint]
            goal_region_ids.append(id)
            goal_region = region

        d, goal_point = self.nav_oracle._find_nearest_point(
            scan, current_viewpoint, ob['goal_viewpoints'])

        actions_names = [self._make_action_name(action) for action in actions]

        # answer for 'do I arrive?'
        if self.agent_ask_actions[q] == 'arrive':
            if current_viewpoint in goal_viewpoints:
                return 'stop .', 'replace'
            else:
                return 'go, ', 'prepend'

        # answer for 'am I in the right room?'
        if self.agent_ask_actions[q] == 'room':
            if current_region == goal_region and current_region_id in goal_region_ids:
                if ('find' in instr) and (' in ' in instr):
                    return instr[instr.index('find'):instr.
                                 index(' in ')] + ', ', 'prepend'
                else:
                    return instr, 'replace'
            else:
                return 'exit room, ', 'prepend'

        # answer for 'am I on the right direction?'
        elif self.agent_ask_actions[q] == 'direction':
            if 'turn' in actions_names[0]:
                return 'turn around, ', 'prepend'
            else:
                return 'go straight, ', 'prepend'

        # answer for 'is the goal far from me?'
        elif self.agent_ask_actions[q] == 'distance':
            if d >= 10:
                return 'far, ', 'prepend'
            elif d >= 5:
                return 'middle, ', 'prepend'
            else:
                return 'close, ', 'prepend'
Ejemplo n.º 2
0
    def _should_ask(self, ob, nav_oracle=None):

        if self.rule_a_e:
            return self._should_ask_rule_a_e(ob, nav_oracle=nav_oracle)

        if self.rule_b_d:
            return self._should_ask_rule_b_d(ob, nav_oracle=nav_oracle)

        if ob['queries_unused'] <= 0:
            return self.agent_ask_actions.index('dont_ask'), 'exceed'

        # Find nearest point on the current shortest path
        scan = ob['scan']
        current_point = ob['viewpoint']
        # Find nearest goal to current point
        d, goal_point = nav_oracle._find_nearest_point(scan, current_point,
                                                       ob['goal_viewpoints'])

        panos_to_region = utils.load_panos_to_region(scan,
                                                     None,
                                                     include_region_id=False)

        # Rule (e): ask if the goal has been reached
        agent_decision = int(np.argmax(ob['nav_dist']))
        if current_point == goal_point or agent_decision == nav_oracle.agent_nav_actions.index(
                '<end>'):
            return self.agent_ask_actions.index('arrive'), 'arrive'

        # Rule (a): ask if the agent deviates too far from the optimal path
        if d > self.deviate_threshold:
            return self.agent_ask_actions.index('direction'), 'deviate'

        # Rule (c): ask if not moving for too long
        if len(ob['agent_path']) >= self.unmoved_threshold:
            last_nodes = [t[0]
                          for t in ob['agent_path']][-self.unmoved_threshold:]
            if all(node == last_nodes[0] for node in last_nodes):
                return self.agent_ask_actions.index('distance'), 'unmoved'

        # Rule (f): ask if staying in the same room for too long
        if len(ob['agent_path']) >= self.same_room_threshold:
            last_ask = [a for a in ob['agent_ask']][-self.same_room_threshold:]
            last_nodes = [t[0] for t in ob['agent_path']
                          ][-self.same_room_threshold:]
            if all(panos_to_region[node] == panos_to_region[last_nodes[0]] for node in last_nodes) and \
               self.agent_ask_actions.index('room') not in last_ask:
                return self.agent_ask_actions.index('room'), 'same_room'

        # Rule (b): ask if uncertain
        agent_dist = ob['nav_dist']
        uniform = [1. / len(agent_dist)] * len(agent_dist)
        entropy_gap = scipy.stats.entropy(uniform) - scipy.stats.entropy(
            agent_dist)
        if entropy_gap < self.uncertain_threshold - 1e-9:
            return self.agent_ask_actions.index('direction'), 'uncertain'

        return self.agent_ask_actions.index('dont_ask'), 'pass'
Ejemplo n.º 3
0
    def __init__(self, hparams, splits, data_path):
        self.success_radius = hparams.success_radius
        self.splits = splits

        self.scans = set()
        self.graphs = {}
        self.distances = {}

        self.no_room = hasattr(hparams, 'no_room') and hparams.no_room
        if splits:
            self.load_data(
                load_datasets(splits,
                              data_path,
                              prefix='noroom' if self.no_room else 'asknav'))

        self.region_label_to_name = load_region_label_to_name()
        self.panos_to_region = {}
        for scan in self.scans:
            self.panos_to_region[scan] = load_panos_to_region(
                scan, self.region_label_to_name)
Ejemplo n.º 4
0
def compute_ask_stats(traj, agent):
    total_steps = 0
    total_agent_ask = 0
    total_teacher_ask = 0
    queries_per_ep = []
    ask_pred = []
    ask_true = []
    bad_questions = []

    all_reasons = []
    loss_str = ''

    nav_oracle = agent.advisor.nav_oracle

    for i, t in enumerate(traj):
        assert len(t['agent_ask']) == len(t['teacher_ask'])

        end_step = len(t['agent_path'])

        pred = t['agent_ask'][:end_step]
        true = t['teacher_ask'][:end_step]

        ### BAD QUESTION

        path = t['agent_path']

        bad_question_marks = [0] * len(path)

        # bad question rule 1
        for index in range(len(pred) - 1):
            if pred[index] == pred[index + 1] == AskAgent.ask_actions.index('direction') and \
                    path[index] == path[index + 1]:
                bad_question_marks[index + 1] = 1

        # bad question rule 2
        scan = t['scan']
        goal_viewpoints = t['goal_viewpoints']

        distance_indices = [
            index for index, question in enumerate(pred)
            if question == AskAgent.ask_actions.index('distance')
        ]
        for index in range(len(distance_indices) - 1):
            _, goal_point = nav_oracle._find_nearest_point(
                scan, path[distance_indices[index]][0], goal_viewpoints)
            d1, _ = nav_oracle._find_nearest_point_on_a_path(
                scan, path[distance_indices[index]][0], path[0][0], goal_point)
            d2, _ = nav_oracle._find_nearest_point_on_a_path(
                scan, path[distance_indices[index + 1]][0], path[0][0],
                goal_point)
            if abs(d1 - d2) <= 3:
                bad_question_marks[distance_indices[index + 1]] = 1

        # bad question rule 3
        panos_to_region = load_panos_to_region(scan,
                                               None,
                                               include_region_id=True)
        room_indices = [
            index for index, question in enumerate(pred)
            if question == AskAgent.ask_actions.index('room')
        ]
        for index in range(len(room_indices) - 1):
            region_id_1, region_1 = panos_to_region[path[room_indices[index]]
                                                    [0]]
            region_id_2, region_2 = panos_to_region[path[room_indices[index +
                                                                      1]][0]]
            if region_id_1 == region_id_2 and region_1 == region_2:
                bad_question_marks[room_indices[index + 1]] = 1

        # bad question rule 4
        goal_viewpoints = t['goal_viewpoints']
        for index in range(len(pred) - 1):
            if pred[index] == AskAgent.ask_actions.index('arrive'):
                d, goal_point = nav_oracle._find_nearest_point(
                    scan, path[index][0], goal_viewpoints)
                if d >= 4:
                    bad_question_marks[index] = 1

        bad_questions.append(sum(bad_question_marks))

        ### BAD QUESTION

        total_steps += len(true)
        total_agent_ask += sum(
            any(x == AskAgent.ask_actions.index(question)
                for question in AskAgent.question_pool) for x in pred)  # TBD
        total_teacher_ask += sum(
            any(x == AskAgent.ask_actions.index(question)
                for question in AskAgent.question_pool) for x in true)
        ask_pred.extend(pred)
        ask_true.extend(true)

        queries_per_ep.append(
            sum(
                any(x == AskAgent.ask_actions.index(question)
                    for question in AskAgent.question_pool) for x in pred))
        teacher_reason = t['teacher_ask_reason'][:end_step]
        all_reasons.extend(teacher_reason)

    loss_str += '\n *** ASK:'
    loss_str += ' queries_per_ep %.1f' % (sum(queries_per_ep) /
                                          len(queries_per_ep))
    loss_str += ', agent_ratio %.3f' % (total_agent_ask / total_steps)
    loss_str += ', teacher_ratio %.3f' % (total_teacher_ask / total_steps)
    loss_str += ', A/P/R/F %.3f / %.3f / %.3f / %.3f' % (
        accuracy_score(ask_true, ask_pred),
        precision_score(ask_true, ask_pred, average='macro'),
        recall_score(ask_true, ask_pred, average='macro'),
        f1_score(ask_true, ask_pred, average='macro'))
    loss_str += ', bad_questions_per_ep %.1f' % (sum(bad_questions) /
                                                 len(bad_questions))

    loss_str += '\n *** TEACHER ASK:'
    reason_counter = Counter(all_reasons)
    total_asks = sum(x != 'pass' and x != 'exceed' for x in all_reasons)
    loss_str += ' ask %.3f, dont_ask %.3f, ' % (
        total_asks / len(all_reasons),
        (len(all_reasons) - total_asks) / len(all_reasons))
    loss_str += ', '.join([
        '%s %.3f' % (k, reason_counter[k] / total_asks)
        for k in reason_counter.keys() if k not in ['pass', 'exceed']
    ])

    return loss_str