Esempio n. 1
0
    def __init__(
        self,
        observation_space: spaces.Dict,
        hidden_size: int,
    ):
        super().__init__()

        if (IntegratedPointGoalGPSAndCompassSensor.cls_uuid
                in observation_space.spaces):
            self._n_input_goal = observation_space.spaces[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid].shape[0]
        elif PointGoalSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                PointGoalSensor.cls_uuid].shape[0]
        elif ImageGoalSensor.cls_uuid in observation_space.spaces:
            goal_observation_space = spaces.Dict(
                {"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]})
            self.goal_visual_encoder = SimpleCNN(goal_observation_space,
                                                 hidden_size)
            self._n_input_goal = hidden_size

        self._hidden_size = hidden_size

        self.visual_encoder = SimpleCNN(observation_space, hidden_size)

        self.state_encoder = build_rnn_state_encoder(
            (0 if self.is_blind else self._hidden_size) + self._n_input_goal,
            self._hidden_size,
        )

        self.train()
Esempio n. 2
0
    def __init__(
        self,
        observation_space: spaces.Dict,
        hidden_size: int,
        goal_hidden_size,
        fuse_states,
        force_blind,
        normalize_vis,
        normalize_state,
    ):
        super().__init__()

        if TargetPointGoalGPSAndCompassSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                TargetPointGoalGPSAndCompassSensor.cls_uuid].shape[0]
        elif (IntegratedPointGoalGPSAndCompassSensor.cls_uuid
              in observation_space.spaces):
            self._n_input_goal = observation_space.spaces[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid].shape[0]
        elif PointGoalSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                PointGoalSensor.cls_uuid].shape[0]
        elif ImageGoalSensor.cls_uuid in observation_space.spaces:
            goal_observation_space = spaces.Dict(
                {"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]})
            self.goal_visual_encoder = SimpleCNN(goal_observation_space,
                                                 hidden_size, False, False)
            self._n_input_goal = hidden_size
        else:
            self.fuse_states = fuse_states
            self._n_input_goal = sum([
                observation_space.spaces[n].shape[0] for n in self.fuse_states
            ])

        self._hidden_size = hidden_size
        self._goal_hidden_size = goal_hidden_size

        self.visual_encoder = SimpleCNN(observation_space, hidden_size,
                                        force_blind, normalize_vis)
        state_dim = self._n_input_goal
        if self._goal_hidden_size != 0:
            self.goal_encoder = nn.Sequential(
                nn.Linear(self._n_input_goal, self._goal_hidden_size),
                nn.ReLU(),
                nn.Linear(self._goal_hidden_size, self._goal_hidden_size),
                nn.ReLU())
            state_dim = self._goal_hidden_size

        if normalize_state:
            self.state_norm = RunningMeanAndVar(self._n_input_goal, [])
        else:
            self.state_norm = nn.Sequential()

        self.state_encoder = build_rnn_state_encoder(
            (0 if self.is_blind else self._hidden_size) + state_dim,
            self._hidden_size,
        )

        self.train()
Esempio n. 3
0
def test_rnn_state_encoder():
    from habitat_baselines.rl.models.rnn_state_encoder import (
        build_rnn_state_encoder, )

    device = (torch.device("cuda")
              if torch.cuda.is_available() else torch.device("cpu"))
    rnn_state_encoder = build_rnn_state_encoder(32, 32,
                                                num_layers=2).to(device=device)
    rnn = rnn_state_encoder.rnn
    with torch.no_grad():
        for T in [1, 2, 4, 8, 16, 32, 64, 3, 13, 31]:
            for N in [1, 2, 4, 8, 3, 5]:
                masks = torch.randint(0,
                                      2,
                                      size=(T, N, 1),
                                      dtype=torch.bool,
                                      device=device)
                inputs = torch.randn((T, N, 32), device=device)
                hidden_states = torch.randn(
                    rnn_state_encoder.num_recurrent_layers,
                    N,
                    32,
                    device=device,
                )

                outputs, out_hiddens = rnn_state_encoder(
                    inputs.flatten(0, 1),
                    hidden_states.permute(1, 0, 2),
                    masks.flatten(0, 1),
                )
                out_hiddens = out_hiddens.permute(1, 0, 2)

                reference_ouputs = []
                reference_hiddens = hidden_states.clone()
                for t in range(T):
                    reference_hiddens = torch.where(
                        masks[t].view(1, -1, 1),
                        reference_hiddens,
                        reference_hiddens.new_zeros(()),
                    )

                    x, reference_hiddens = rnn(inputs[t:t + 1],
                                               reference_hiddens)

                    reference_ouputs.append(x.squeeze(0))

                reference_ouputs = torch.stack(reference_ouputs,
                                               0).flatten(0, 1)

                assert (torch.norm(reference_ouputs - outputs).item() <
                        1e-3), "Failed on (T={}, N={})".format(T, N)
                assert (torch.norm(reference_hiddens - out_hiddens).item() <
                        1e-3), "Failed on (T={}, N={})".format(T, N)
Esempio n. 4
0
    def __init__(
        self,
        observation_space: spaces.Dict,
        action_space,
        hidden_size: int,
        num_recurrent_layers: int,
        rnn_type: str,
        backbone,
        resnet_baseplanes,
        normalize_visual_inputs: bool,
        force_blind_policy: bool = False,
    ):
        super().__init__()

        self.prev_action_embedding = nn.Embedding(action_space.n + 1, 32)
        self._n_prev_action = 32
        rnn_input_size = self._n_prev_action

        if (
            IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            in observation_space.spaces
        ):
            n_input_goal = (
                observation_space.spaces[
                    IntegratedPointGoalGPSAndCompassSensor.cls_uuid
                ].shape[0]
                + 1
            )
            self.tgt_embeding = nn.Linear(n_input_goal, 32)
            rnn_input_size += 32

        if ObjectGoalSensor.cls_uuid in observation_space.spaces:
            self._n_object_categories = (
                int(
                    observation_space.spaces[ObjectGoalSensor.cls_uuid].high[0]
                )
                + 1
            )
            self.obj_categories_embedding = nn.Embedding(
                self._n_object_categories, 32
            )
            rnn_input_size += 32

        if EpisodicGPSSensor.cls_uuid in observation_space.spaces:
            input_gps_dim = observation_space.spaces[
                EpisodicGPSSensor.cls_uuid
            ].shape[0]
            self.gps_embedding = nn.Linear(input_gps_dim, 32)
            rnn_input_size += 32

        if PointGoalSensor.cls_uuid in observation_space.spaces:
            input_pointgoal_dim = observation_space.spaces[
                PointGoalSensor.cls_uuid
            ].shape[0]
            self.pointgoal_embedding = nn.Linear(input_pointgoal_dim, 32)
            rnn_input_size += 32

        if HeadingSensor.cls_uuid in observation_space.spaces:
            input_heading_dim = (
                observation_space.spaces[HeadingSensor.cls_uuid].shape[0] + 1
            )
            assert input_heading_dim == 2, "Expected heading with 2D rotation."
            self.heading_embedding = nn.Linear(input_heading_dim, 32)
            rnn_input_size += 32

        if ProximitySensor.cls_uuid in observation_space.spaces:
            input_proximity_dim = observation_space.spaces[
                ProximitySensor.cls_uuid
            ].shape[0]
            self.proximity_embedding = nn.Linear(input_proximity_dim, 32)
            rnn_input_size += 32

        if EpisodicCompassSensor.cls_uuid in observation_space.spaces:
            assert (
                observation_space.spaces[EpisodicCompassSensor.cls_uuid].shape[
                    0
                ]
                == 1
            ), "Expected compass with 2D rotation."
            input_compass_dim = 2  # cos and sin of the angle
            self.compass_embedding = nn.Linear(input_compass_dim, 32)
            rnn_input_size += 32

        if ImageGoalSensor.cls_uuid in observation_space.spaces:
            goal_observation_space = spaces.Dict(
                {"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]}
            )
            self.goal_visual_encoder = ResNetEncoder(
                goal_observation_space,
                baseplanes=resnet_baseplanes,
                ngroups=resnet_baseplanes // 2,
                make_backbone=getattr(resnet, backbone),
                normalize_visual_inputs=normalize_visual_inputs,
            )

            self.goal_visual_fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(
                    np.prod(self.goal_visual_encoder.output_shape), hidden_size
                ),
                nn.ReLU(True),
            )

            rnn_input_size += hidden_size

        self._hidden_size = hidden_size

        self.visual_encoder = ResNetEncoder(
            observation_space if not force_blind_policy else spaces.Dict({}),
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=normalize_visual_inputs,
        )

        if not self.visual_encoder.is_blind:
            self.visual_fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(
                    np.prod(self.visual_encoder.output_shape), hidden_size
                ),
                nn.ReLU(True),
            )

        self.state_encoder = build_rnn_state_encoder(
            (0 if self.is_blind else self._hidden_size) + rnn_input_size,
            self._hidden_size,
            rnn_type=rnn_type,
            num_layers=num_recurrent_layers,
        )

        self.train()