Example #1
0
    def __init__(self, args):
        super().__init__()

        self.embed_dim = args.encoder_embed_dim
        self.self_attn = MultiheadAttention(self.embed_dim,
                                            args.encoder_attention_heads,
                                            dropout=args.attention_dropout)
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.encoder_normalize_before
        self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
        self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
        self.layer_norms = nn.ModuleList(
            [LayerNorm(self.embed_dim) for i in range(2)])
Example #2
0
    def __init__(self, hps, obs_config):
        super(TransformerPolicy7, self).__init__()
        assert obs_config.drones > 0 or obs_config.minerals > 0,\
            'Must have at least one mineral or drones observation'
        assert obs_config.drones >= obs_config.allies
        assert not hps.use_privileged or (
            hps.nmineral > 0 and hps.nally > 0 and
            (hps.nenemy > 0 or hps.ally_enemy_same))

        assert hps.nally == obs_config.allies
        assert hps.nenemy == obs_config.drones - obs_config.allies
        assert hps.nmineral == obs_config.minerals
        assert hps.ntile == obs_config.tiles

        self.version = 'transformer_v7'

        self.kwargs = dict(hps=hps, obs_config=obs_config)

        self.hps = hps
        self.obs_config = obs_config
        self.agents = hps.agents
        self.nally = hps.nally
        self.nenemy = hps.nenemy
        self.nmineral = hps.nmineral
        self.nconstant = hps.nconstant
        self.ntile = hps.ntile
        self.nitem = hps.nally + hps.nenemy + hps.nmineral + hps.nconstant + hps.ntile
        self.fp16 = hps.fp16
        self.d_agent = hps.d_agent
        self.d_item = hps.d_item
        self.naction = hps.objective.naction() + obs_config.extra_actions()

        if hasattr(obs_config, 'global_drones'):
            self.global_drones = obs_config.global_drones
        else:
            self.global_drones = 0

        if hps.norm == 'none':
            norm_fn = lambda x: nn.Sequential()
        elif hps.norm == 'batchnorm':
            norm_fn = lambda n: nn.BatchNorm2d(n)
        elif hps.norm == 'layernorm':
            norm_fn = lambda n: nn.LayerNorm(n)
        else:
            raise Exception(f'Unexpected normalization layer {hps.norm}')

        endglobals = self.obs_config.endglobals()
        endallies = self.obs_config.endallies()
        endenemies = self.obs_config.endenemies()
        endmins = self.obs_config.endmins()
        endtiles = self.obs_config.endtiles()
        endallenemies = self.obs_config.endallenemies()

        self.agent_embedding = ItemBlock(
            obs_config.dstride() + obs_config.global_features(),
            hps.d_agent,
            hps.d_agent * hps.dff_ratio,
            norm_fn,
            True,
            mask_feature=7,  # Feature 7 is hitpoints
        )
        self.relpos_net = ItemBlock(3, hps.d_item // 2,
                                    hps.d_item // 2 * hps.dff_ratio, norm_fn,
                                    hps.item_ff)

        self.item_nets = nn.ModuleList()
        if hps.ally_enemy_same:
            self.item_nets.append(
                PosItemBlock(
                    obs_config.dstride(),
                    hps.d_item // 2,
                    hps.d_item // 2 * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    mask_feature=7,  # Feature 7 is hitpoints
                    count=obs_config.drones,
                    start=endglobals,
                    end=endenemies,
                ))
        else:
            if self.nally > 0:
                self.item_nets.append(
                    PosItemBlock(
                        obs_config.dstride(),
                        hps.d_item // 2,
                        hps.d_item // 2 * hps.dff_ratio,
                        norm_fn,
                        hps.item_ff,
                        mask_feature=7,  # Feature 7 is hitpoints
                        count=obs_config.allies,
                        start=endglobals,
                        end=endallies,
                    ))
            if self.nenemy > 0:
                self.item_nets.append(
                    PosItemBlock(
                        obs_config.dstride(),
                        hps.d_item // 2,
                        hps.d_item // 2 * hps.dff_ratio,
                        norm_fn,
                        hps.item_ff,
                        mask_feature=7,  # Feature 7 is hitpoints
                        count=obs_config.drones - self.obs_config.allies,
                        start=endallies,
                        end=endenemies,
                        start_privileged=endtiles
                        if hps.use_privileged else None,
                        end_privileged=endallenemies
                        if hps.use_privileged else None,
                    ))
        if hps.nmineral > 0:
            self.item_nets.append(
                PosItemBlock(
                    obs_config.mstride(),
                    hps.d_item // 2,
                    hps.d_item // 2 * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    mask_feature=2,  # Feature 2 is size
                    count=obs_config.minerals,
                    start=endenemies,
                    end=endmins,
                ))
        if hps.ntile > 0:
            self.item_nets.append(
                PosItemBlock(
                    obs_config.tstride(),
                    hps.d_item // 2,
                    hps.d_item // 2 * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    mask_feature=2,  # Feature is elapsed since last visited time
                    count=obs_config.tiles,
                    start=endmins,
                    end=endtiles,
                ))
        if hps.nconstant > 0:
            self.constant_items = nn.Parameter(
                torch.normal(0, 1, (hps.nconstant, hps.d_item)))

        if hps.item_item_attn_layers > 0:
            encoder_layer = nn.TransformerEncoderLayer(d_model=hps.d_item,
                                                       nhead=8)
            self.item_item_attn = nn.TransformerEncoder(
                encoder_layer, num_layers=hps.item_item_attn_layers)
        else:
            self.item_item_attn = None

        self.multihead_attention = MultiheadAttention(
            embed_dim=hps.d_agent,
            kdim=hps.d_item,
            vdim=hps.d_item,
            num_heads=hps.nhead,
            dropout=hps.dropout,
        )
        self.linear1 = nn.Linear(hps.d_agent, hps.d_agent * hps.dff_ratio)
        self.linear2 = nn.Linear(hps.d_agent * hps.dff_ratio, hps.d_agent)
        self.norm1 = nn.LayerNorm(hps.d_agent)
        self.norm2 = nn.LayerNorm(hps.d_agent)

        self.map_channels = hps.d_agent // (hps.nm_nrings * hps.nm_nrays)
        map_item_channels = self.map_channels - 2 if self.hps.map_embed_offset else self.map_channels
        self.downscale = nn.Linear(hps.d_item, map_item_channels)
        self.norm_map = norm_fn(map_item_channels)
        self.conv1 = spatial.ZeroPaddedCylindricalConv2d(self.map_channels,
                                                         hps.dff_ratio *
                                                         self.map_channels,
                                                         kernel_size=3)
        self.conv2 = spatial.ZeroPaddedCylindricalConv2d(hps.dff_ratio *
                                                         self.map_channels,
                                                         self.map_channels,
                                                         kernel_size=3)
        self.norm_conv = norm_fn(self.map_channels)

        final_width = hps.d_agent
        if hps.nearby_map:
            final_width += hps.d_agent
        self.final_layer = nn.Sequential(
            nn.Linear(final_width, hps.d_agent * hps.dff_ratio),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(hps.d_agent * hps.dff_ratio, self.naction)
        if hps.small_init_pi:
            self.policy_head.weight.data *= 0.01
            self.policy_head.bias.data.fill_(0.0)

        if hps.use_privileged:
            self.value_head = nn.Linear(
                hps.d_agent * hps.dff_ratio + hps.d_item, 1)
        else:
            self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio, 1)
        if hps.zero_init_vf:
            self.value_head.weight.data.fill_(0.0)
            self.value_head.bias.data.fill_(0.0)

        self.epsilon = 1e-4 if hps.fp16 else 1e-8
Example #3
0
    def __init__(self, hps, obs_config):
        super(TransformerPolicy3, self).__init__()
        assert obs_config.drones > 0 or obs_config.minerals > 0,\
            'Must have at least one mineral or drones observation'
        assert obs_config.drones >= obs_config.allies
        assert not hps.use_privileged or (
            hps.nmineral > 0 and hps.nally > 0 and
            (hps.nenemy > 0 or hps.ally_enemy_same))

        self.version = 'transformer_v3'

        self.kwargs = dict(hps=hps, obs_config=obs_config)

        self.hps = hps
        self.obs_config = obs_config
        self.agents = hps.agents
        self.nally = hps.nally
        self.nenemy = hps.nenemy
        self.nmineral = hps.nmineral
        self.nitem = hps.nally + hps.nenemy + hps.nmineral
        self.fp16 = hps.fp16
        self.d_agent = hps.d_agent
        self.d_item = hps.d_item
        self.naction = hps.objective.naction()

        if hasattr(obs_config, 'global_drones'):
            self.global_drones = obs_config.global_drones
        else:
            self.global_drones = 0

        if hps.norm == 'none':
            norm_fn = lambda x: nn.Sequential()
        elif hps.norm == 'batchnorm':
            norm_fn = lambda n: nn.BatchNorm2d(n)
        elif hps.norm == 'layernorm':
            norm_fn = lambda n: nn.LayerNorm(n)
        else:
            raise Exception(f'Unexpected normalization layer {hps.norm}')

        self.agent_embedding = ItemBlock(
            obs_config.dstride() + obs_config.global_features(),
            hps.d_agent,
            hps.d_agent * hps.dff_ratio,
            norm_fn,
            True,
            keep_abspos=True,
            mask_feature=7,  # Feature 7 is hitpoints
            relpos=False,
        )
        if hps.ally_enemy_same:
            self.drone_net = ItemBlock(
                obs_config.dstride(),
                hps.d_item,
                hps.d_item * hps.dff_ratio,
                norm_fn,
                hps.item_ff,
                keep_abspos=hps.obs_keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=hps.nally + hps.nenemy,
            )
        else:
            self.ally_net = ItemBlock(
                obs_config.dstride(),
                hps.d_item,
                hps.d_item * hps.dff_ratio,
                norm_fn,
                hps.item_ff,
                keep_abspos=hps.obs_keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=hps.nally,
            )
            self.enemy_net = ItemBlock(
                obs_config.dstride(),
                hps.d_item,
                hps.d_item * hps.dff_ratio,
                norm_fn,
                hps.item_ff,
                keep_abspos=hps.obs_keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=hps.nenemy,
            )
        self.mineral_net = ItemBlock(
            obs_config.mstride(),
            hps.d_item,
            hps.d_item * hps.dff_ratio,
            norm_fn,
            hps.item_ff,
            keep_abspos=hps.obs_keep_abspos,
            mask_feature=2,  # Feature 2 is size
            topk=hps.nmineral,
        )

        if hps.use_privileged:
            self.pmineral_net = ItemBlock(
                obs_config.mstride(),
                hps.d_item,
                hps.d_item * hps.dff_ratio,
                norm_fn,
                hps.item_ff,
                keep_abspos=True,
                relpos=False,
                mask_feature=2,
            )
            if hps.ally_enemy_same:
                self.pdrone_net = ItemBlock(
                    obs_config.dstride(),
                    hps.d_item,
                    hps.d_item * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    keep_abspos=True,
                    relpos=False,
                    mask_feature=7,
                )
            else:
                self.pally_net = ItemBlock(
                    obs_config.dstride(),
                    hps.d_item,
                    hps.d_item * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    keep_abspos=True,
                    relpos=False,
                    mask_feature=7,
                )
                self.penemy_net = ItemBlock(
                    obs_config.dstride(),
                    hps.d_item,
                    hps.d_item * hps.dff_ratio,
                    norm_fn,
                    hps.item_ff,
                    keep_abspos=True,
                    relpos=False,
                    mask_feature=7,
                )

        if hps.item_item_attn_layers > 0:
            encoder_layer = nn.TransformerEncoderLayer(d_model=hps.d_item,
                                                       nhead=8)
            self.item_item_attn = nn.TransformerEncoder(
                encoder_layer, num_layers=hps.item_item_attn_layers)
        else:
            self.item_item_attn = None

        self.multihead_attention = MultiheadAttention(
            embed_dim=hps.d_agent,
            kdim=hps.d_item,
            vdim=hps.d_item,
            num_heads=hps.nhead,
            dropout=hps.dropout,
        )
        self.linear1 = nn.Linear(hps.d_agent, hps.d_agent * hps.dff_ratio)
        self.linear2 = nn.Linear(hps.d_agent * hps.dff_ratio, hps.d_agent)
        self.norm1 = nn.LayerNorm(hps.d_agent)
        self.norm2 = nn.LayerNorm(hps.d_agent)

        self.map_channels = hps.d_agent // (hps.nm_nrings * hps.nm_nrays)
        map_item_channels = self.map_channels - 2 if self.hps.map_embed_offset else self.map_channels
        self.downscale = nn.Linear(hps.d_item, map_item_channels)
        self.norm_map = norm_fn(map_item_channels)
        self.conv1 = spatial.ZeroPaddedCylindricalConv2d(self.map_channels,
                                                         hps.dff_ratio *
                                                         self.map_channels,
                                                         kernel_size=3)
        self.conv2 = spatial.ZeroPaddedCylindricalConv2d(hps.dff_ratio *
                                                         self.map_channels,
                                                         self.map_channels,
                                                         kernel_size=3)
        self.norm_conv = norm_fn(self.map_channels)

        final_width = hps.d_agent
        if hps.nearby_map:
            final_width += hps.d_agent
        self.final_layer = nn.Sequential(
            nn.Linear(final_width, hps.d_agent * hps.dff_ratio),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(hps.d_agent * hps.dff_ratio, self.naction)
        if hps.small_init_pi:
            self.policy_head.weight.data *= 0.01
            self.policy_head.bias.data.fill_(0.0)

        if hps.use_privileged:
            self.value_head = nn.Linear(
                hps.d_agent * hps.dff_ratio + 2 * hps.d_item, 1)
        else:
            self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio, 1)
        if hps.zero_init_vf:
            self.value_head.weight.data.fill_(0.0)
            self.value_head.bias.data.fill_(0.0)

        self.epsilon = 1e-4 if hps.fp16 else 1e-8
Example #4
0
 def build_multihead_attention(embed_dim, dropout, num_heads, **kwargs):
     return MultiheadAttention(embed_dim=embed_dim,
                               num_heads=num_heads,
                               dropout=dropout)
Example #5
0
 def build_multihead_attention(*args, **kwargs):
     return MultiheadAttention(*args, **kwargs)
Example #6
0
    def __init__(self,
                 d_agent,
                 d_item,
                 dff_ratio,
                 nhead,
                 dropout,
                 small_init_pi,
                 zero_init_vf,
                 fp16,
                 norm,
                 agents,
                 nally,
                 nenemy,
                 nmineral,
                 obs_config=DEFAULT_OBS_CONFIG,
                 use_privileged=False,
                 nearby_map=False,
                 ring_width=40,
                 nrays=8,
                 nrings=8,
                 map_conv=False,
                 map_conv_kernel_size=3,
                 map_embed_offset=False,
                 item_ff=True,
                 keep_abspos=False,
                 ally_enemy_same=False,
                 naction=8,
                 ):
        super(TransformerPolicy2, self).__init__()
        assert obs_config.drones > 0 or obs_config.minerals > 0,\
            'Must have at least one mineral or drones observation'
        assert obs_config.drones >= obs_config.allies
        assert not use_privileged or (nmineral > 0 and nally > 0 and (nenemy > 0 or ally_enemy_same))

        self.version = 'transformer_v2'

        self.kwargs = dict(
            d_agent=d_agent,
            d_item=d_item,
            dff_ratio=dff_ratio,
            nhead=nhead,
            dropout=dropout,
            small_init_pi=small_init_pi,
            zero_init_vf=zero_init_vf,
            fp16=fp16,
            use_privileged=use_privileged,
            norm=norm,
            obs_config=obs_config,
            agents=agents,
            nally=nally,
            nenemy=nenemy,
            nmineral=nmineral,
            nearby_map=nearby_map,

            ring_width=ring_width,
            nrays=nrays,
            nrings=nrings,
            map_conv=map_conv,
            map_conv_kernel_size=map_conv_kernel_size,
            map_embed_offset=map_embed_offset,
            item_ff=item_ff,
            keep_abspos=keep_abspos,
            ally_enemy_same=ally_enemy_same,
            naction=naction,
        )

        self.obs_config = obs_config
        self.agents = agents
        self.nally = nally
        self.nenemy = nenemy
        self.nmineral = nmineral
        self.nitem = nally + nenemy + nmineral
        if hasattr(obs_config, 'global_drones'):
            self.global_drones = obs_config.global_drones
        else:
            self.global_drones = 0

        self.d_agent = d_agent
        self.d_item = d_item
        self.dff_ratio = dff_ratio
        self.nhead = nhead
        self.dropout = dropout
        self.nearby_map = nearby_map
        self.ring_width = ring_width
        self.nrays = nrays
        self.nrings = nrings
        self.map_conv = map_conv
        self.map_conv_kernel_size = map_conv_kernel_size
        self.map_embed_offset = map_embed_offset
        self.item_ff = item_ff
        self.naction = naction

        self.fp16 = fp16
        self.use_privileged = use_privileged
        self.ally_enemy_same = ally_enemy_same

        if norm == 'none':
            norm_fn = lambda x: nn.Sequential()
        elif norm == 'batchnorm':
            norm_fn = lambda n: nn.BatchNorm2d(n)
        elif norm == 'layernorm':
            norm_fn = lambda n: nn.LayerNorm(n)
        else:
            raise Exception(f'Unexpected normalization layer {norm}')

        self.agent_embedding = ItemBlock(
            DSTRIDE_V2 + GLOBAL_FEATURES_V2, d_agent, d_agent * dff_ratio, norm_fn, True,
            keep_abspos=True,
            mask_feature=7,  # Feature 7 is hitpoints
            relpos=False,
        )
        if ally_enemy_same:
            self.drone_net = ItemBlock(
                DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                keep_abspos=keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=nally+nenemy,
            )
        else:
            self.ally_net = ItemBlock(
                DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                keep_abspos=keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=nally,
            )
            self.enemy_net = ItemBlock(
                DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                keep_abspos=keep_abspos,
                mask_feature=7,  # Feature 7 is hitpoints
                topk=nenemy,
            )
        self.mineral_net = ItemBlock(
            MSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
            keep_abspos=keep_abspos,
            mask_feature=2,  # Feature 2 is size
            topk=nmineral,
        )

        if use_privileged:
            self.pmineral_net = ItemBlock(
                MSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                keep_abspos=True, relpos=False, mask_feature=2,
            )
            if ally_enemy_same:
                self.pdrone_net = ItemBlock(
                    DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                    keep_abspos=True, relpos=False, mask_feature=7,
                )
            else:
                self.pally_net = ItemBlock(
                    DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                    keep_abspos=True, relpos=False, mask_feature=7,
                )
                self.penemy_net = ItemBlock(
                    DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff,
                    keep_abspos=True, relpos=False, mask_feature=7,
                )

        self.multihead_attention = MultiheadAttention(
            embed_dim=d_agent,
            kdim=d_item,
            vdim=d_item,
            num_heads=nhead,
            dropout=dropout,
        )
        self.linear1 = nn.Linear(d_agent, d_agent * dff_ratio)
        self.linear2 = nn.Linear(d_agent * dff_ratio, d_agent)
        self.norm1 = nn.LayerNorm(d_agent)
        self.norm2 = nn.LayerNorm(d_agent)

        self.map_channels = d_agent // (nrings * nrays)
        map_item_channels = self.map_channels - 2 if self.map_embed_offset else self.map_channels
        self.downscale = nn.Linear(d_item, map_item_channels)
        self.norm_map = norm_fn(map_item_channels)
        self.conv1 = spatial.ZeroPaddedCylindricalConv2d(
            self.map_channels, dff_ratio * self.map_channels, kernel_size=map_conv_kernel_size)
        self.conv2 = spatial.ZeroPaddedCylindricalConv2d(
            dff_ratio * self.map_channels, self.map_channels, kernel_size=map_conv_kernel_size)
        self.norm_conv = norm_fn(self.map_channels)

        final_width = d_agent
        if nearby_map:
            final_width += d_agent
        self.final_layer = nn.Sequential(
            nn.Linear(final_width, d_agent * dff_ratio),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(d_agent * dff_ratio, naction)
        if small_init_pi:
            self.policy_head.weight.data *= 0.01
            self.policy_head.bias.data.fill_(0.0)

        if self.use_privileged:
            self.value_head = nn.Linear(d_agent * dff_ratio + 2 * d_item, 1)
        else:
            self.value_head = nn.Linear(d_agent * dff_ratio, 1)
        if zero_init_vf:
            self.value_head.weight.data.fill_(0.0)
            self.value_head.bias.data.fill_(0.0)

        self.epsilon = 1e-4 if fp16 else 1e-8