예제 #1
0
    def forward(self, commands, args, label=None):
        S, G, N = commands.shape
        l = self.label_embedding(label).unsqueeze(0).unsqueeze(0).repeat(1, commands.size(1), 1, 1) if self.cfg.label_condition else None

        if self.cfg.encode_stages == 2:
            visibility_mask, key_visibility_mask = _get_visibility_mask(commands, seq_dim=0), _get_key_visibility_mask(commands, seq_dim=0)

        commands, args, l = _pack_group_batch(commands, args, l)
        padding_mask, key_padding_mask = _get_padding_mask(commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0)
        group_mask = _get_group_mask(commands, seq_dim=0) if self.use_group else None

        src = self.embedding(commands, args, group_mask)

        if self.cfg.model_type == "transformer":
            memory = self.encoder(src, mask=None, src_key_padding_mask=key_padding_mask, memory2=l)

            z = (memory * padding_mask).sum(dim=0, keepdim=True) / padding_mask.sum(dim=0, keepdim=True)
        else:  # "lstm"
            hidden_cell = (src.new_zeros(2, N, self.cfg.d_model // 2),
                           src.new_zeros(2, N, self.cfg.d_model // 2))
            sequence_lengths = padding_mask.sum(dim=0).squeeze(-1)
            x = pack_padded_sequence(src, sequence_lengths, enforce_sorted=False)

            packed_output, _ = self.encoder(x, hidden_cell)

            memory, _ = pad_packed_sequence(packed_output)
            idx = (sequence_lengths - 1).long().view(1, -1, 1).repeat(1, 1, self.cfg.d_model)
            z = memory.gather(dim=0, index=idx)

        z = _unpack_group_batch(N, z)

        if self.cfg.encode_stages == 2:
            src = z.transpose(0, 1)
            src = _pack_group_batch(src)
            l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None

            if not self.cfg.self_match:
                src = self.hierarchical_PE(src)

            memory = self.hierarchical_encoder(src, mask=None, src_key_padding_mask=key_visibility_mask, memory2=l)
            z = (memory * visibility_mask).sum(dim=0, keepdim=True) / visibility_mask.sum(dim=0, keepdim=True)
            z = _unpack_group_batch(N, z)

        return z
예제 #2
0
    def forward(self, z, commands, args, label=None, hierarch_logits=None, return_hierarch=False):
        N = z.size(2)
        l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
        if hierarch_logits is None:
            z = _pack_group_batch(z)

        if self.cfg.decode_stages == 2:
            if hierarch_logits is None:
                src = self.hierarchical_embedding(z)
                out = self.hierarchical_decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
                hierarch_logits, z = self.hierarchical_fcn(out)

            if self.cfg.label_condition: l = l.unsqueeze(0).repeat(1, z.size(1), 1, 1)

            hierarch_logits, z, l = _pack_group_batch(hierarch_logits, z, l)

            if return_hierarch:
                return _unpack_group_batch(N, hierarch_logits, z)

        if self.cfg.pred_mode == "autoregressive":
            S = commands.size(0)
            commands, args = _pack_group_batch(commands, args)

            group_mask = _get_group_mask(commands, seq_dim=0)

            src = self.embedding(commands, args, group_mask)

            if self.cfg.model_type == "transformer":
                key_padding_mask = _get_key_padding_mask(commands, seq_dim=0)
                out = self.decoder(src, z, tgt_mask=self.square_subsequent_mask[:S, :S], tgt_key_padding_mask=key_padding_mask, memory2=l)
            else:  # "lstm"
                hidden_cell = self._get_initial_state(z)  # TODO: reinject intermediate state
                out, _ = self.decoder(src, hidden_cell)

        else:  # "one_shot"
            src = self.embedding(z)
            out = self.decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)

        command_logits, args_logits = self.fcn(out)

        out_logits = (command_logits, args_logits) + ((hierarch_logits,) if self.cfg.decode_stages == 2 else ())

        return _unpack_group_batch(N, *out_logits)
    def forward(self,
                commands,
                args,
                history,
                agents,
                agents_validity,
                label=None):
        S, G, N = commands.shape

        l = self.label_embedding(label).unsqueeze(0).unsqueeze(0).repeat(
            1, commands.size(1), 1, 1) if self.cfg.label_condition else None

        if self.cfg.encode_stages == 2:
            modified = False
            if self.add_mlp_history > 0 or self.add_mlp_agent > 0:
                modified = True
            visibility_mask, key_visibility_mask = _get_visibility_mask(
                commands,
                seq_dim=0,
                modified=modified,
                agents_validity=agents_validity), _get_key_visibility_mask(
                    commands,
                    seq_dim=0,
                    modified=modified,
                    agents_validity=agents_validity)
        commands, args, l = _pack_group_batch(commands, args, l)
        padding_mask, key_padding_mask = _get_padding_mask(
            commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0)

        group_mask = _get_group_mask(commands,
                                     seq_dim=0) if self.use_group else None

        src = self.embedding(commands, args, group_mask)

        if self.cfg.model_type == "transformer":
            memory = self.encoder(src,
                                  mask=None,
                                  src_key_padding_mask=key_padding_mask,
                                  memory2=l)
            z = (memory * padding_mask).sum(
                dim=0, keepdim=True) / padding_mask.sum(dim=0, keepdim=True)
        else:  # "lstm"
            hidden_cell = (src.new_zeros(2, N, self.cfg.d_model // 2),
                           src.new_zeros(2, N, self.cfg.d_model // 2))
            sequence_lengths = padding_mask.sum(dim=0).squeeze(-1)
            x = pack_padded_sequence(src,
                                     sequence_lengths,
                                     enforce_sorted=False)

            packed_output, _ = self.encoder(x, hidden_cell)

            memory, _ = pad_packed_sequence(packed_output)
            idx = (sequence_lengths - 1).long().view(1, -1, 1).repeat(
                1, 1, self.cfg.d_model)
            z = memory.gather(dim=0, index=idx)

        z = _unpack_group_batch(N, z)
        if self.add_mlp_history > 0:
            h = self.history_block(
                self.history_residual(torch.flatten(
                    history, start_dim=1))).unsqueeze(0).unsqueeze(0)
            z = torch.cat((z, h), dim=1)
        if self.add_mlp_agent > 0:
            agents = agents.permute(1, 0, 2, 3)
            for agent in agents:
                a = self.agent_block(
                    self.agent_residual(
                        torch.flatten(agent.type(torch.cuda.FloatTensor),
                                      start_dim=1))).unsqueeze(0).unsqueeze(0)
                z = torch.cat((z, a), dim=1)

        if self.cfg.encode_stages == 2:
            src = z.transpose(0, 1)
            src = _pack_group_batch(src)
            l = self.label_embedding(label).unsqueeze(
                0) if self.cfg.label_condition else None
            if not self.cfg.self_match:
                src = self.hierarchical_PE(src)
            memory = self.hierarchical_encoder(
                src,
                mask=None,
                src_key_padding_mask=key_visibility_mask,
                memory2=l)
            z = (memory * visibility_mask).sum(
                dim=0, keepdim=True) / visibility_mask.sum(dim=0, keepdim=True)
            z = _unpack_group_batch(N, z)

        return z