コード例 #1
0
ファイル: r2r_dataset.py プロジェクト: erick84mm/habitat-api
    def from_json(self,
                  json_str: str,
                  scenes_dir: Optional[str] = None) -> None:
        deserialized = json.loads(json_str)
        default_rotation = [0, 0, 0, 1]

        self.train_vocab = VocabDict(
            word_list=deserialized["train_vocab"]["word_list"])
        self.trainval_vocab = VocabDict(
            word_list=deserialized["trainval_vocab"]["word_list"])

        self.action_tokens = deserialized["BERT_vocab"]["action_tokens"]
        self.mini_alignments = deserialized["mini_alignments"]

        self.scenes = deserialized["scenes"]

        self.connectivity = load_connectivity(self.config.CONNECTIVITY_PATH,
                                              self.scenes)

        for ep_index, r2r_episode in enumerate(deserialized["episodes"]):

            r2r_episode["curr_viewpoint"] = ViewpointData(
                image_id=r2r_episode["goals"][0],
                view_point=AgentState(position=r2r_episode["start_position"],
                                      rotation=r2r_episode["start_rotation"]))
            instruction_encoding = r2r_episode["instruction_encoding"]
            mask = r2r_episode["mask"]
            del r2r_episode["instruction_encoding"]
            del r2r_episode["mask"]
            episode = VLNEpisode(**r2r_episode)

            if scenes_dir is not None:
                if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX):
                    episode.scene_id = episode.scene_id[
                        len(DEFAULT_SCENE_PATH_PREFIX):]
                episode.scene_id = os.path.join(scenes_dir, episode.scene_id)
            episode.instruction = InstructionData(
                instruction=r2r_episode["instruction"],
                tokens=instruction_encoding,
                tokens_length=sum(mask),
                mask=mask)

            scan = episode.scan
            for v_index, viewpoint in enumerate(episode.goals):
                viewpoint_id = self.connectivity[scan]["idxtoid"][viewpoint]
                pos = self.connectivity[scan]["viewpoints"][viewpoint_id]
                rot = default_rotation
                episode.goals[v_index] = ViewpointData(image_id=viewpoint,
                                                       view_point=AgentState(
                                                           position=pos,
                                                           rotation=rot))
            episode.distance = self.get_distance_to_target(
                scan, episode.goals[0].image_id, episode.goals[-1].image_id)
            self.episodes.append(episode)
コード例 #2
0
    def _teacher_actions(self, observations, goal):
        action = ""
        action_args = {}
        navigable_locations = observations["adjacentViewpoints"]

        if goal == navigable_locations[0][1]:  # image_id
            action = "STOP"
        else:
            step_size = np.pi/6.0  # default step in R2R
            goal_location = None
            for location in navigable_locations:
                if location[1] == goal:  # image_id
                    goal_location = location
                    break
            # Check if the goal is visible
            if goal_location:

                rel_heading = goal_location[2]  # rel_heading
                rel_elevation = goal_location[3]  #rel_elevation

                if rel_heading > step_size:
                    action = "TURN_RIGHT"
                elif rel_heading < -step_size:
                    action = "TURN_LEFT"
                elif rel_elevation > step_size:
                    action = "LOOK_UP"
                elif rel_elevation < -step_size:
                    action = "LOOK_DOWN"
                else:
                    if goal_location[0] == 1:  # restricted
                        print("WARNING: The target was not in the" +
                              " Field of view, but the step action " +
                              "is going to be performed")
                    action = "TELEPORT"  # Move forward
                    image_id = goal
                    posB = goal_location[4:7]  # start_position
                    rotA = navigable_locations[0][14:18]  # camera_rotation
                    viewpoint = ViewpointData(
                        image_id=image_id,
                        view_point=AgentState(position=posB, rotation=rotA)
                    )
                    action_args.update({"target": viewpoint})
            else:
                # Episode Failure
                action = 'STOP'
                print("Target position %s not visible, " % goal +
                      "This is an error in the system")
                '''
                for ob in observations["images"]:
                    image = ob
                    image =  image[:,:, [2,1,0]]
                    cv2.imshow("RGB", image)
                    cv2.waitKey(0)
                '''
        return action, action_args
コード例 #3
0
    def _teleport_target(self, observations):
        action = ""
        action_args = {}
        navigable_locations = observations["adjacentViewpoints"]

        for location in navigable_locations[1:]:
            if location[0] == 1:  # location is restricted
                continue
            elif location[0] == 0:  # Non restricted location
                action = "TELEPORT"
                image_id = location[1]
                posB = location[4:7]  # start_position
                rotA = navigable_locations[0][14:18]  # camera_rotation
                viewpoint = ViewpointData(image_id=image_id,
                                          view_point=AgentState(position=posB,
                                                                rotation=rotA))
                action_args = {"target": viewpoint}
                #print("the target is ", location)
                return {"action": action, "action_args": action_args}

        return {"action": action, "action_args": action_args}
コード例 #4
0
    def act(self, observations, elapsed_steps, previous_step_collided):
        action = ""
        action_args = {}
        visible_points = sum([
            1 - ob[0] for ob in observations["adjacentViewpoints"]
            if ob[0] != -1
        ])

        if elapsed_steps == 0:
            # Turn right (direction choosing)
            action = "TURN_RIGHT"
            num_steps = random.randint(0, 11)
            if num_steps > 0:
                action_args = {"num_steps": num_steps}

        # Stop after teleporting 6 times.
        elif elapsed_steps >= 5:
            action = "STOP"

        # Turn right until we can go forward
        elif visible_points > 0:
            for ob in observations["adjacentViewpoints"]:
                if not ob[0]:
                    goal = ob
                    action = "TELEPORT"
                    image_id = goal[1]
                    pos = goal[4:7]  # agent_position

                    # Keeping the same rotation as the previous step
                    # camera rotation
                    rot = observations["adjacentViewpoints"][0][14:18]

                    viewpoint = ViewpointData(image_id=image_id,
                                              view_point=AgentState(
                                                  position=pos, rotation=rot))
                    action_args.update({"target": viewpoint})
                    break
        else:
            action = "TURN_RIGHT"
        return {"action": action, "action_args": action_args}
コード例 #5
0
    def act(self, observations, episode, goal):
        # Initialization when the action is start
        batch_size = 1

        if self.previous_action == "<start>":
            # should be a tensor of logits
            if episode.instruction.tokens_length < self.max_tokens:
                pad = self.max_tokens - episode.instruction.tokens_length
                tokens = episode.instruction.tokens
                tokens.extend([0] * pad)
                tokens = np.array([tokens])
            else:
                tokens = episode.instruction.tokens[:self.max_tokens]
                tokens = np.array([tokens])

            seq_lengths = np.argmax(tokens == 0, axis=1)
            seq_lengths[seq_lengths == 0] = self.max_tokens
            seq_tensor = torch.from_numpy(tokens).to('cuda')
            seq_lengths = torch.from_numpy(seq_lengths).to('cuda')
            seq_mask = (seq_tensor == 0)[:, :seq_lengths[0]]
            self.seq_mask = seq_mask.to('cuda')
            tokens = None

            # Forward through encoder, giving initial hidden state and memory cell for decoder
            self.ctx, self.h_t, self.c_t = self.encoder(
                seq_tensor, seq_lengths)
            self.a_t = torch.ones(batch_size, requires_grad=False).long() * \
                    self.model_actions.index(self.previous_action)
            self.a_t = self.a_t.unsqueeze(0).to('cuda')

        im = observations["rgb"][:, :, [2, 1, 0]]
        f_t = self._get_image_features(im)  #.to('cuda')

        im = None
        ended = np.array(
            [False] *
            batch_size)  # Indices match permuation of the model, not env

        # Do a sequence rollout and calculate the loss
        self.h_t, self.c_t, alpha, logit = self.decoder(
            self.a_t.view(-1, 1), f_t, self.h_t, self.c_t, self.ctx,
            self.seq_mask)
        # Mask outputs where agent can't move forward
        # Release memory?
        f_t = None
        visible_points = sum([
            1 - ob[0] for ob in observations["adjacentViewpoints"]
            if ob[0] != -1
        ])

        if visible_points == 0:
            logit[0, self.model_actions.index('TELEPORT')] = -float('inf')

        # Supervised training
        target_action, action_args = self._teacher_actions(observations, goal)
        target = torch.LongTensor(1)
        target[0] = self.model_actions.index(target_action)
        target = target.to('cuda')

        self.loss += self.criterion(logit, target)
        #print(logit)
        # Determine next model inputs
        if self.feedback == 'teacher':
            self.a_t = target  # teacher forcing
            action = target_action
        elif self.feedback == 'argmax':
            _, self.a_t = logit.max(1)  # student forcing - argmax
            self.a_t = self.a_t.detach()
            action = self.model_actions[self.a_t.item()]
            action_args = {
            }  # What happens if you need to teleport? How to choose?
        elif self.feedback == 'sample':
            probs = F.softmax(logit, dim=1)
            m = D.Categorical(probs)
            self.a_t = m.sample()  # sampling an action from model
            action = self.model_actions[self.a_t.item()]
            action_args = {}
        else:
            sys.exit('Invalid feedback option')

        # Teleport to the next location, the one with lower rel heading

        if action == "TELEPORT" and self.feedback != 'teacher':
            sorted_obs = sorted(
                observations["adjacentViewpoints"][1:],
                key=lambda x: abs(x[2]))  # sort by relative heading
            for ob in sorted_obs:
                if not ob[0]:  # restricted
                    next_location = ob
                    action = "TELEPORT"
                    image_id = next_location[1]
                    pos = next_location[4:7]  # agent_position

                    # Keeping the same rotation as the previous step
                    # camera rotation
                    rot = observations["adjacentViewpoints"][0][14:18]
                    #print("Teleporting to ",image_id, goal, pos, rot, ob,
                    #[ ob[:3] for ob in sorted_obs if ob[0] != -1])
                    viewpoint = ViewpointData(image_id=image_id,
                                              view_point=AgentState(
                                                  position=pos, rotation=rot))
                    action_args = {"target": viewpoint}
                    break
        sorted_obs = None
        #print(action, target_action, self.loss.item())
        #self.predicted_actions.append(action)
        self.previous_action = action

        return {"action": action, "action_args": action_args}