예제 #1
0
    def create_memory(
        self,
        spec: Optional[FullMemorySpecType],
        num_samplers: int,
    ) -> Memory:
        if spec is None:
            return Memory()

        memory = Memory()
        for key in spec:
            dims_template, dtype = spec[key]

            dim_names = ["step"] + [d[0] for d in dims_template]
            sampler_dim = dim_names.index("sampler")

            all_dims = [self.num_steps + 1] + [d[1] for d in dims_template]
            all_dims[sampler_dim] = num_samplers

            memory.check_append(
                key=key,
                tensor=torch.zeros(*all_dims, dtype=dtype),
                sampler_dim=sampler_dim,
            )

            self.flattened_to_unflattened["memory"][key] = [key]
            self.unflattened_to_flattened["memory"][(key, )] = key

        return memory
예제 #2
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        target_encoding = self.get_target_coordinates_encoding(observations)
        x: Union[torch.Tensor, List[torch.Tensor]]
        x = [target_encoding]

        # if observations["rgb"].shape[0] != 1:
        #     print("rgb", (observations["rgb"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["rgb"][...,0,0,:]).float().mean())
        #     if "depth" in observations:
        #         print("depth", (observations["depth"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["depth"][...,0,0,:]).float().mean())

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            if self.sensor_fusion:
                perception_embed = self.sensor_fuser(perception_embed)
            x = [perception_embed] + x

        x = torch.cat(x, dim=-1)
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks)

        ac_output = ActorCriticOutput(
            distributions=self.actor(x), values=self.critic(x), extras={}
        )

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
예제 #3
0
    def forward(  # type:ignore
        self,
        observations: Dict[str, Union[torch.FloatTensor, Dict[str, Any]]],
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        rnn_out, mem_return = self.state_encoder(
            x=observations[self.input_uuid],
            hidden_states=memory.tensor(self.memory_key),
            masks=masks,
        )

        # noinspection PyCallingNonCallable
        out, _ = self.ac_nonrecurrent_head(
            observations={self.head_uuid: rnn_out},
            memory=None,
            prev_actions=prev_actions,
            masks=masks,
        )

        # noinspection PyArgumentList
        return (
            out,
            memory.set_tensor(self.memory_key, mem_return),
        )
예제 #4
0
    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
        target_encoding = self.get_target_coordinates_encoding(observations)
        x: Union[torch.Tensor, List[torch.Tensor]]
        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            if self.sensor_fusion:
                perception_embed = self.sensor_fuser(perception_embed)
            x = [perception_embed] + x

        x = torch.cat(x, dim=-1)
        x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"),
                                                  masks)

        ac_output = ActorCriticOutput(distributions=self.actor(x),
                                      values=self.critic(x),
                                      extras={})

        return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
예제 #5
0
 def forward(  # type:ignore
     self,
     observations: ObservationType,
     memory: Memory,
     prev_actions: torch.Tensor,
     masks: torch.FloatTensor,
 ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:
     x = self.goal_visual_encoder(observations)
     x, rnn_hidden_states = self.state_encoder(
         x, memory.tensor(self.memory_key), masks)
     return (
         ActorCriticOutput(distributions=self.actor(x),
                           values=self.critic(x),
                           extras={}),
         memory.set_tensor(self.memory_key, rnn_hidden_states),
     )
예제 #6
0
    def __init__(
        self,
        num_steps: int,
        num_samplers: int,
        actor_critic: ActorCriticModel,
        *args,
        **kwargs,
    ):
        self.num_steps = num_steps

        self.flattened_to_unflattened: Dict[str, Dict[str, List[str]]] = {
            "memory": dict(),
            "observations": dict(),
        }
        self.unflattened_to_flattened: Dict[str, Dict[Tuple[str, ...], str]] = {
            "memory": dict(),
            "observations": dict(),
        }

        self.dim_names = ["step", "sampler", "agent", None]

        self.memory: Memory = self.create_memory(
            actor_critic.recurrent_memory_specification, num_samplers
        )
        self.observations: Memory = Memory()

        self.num_agents = getattr(actor_critic, "num_agents", 1)

        self.rewards = torch.zeros(num_steps, num_samplers, self.num_agents, 1,)
        self.value_preds = torch.zeros(num_steps + 1, num_samplers, self.num_agents, 1,)
        self.returns = torch.zeros(num_steps + 1, num_samplers, self.num_agents, 1,)
        self.action_log_probs = torch.zeros(
            num_steps, num_samplers, self.num_agents, 1,
        )

        action_space = actor_critic.action_space

        if action_space.__class__.__name__ == "Discrete":
            action_shape = 1
        else:
            action_shape = action_space.shape[0]

        self.actions = torch.zeros(
            num_steps, num_samplers, self.num_agents, action_shape,
        )
        self.prev_actions = torch.zeros(
            num_steps + 1, num_samplers, self.num_agents, action_shape,
        )

        if action_space.__class__.__name__ == "Discrete":
            self.actions = self.actions.long()
            self.prev_actions = self.prev_actions.long()

        self.masks = torch.ones(num_steps + 1, num_samplers, self.num_agents, 1,)

        self.step = 0

        self.unnarrow_data: DefaultDict[
            str, Union[int, torch.Tensor, Dict]
        ] = defaultdict(dict)