Ejemplo n.º 1
0
])

origin = torch.tensor([
    [[0., 0.], [0., 1.]],
    [[0., 0.], [100., 10.]],
])
direction = torch.tensor([
    [[0., 1.], [0., -1.]],
    [[0., 1.], [0., -1.]],
])

relpos = relative_positions(origin, direction, positions)
map = spatial_scatter(
    items,
    relpos,
    nray=8,
    nring=5,
    inner_radius=1,
)
print(map.size())
print(map)

assert ((map == torch.tensor([[[[[0., 0., 0., 0., 1., 0., 0., 0.],
                                 [0., 0., 5., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 3., 0., 0., 0., 0.]],
                                [[0., 0., 0., 0., -1., 0., 0., 0.],
                                 [0., 0., -5., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0., 0., 0., 0.],
Ejemplo n.º 2
0
    def latents(self, x, x_privileged):
        if self.fp16:
            # Normalization layers perform fp16 conversion for x after normalization
            x_privileged = x_privileged.half()

        batch_size = x.size()[0]

        endglobals = self.obs_config.endglobals()
        endallies = self.obs_config.endallies()
        endenemies = self.obs_config.endenemies()
        endmins = self.obs_config.endmins()
        endtiles = self.obs_config.endtiles()
        endallenemies = self.obs_config.endallenemies()

        globals = x[:, :endglobals]

        # properties of the drone controlled by this network
        xagent = x[:, endglobals:endallies]\
            .view(batch_size, self.obs_config.allies, self.obs_config.dstride())[:, :self.agents, :]
        globals = globals.view(batch_size, 1, self.obs_config.global_features()) \
            .expand(batch_size, self.agents, self.obs_config.global_features())
        xagent = torch.cat([xagent, globals], dim=2)
        agents, _, mask_agent = self.agent_embedding(xagent)

        origin = xagent[:, :, 0:2].clone()
        direction = xagent[:, :, 2:4].clone()

        if self.hps.ally_enemy_same:
            xdrone = x[:, endglobals:endenemies].view(batch_size, self.obs_config.drones, self.obs_config.dstride())
            items, relpos, mask = self.drone_net(xdrone, origin, direction)
        else:
            xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride())
            items, relpos, mask = self.ally_net(xally, origin, direction)
        # Ensure that at least one item is not masked out to prevent NaN in transformer softmax
        mask[:, :, 0] = 0

        if self.nenemy > 0 and not self.hps.ally_enemy_same:
            eobs = self.obs_config.drones - self.obs_config.allies
            xe = x[:, endallies:endenemies].view(batch_size, eobs, self.obs_config.dstride())

            items_e, relpos_e, mask_e = self.enemy_net(xe, origin, direction)
            items = torch.cat([items, items_e], dim=2)
            mask = torch.cat([mask, mask_e], dim=2)
            relpos = torch.cat([relpos, relpos_e], dim=2)

        if self.nmineral > 0:
            xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride())

            items_m, relpos_m, mask_m = self.mineral_net(xm, origin, direction)
            items = torch.cat([items, items_m], dim=2)
            mask = torch.cat([mask, mask_m], dim=2)
            relpos = torch.cat([relpos, relpos_m], dim=2)

        if self.ntile > 0:
            xt = x[:, endmins:endtiles].view(batch_size, self.obs_config.tiles, self.obs_config.tstride())

            items_t, relpos_t, mask_t = self.tile_net(xt, origin, direction)
            items = torch.cat([items, items_t], dim=2)
            mask = torch.cat([mask, mask_t], dim=2)
            relpos = torch.cat([relpos, relpos_t], dim=2)

        if self.nconstant > 0:
            items_c = self.constant_items\
                .view(1, 1, self.nconstant, self.hps.d_item)\
                .repeat((batch_size, self.agents, 1, 1))
            mask_c = torch.zeros(batch_size, self.agents, self.nconstant).bool().to(x.device)
            items = torch.cat([items, items_c], dim=2)
            mask = torch.cat([mask, mask_c], dim=2)

        if self.hps.use_privileged:
            xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride())
            eobs = self.obs_config.drones - self.obs_config.allies
            xenemy = x[:, endtiles:endallenemies].view(batch_size, eobs, self.obs_config.dstride())
            if self.hps.ally_enemy_same:
                xdrone = torch.cat([xally, xenemy], dim=1)
                pitems, _, pmask = self.pdrone_net(xdrone)
            else:
                pitems, _, pmask = self.pally_net(xally)
                pitems_e, _, pmask_e = self.penemy_net(xenemy)
                pitems = torch.cat([pitems, pitems_e], dim=1)
                pmask = torch.cat([pmask, pmask_e], dim=1)
            xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride())
            pitems_m, _, pmask_m = self.pmineral_net(xm)
            pitems = torch.cat([pitems, pitems_m], dim=1)
            pmask = torch.cat([pmask, pmask_m], dim=1)
            if self.item_item_attn:
                pmask_nonzero = pmask.clone()
                pmask_nonzero[:, 0] = False
                pitems = self.item_item_attn(
                    pitems.permute(1, 0, 2),
                    src_key_padding_mask=pmask_nonzero,
                ).permute(1, 0, 2)
                if (pitems != pitems).sum() > 0:
                    print(pmask)
                    print(pitems)
                    raise Exception("NaN!")
        else:
            pitems = None
            pmask = None

        # Transformer input dimensions are: Sequence length, Batch size, Embedding size
        source = items.view(batch_size * self.agents, self.nitem, self.d_item).permute(1, 0, 2)
        target = agents.view(1, batch_size * self.agents, self.d_agent)
        x, attn_weights = self.multihead_attention(
            query=target,
            key=source,
            value=source,
            key_padding_mask=mask.view(batch_size * self.agents, self.nitem),
        )
        x = self.norm1(x + target)
        x2 = self.linear2(F.relu(self.linear1(x)))
        x = self.norm2(x + x2)
        x = x.view(batch_size, self.agents, self.d_agent)

        if self.hps.nearby_map:
            items = self.norm_map(F.relu(self.downscale(items)))
            items = items * (1 - mask.float().unsqueeze(-1))
            nearby_map = spatial.spatial_scatter(
                items=items[:, :, :(self.nitem - self.nconstant - self.ntile), :],
                positions=relpos[:, :, :self.nitem - self.nconstant - self.ntile],
                nray=self.hps.nm_nrays,
                nring=self.hps.nm_nrings,
                inner_radius=self.hps.nm_ring_width,
                embed_offsets=self.hps.map_embed_offset,
            ).view(batch_size * self.agents, self.map_channels, self.hps.nm_nrings, self.hps.nm_nrays)
            if self.hps.map_conv:
                nearby_map2 = self.conv2(F.relu(self.conv1(nearby_map)))
                nearby_map2 = nearby_map2.permute(0, 3, 2, 1)
                nearby_map = nearby_map.permute(0, 3, 2, 1)
                nearby_map = self.norm_conv(nearby_map + nearby_map2)
            nearby_map = nearby_map.reshape(batch_size, self.agents, self.d_agent)
            x = torch.cat([x, nearby_map], dim=2)

        x = self.final_layer(x)
        x = x.view(batch_size, self.agents, self.d_agent * self.hps.dff_ratio)
        x = x * (~mask_agent).float().unsqueeze(-1)

        return x, (pitems, pmask)
Ejemplo n.º 3
0
    def latents(self, x, x_privileged):
        if self.fp16:
            # Normalization layers perform fp16 conversion for x after normalization
            x_privileged = x_privileged.half()

        batch_size = x.size()[0]

        endglobals = GLOBAL_FEATURES_V2
        endallies = GLOBAL_FEATURES_V2 + DSTRIDE_V2 * self.obs_config.allies
        endenemies = GLOBAL_FEATURES_V2 + DSTRIDE_V2 * self.obs_config.drones
        endmins = endenemies + MSTRIDE_V2 * self.obs_config.minerals
        endallenemies = endmins + DSTRIDE_V2 * (self.obs_config.drones - self.obs_config.allies)

        globals = x[:, :endglobals]

        # properties of the drone controlled by this network
        xagent = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2)[:, :self.agents, :]
        globals = globals.view(batch_size, 1, GLOBAL_FEATURES_V2) \
            .expand(batch_size, self.agents, GLOBAL_FEATURES_V2)
        xagent = torch.cat([xagent, globals], dim=2)
        agents, _, mask_agent = self.agent_embedding(xagent)

        origin = xagent[:, :, 0:2].clone()
        direction = xagent[:, :, 2:4].clone()

        if self.ally_enemy_same:
            xdrone = x[:, endglobals:endenemies].view(batch_size, self.obs_config.drones, DSTRIDE_V2)
            items, relpos, mask = self.drone_net(xdrone, origin, direction)
        else:
            xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2)
            items, relpos, mask = self.ally_net(xally, origin, direction)
        # Ensure that at least one item is not masked out to prevent NaN in transformer softmax
        mask[:, :, 0] = 0

        if self.nenemy > 0 and not self.ally_enemy_same:
            eobs = self.obs_config.drones - self.obs_config.allies
            xe = x[:, endallies:endenemies].view(batch_size, eobs, DSTRIDE_V2)

            items_e, relpos_e, mask_e = self.enemy_net(xe, origin, direction)
            items = torch.cat([items, items_e], dim=2)
            mask = torch.cat([mask, mask_e], dim=2)
            relpos = torch.cat([relpos, relpos_e], dim=2)

        if self.nmineral > 0:
            xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, MSTRIDE_V2)

            items_m, relpos_m, mask_m = self.mineral_net(xm, origin, direction)
            items = torch.cat([items, items_m], dim=2)
            mask = torch.cat([mask, mask_m], dim=2)
            relpos = torch.cat([relpos, relpos_m], dim=2)

        if self.use_privileged:
            # TODO: use hidden enemies
            xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2)
            eobs = self.obs_config.drones - self.obs_config.allies
            xenemy = x[:, endmins:endallenemies].view(batch_size, eobs, DSTRIDE_V2)
            if self.ally_enemy_same:
                xdrone = torch.cat([xally, xenemy], dim=1)
                pitems, _, pmask = self.pdrone_net(xdrone)
            else:
                pitems, _, pmask = self.pally_net(xally)
                pitems_e, _, pmask_e = self.penemy_net(xenemy)
                pitems = torch.cat([pitems, pitems_e], dim=1)
                pmask = torch.cat([pmask, pmask_e], dim=1)
            xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, MSTRIDE_V2)
            pitems_m, _, pmask_m = self.pmineral_net(xm)
            pitems = torch.cat([pitems, pitems_m], dim=1)
            pmask = torch.cat([pmask, pmask_m], dim=1)
        else:
            pitems = None
            pmask = None

        # Transformer input dimensions are: Sequence length, Batch size, Embedding size
        source = items.view(batch_size * self.agents, self.nitem, self.d_item).permute(1, 0, 2)
        target = agents.view(1, batch_size * self.agents, self.d_agent)
        x, attn_weights = self.multihead_attention(
            query=target,
            key=source,
            value=source,
            key_padding_mask=mask.view(batch_size * self.agents, self.nitem),
        )
        x = self.norm1(x + target)
        x2 = self.linear2(F.relu(self.linear1(x)))
        x = self.norm2(x + x2)
        x = x.view(batch_size, self.agents, self.d_agent)

        if self.nearby_map:
            items = self.norm_map(F.relu(self.downscale(items)))
            items = items * (1 - mask.float().unsqueeze(-1))
            nearby_map = spatial.spatial_scatter(
                items=items,
                positions=relpos,
                nray=self.nrays,
                nring=self.nrings,
                inner_radius=self.ring_width,
                embed_offsets=self.map_embed_offset,
            ).view(batch_size * self.agents, self.map_channels, self.nrings, self.nrays)
            if self.map_conv:
                nearby_map2 = self.conv2(F.relu(self.conv1(nearby_map)))
                nearby_map2 = nearby_map2.permute(0, 3, 2, 1)
                nearby_map = nearby_map.permute(0, 3, 2, 1)
                nearby_map = self.norm_conv(nearby_map + nearby_map2)
            nearby_map = nearby_map.reshape(batch_size, self.agents, self.d_agent)
            x = torch.cat([x, nearby_map], dim=2)

        x = self.final_layer(x)
        x = x.view(batch_size, self.agents, self.d_agent * self.dff_ratio)
        x = x * (~mask_agent).float().unsqueeze(-1)

        return x, (pitems, pmask)