コード例 #1
0
ファイル: rl_loss.py プロジェクト: liuruoze/mini-AlphaStar
def change_units_and_logits(behavior_logits, gt_units, select_units_num,
                            entity_nums):
    [batch_size, select_size, units_size] = behavior_logits.shape
    padding = torch.zeros(batch_size,
                          1,
                          units_size,
                          dtype=behavior_logits.dtype,
                          device=behavior_logits.device)
    token = torch.tensor(AHP.max_entities - 1,
                         dtype=torch.long,
                         device=padding.device)

    padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1).float()
    behavior_logits = torch.cat([behavior_logits, padding], dim=1)
    select_units_num = select_units_num.reshape(-1).long()
    entity_nums = entity_nums.reshape(-1).long()

    # print('behavior_logits[0][1]', behavior_logits[0][1]) if 1 else None
    # print('behavior_logits.shape', behavior_logits.shape) if 1 else None

    behavior_logits[torch.arange(batch_size),
                    -1] = (L.tensor_one_hot(entity_nums, units_size).float() -
                           1) * 1e9

    # print('behavior_logits[0][1]', behavior_logits[0][1]) if 1 else None
    # print('behavior_logits[0][-1]', behavior_logits[0][-1]) if 1 else None
    # print('behavior_logits.shape', behavior_logits.shape) if 1 else None

    gt_units = gt_units.long()

    padding = torch.zeros(batch_size,
                          1,
                          1,
                          dtype=gt_units.dtype,
                          device=gt_units.device)
    token = torch.tensor(AHP.max_entities - 1,
                         dtype=padding.dtype,
                         device=padding.device)
    padding[:, 0] = token

    gt_units = torch.cat([gt_units, padding], dim=1)

    # print('gt_units[0][1]', gt_units[0][1]) if 1 else None
    # print('gt_units.shape', gt_units.shape) if 1 else None

    gt_units[torch.arange(batch_size), -1] = entity_nums.unsqueeze(dim=1)

    # print('gt_units[0][1]', gt_units[0][1]) if 1 else None
    # print('gt_units.shape', gt_units.shape) if 1 else None

    # print('stop', stop)

    del padding, token, select_units_num, entity_nums

    gt_units = gt_units.long()

    return behavior_logits, gt_units
コード例 #2
0
    def forward(self, autoregressive_embedding, delay=None):
        # AlphaStar: `autoregressive_embedding` is decoded using a 2-layer (each with size 256) 
        # linear network with ReLUs,
        x = F.relu(self.fc_1(autoregressive_embedding))
        x = F.relu(self.fc_2(x))

        # AlphaStar: before being embedded into `delay_logits` that has size 128 (one for each 
        # possible requested delay in game steps).
        # note: no temperature used here
        delay_logits = self.embed_fc(x)

        # AlphaStar: `delay` is sampled from `delay_logits` using a multinomial, though unlike all other arguments,
        # no temperature is applied to `delay_logits` before sampling.
        if delay is None:
            delay_probs = self.softmax(delay_logits)        
            delay = torch.multinomial(delay_probs, 1)
            del delay_probs

        # AlphaStar: Similar to `action_type`, `delay` is projected to a 1D tensor of size 1024 through 
        # a 2-layer (each with size 256) linear network with ReLUs, and added to `autoregressive_embedding`
        # similar to action_type here, change it to one_hot version
        delay_one_hot = L.tensor_one_hot(delay, self.max_delay)
        delay_one_hot = delay_one_hot.squeeze(-2)
        z = F.relu(self.fc_3(delay_one_hot))
        z = F.relu(self.fc_4(z))
        t = self.project(z)

        # the operation may auto broadcasting, so we need a test
        autoregressive_embedding = autoregressive_embedding + t

        del delay_one_hot, x, z, t

        return delay_logits, delay, autoregressive_embedding
コード例 #3
0
ファイル: queue_head.py プロジェクト: liuruoze/mini-AlphaStar
    def forward(self,
                autoregressive_embedding,
                action_type,
                embedded_entity=None,
                queue=None):
        # AlphaStar: Queued Head is similar to the delay head except a temperature of 0.8
        # AlphaStar: is applied to the logits before sampling,
        x = self.fc_1(autoregressive_embedding)
        x = self.relu(x)
        x = self.fc_2(x)
        x = self.relu(x)

        # note: temperature is used here, compared to delay head
        # AlphaStar: the size of `queued_logits` is 2 (for queueing and not queueing),
        queue_logits = self.embed_fc(x)

        temperature = self.temperature if self.is_rl_training else 1
        queue_logits = queue_logits / temperature

        if queue is None:
            queue_probs = self.softmax(queue_logits)
            queue = torch.multinomial(queue_probs, 1)
            del queue_probs

        # similar to action_type here, change it to one_hot version
        queue_one_hot = L.tensor_one_hot(queue, self.max_queue)

        # to make the dim of queue_one_hot as queue
        queue_one_hot = queue_one_hot.squeeze(-2)

        z = self.relu(self.fc_3(queue_one_hot))
        z = self.relu(self.fc_4(z))
        t = self.project(z)

        # AlphaStar: and the projected `queued` is not added to `autoregressive_embedding`
        # if queuing is not possible for the chosen `action_type`
        # note: projected `queued` is not added to `autoregressive_embedding` if queuing is not
        # possible for the chosen `action_type`
        mask = L.action_can_be_queued_mask(action_type).float()
        autoregressive_embedding = autoregressive_embedding + mask * t
        del queue_one_hot, x, z, t, mask, action_type

        return queue_logits, queue, autoregressive_embedding
コード例 #4
0
    def forward(self, scalar_list):
        [agent_statistics, home_race, away_race, upgrades, enemy_upgrades, time, available_actions, unit_counts_bow,
         mmr, units_buildings, effects, upgrade, beginning_build_order, last_delay, last_action_type,
         last_repeat_queued] = scalar_list

        embedded_scalar_list = []
        scalar_context_list = []

        # agent_statistics: Embedded by taking log(agent_statistics + 1) and passing through a linear of size 64 and a ReLU
        the_log_statistics = torch.log(agent_statistics + 1)
        x = F.relu(self.statistics_fc(the_log_statistics))
        del agent_statistics, the_log_statistics
        embedded_scalar_list.append(x)

        # race: Both races are embedded into a one-hot with maximum 5, and embedded through a linear of size 32 and a ReLU.
        x = F.relu(self.home_race_fc(home_race.float()))
        del home_race
        embedded_scalar_list.append(x)
        # The embedding is also added to `scalar_context`.
        scalar_context_list.append(x)

        # race: Both races are embedded into a one-hot with maximum 5, and embedded through a linear of size 32 and a ReLU.
        x = F.relu(self.away_race_fc(away_race.float()))
        del away_race
        # TODO: During training, the opponent's requested race is hidden in 10% of matches, to simulate playing against the Random race.
        embedded_scalar_list.append(x)
        # The embedding is also added to `scalar_context`.
        scalar_context_list.append(x)
        # TODO: If we don't know the opponent's race (either because they are random or it is hidden), 
        # we add their true race to the observation once we observe one of their units.

        # upgrades: The boolean vector of whether an upgrade is present is embedded through a linear of size 128 and a ReLU
        x = F.relu(self.upgrades_fc(upgrades))
        del upgrades
        embedded_scalar_list.append(x)

        # enemy_upgrades: Embedded the same as upgrades
        x = F.relu(self.enemy_upgrades_fc(enemy_upgrades))
        del enemy_upgrades
        embedded_scalar_list.append(x)

        # time: A transformer positional encoder encoded the time into a 1D tensor of size 64
        # do it in the preprocess
        x = time
        embedded_scalar_list.append(x)

        # available_actions: From `entity_list`, we compute which actions may be available and which can never be available. 
        # For example, the agent controls a Stalker and has researched the Blink upgrade, 
        # then the Blink action may be available (even though in practice it may be on cooldown). 
        # The boolean vector of action availability is passed through a linear of size 64 and a ReLU.
        x = F.relu(self.available_actions_fc(available_actions))
        del available_actions
        embedded_scalar_list.append(x)
        # The embedding is also added to `scalar_context`
        scalar_context_list.append(x)

        # unit_counts_bow: A bag-of-words unit count from `entity_list`. 
        # The unit count vector is embedded by square rooting, passing through a linear layer, and passing through a ReLU
        # note make sure unit_counts_bow all >= 0, otherwise torch.sqrt will produce nan !
        print('unit_counts_bow', unit_counts_bow) if debug else None
        assert (unit_counts_bow >= 0).all()

        unit_counts_bow = torch.sqrt(unit_counts_bow)
        x = F.relu(self.unit_counts_bow_fc(unit_counts_bow))
        del unit_counts_bow
        embedded_scalar_list.append(x)

        # mmr: During supervised learning, this is the MMR of the player we are trying to imitate. Elsewhere, this is fixed at 6200. 
        # MMR is mapped to a one-hot of min(mmr / 1000, 6) with maximum 6, then passed through a linear of size 64 and a ReLU
        x = F.relu(self.mmr_fc(mmr))
        del mmr
        embedded_scalar_list.append(x)

        # cumulative_statistics: The cumulative statistics (including units, buildings, effects, and upgrades) are preprocessed 
        # into a boolean vector of whether or not statistic is present in a human game. 
        # That vector is split into 3 sub-vectors of units/buildings, effects, and upgrades, 
        # and each subvector is passed through a linear of size 32 and a ReLU, and concatenated together.
        # The embedding is also added to `scalar_context`
        x = F.relu(self.units_buildings_fc(units_buildings))
        del units_buildings
        embedded_scalar_list.append(x)
        scalar_context_list.append(x)

        x = F.relu(self.effects_fc(effects))
        del effects
        embedded_scalar_list.append(x)
        scalar_context_list.append(x)

        x = F.relu(self.upgrade_fc(upgrade))
        del upgrade
        embedded_scalar_list.append(x)
        scalar_context_list.append(x)

        # beginning_build_order: The first 20 constructed entities are converted to a 2D tensor of size 
        # [20, num_entity_types], concatenated with indices and the binary encodings 
        # (as in the Entity Encoder) of where entities were constructed (if applicable). 
        # The concatenation is passed through a transformer similar to the one in the entity encoder, 
        # but with keys, queries, and values of 8 and with a MLP hidden size of 32. 
        # The embedding is also added to `scalar_context`.
        print("beginning_build_order:", beginning_build_order) if debug else None
        print("beginning_build_order.shape:", beginning_build_order.shape) if debug else None

        batch_size = beginning_build_order.shape[0]

        seq = torch.arange(SCHP.count_beginning_build_order)
        seq = L.tensor_one_hot(seq, SCHP.count_beginning_build_order)
        seq = seq.unsqueeze(0).repeat(batch_size, 1, 1).to(beginning_build_order.device)

        bo_sum = beginning_build_order.sum(dim=-1, keepdim=False)
        bo_sum = bo_sum.sum(dim=-1, keepdim=False)
        bo_sum = bo_sum.unsqueeze(1)
        bo_sum = bo_sum.repeat(1, SCHP.count_beginning_build_order)

        mask = torch.arange(SCHP.count_beginning_build_order)
        mask = mask.unsqueeze(0).repeat(batch_size, 1).to(bo_sum.device)
        mask = mask < bo_sum
        mask = mask.unsqueeze(2).repeat(1, 1, SCHP.count_beginning_build_order)

        # add the seq info, referenced by the processing way of DI-star
        x = torch.cat([beginning_build_order, seq], dim=2)
        x = self.before_beginning_build_order(x)

        # like in entity encoder, we add a sequence mask
        x = self.beginning_build_order_transformer(x, mask=mask)
        x = x.reshape(x.shape[0], SCHP.count_beginning_build_order * self.build_order_model_size)

        embedded_scalar_list.append(x)
        scalar_context_list.append(x)
        del mask, bo_sum, seq

        # last_delay: The delay between when we last acted and the current observation, in game steps. 
        # This may be different from what we requested due to network latency or APM limits. 
        # It is encoded into a one-hot with maximum 128 and passed through a linear of size 64 and a ReLU
        x = F.relu(self.last_delay_fc(last_delay))
        del last_delay
        embedded_scalar_list.append(x)

        # last_action_type: The last action type is encoded into a one-hot with maximum equal 
        # to the number of possible actions, and passed through a linear of size 128 and a ReLU
        x = F.relu(self.last_action_type_fc(last_action_type))
        del last_action_type
        embedded_scalar_list.append(x)

        # last_repeat_queued: Some other action arguments (queued and repeat) are one-hots with 
        # maximum equal to the number of possible values for those arguments, 
        # and jointly passed through a linear of size 256 and ReLU
        x = F.relu(self.last_repeat_queued_fc(last_repeat_queued))
        del last_repeat_queued
        embedded_scalar_list.append(x)

        embedded_scalar = torch.cat(embedded_scalar_list, dim=1)
        embedded_scalar_out = F.relu(self.fc_1(embedded_scalar))

        scalar_context = torch.cat(scalar_context_list, dim=1)
        scalar_context_out = F.relu(self.fc_2(scalar_context))

        del x, embedded_scalar_list, scalar_context_list, embedded_scalar, scalar_context

        return embedded_scalar_out, scalar_context_out
コード例 #5
0
    def forward(self,
                lstm_output,
                scalar_context,
                action_type_mask=None,
                action_type=None):
        batch_size = lstm_output.shape[0]

        # AlphaStar: The action type head embeds `lstm_output` into a 1D tensor of size 256
        x = self.embed_fc(lstm_output)

        # AlphaStar: passes it through 16 ResBlocks with layer normalization each of size 256, and applies a ReLU.
        # QUESTION: There is no map, how to use resblocks?
        # ANSWER: USE resblock1D
        # input shape is [batch_size x seq_size x embedding_size]
        # note that embedding_size is equal to channel_size in conv1d
        # we change this to [batch_size x embedding_size x seq_size]
        x = x.unsqueeze(-1)
        for resblock in self.resblock_stack:
            x = resblock(x)
        x = F.relu(x)
        x = x.squeeze(-1)

        # AlphaStar: The output is converted to a tensor with one logit for each possible
        # action type through a `GLU` gated by `scalar_context`.
        action_type_logits = self.glu_1(x, scalar_context)

        # inspired by the DI-star project, in action_type_head
        if self.is_rl_training and self.use_action_type_mask and action_type_mask is not None:
            action_type_mask = action_type_mask.bool()
            if action_type is not None:
                for i, a in enumerate(action_type):
                    action_type_mask[i, a.item()] = True
            action_type_logits = action_type_logits + (~action_type_mask *
                                                       (-1e9))
            del action_type_mask

        print('action_type_logits:', action_type_logits) if debug else None
        print('action_type_logits.shape:',
              action_type_logits.shape) if debug else None

        # AlphaStar: `action_type` is sampled from these logits using a multinomial with temperature 0.8.
        # Note that during supervised learning, `action_type` will be the ground truth human action
        # type, and temperature is 1.0 (and similarly for all other arguments).
        temperature = self.temperature if self.is_rl_training else 1
        action_type_logits = action_type_logits / temperature
        print('action_type_logits:', action_type_logits) if debug else None
        print('action_type_logits.shape:',
              action_type_logits.shape) if debug else None

        # note, torch.multinomial need samples to non-negative, finite and have a non-zero sum
        # which is different with tf.multinomial which can accept negative values like log(action_type_probs)
        if action_type is None:
            action_type_probs = self.softmax(action_type_logits)
            action_type_probs = action_type_probs.reshape(batch_size, -1)
            print('action_type_probs:', action_type_probs) if debug else None
            print('action_type_probs.shape:',
                  action_type_probs.shape) if debug else None

            action_type = torch.multinomial(action_type_probs, 1)
            action_type = action_type.reshape(batch_size, -1)
            del action_type_probs

        # change action_type to one_hot version
        action_type_one_hot = L.tensor_one_hot(action_type,
                                               self.max_action_num)
        action_type_one_hot = action_type_one_hot.squeeze(-2)

        # AlphaStar: `autoregressive_embedding` is then generated by first applying a ReLU
        # and linear layer of size 256 to the one-hot version of `action_type`
        z = F.relu(self.fc_1(action_type_one_hot))

        # AlphaStar: and projecting it to a 1D tensor of size 1024 through a `GLU` gated by `scalar_context`.
        z = self.glu_2(z, scalar_context)

        # AlphaStar: That projection is added to another projection of `lstm_output` into a 1D tensor of size
        # 1024 gated by `scalar_context` to yield `autoregressive_embedding`.
        t = self.glu_3(lstm_output, scalar_context)

        # the add operation may auto broadcasting, so we need an assert test
        autoregressive_embedding = z + t

        del action_type_one_hot, lstm_output, scalar_context, x, z, t

        return action_type_logits, action_type, autoregressive_embedding
コード例 #6
0
    def pre_forward(self, various_observations):
        [
            agent_statistics, upgrades, unit_counts_bow, units_buildings,
            effects, upgrade, beginning_build_order, cumulative_score
        ] = various_observations

        # These features are all concatenated together to yield `action_type_input`,
        # passed through a linear of size 256, then passed through 16 ResBlocks with 256 hidden units
        # and layer normalization, passed through a ReLU, then passed through
        # a linear with 1 hidden unit.

        device = next(self.parameters()).device
        embedded_scalar_list = []

        # agent_statistics: Embedded by taking log(agent_statistics + 1) and passing through a linear of size 64 and a ReLU
        the_log_statistics = torch.log(agent_statistics + 1)
        x = F.relu(self.statistics_fc(the_log_statistics))
        del agent_statistics, the_log_statistics
        embedded_scalar_list.append(x)

        score_log_statistics = torch.log(cumulative_score + 1)
        x = F.relu(self.cumulatscore_fc(score_log_statistics))
        del cumulative_score, score_log_statistics
        embedded_scalar_list.append(x)

        # upgrades: The boolean vector of whether an upgrade is present is embedded through a linear of size 128 and a ReLU
        x = F.relu(self.upgrades_fc(upgrades))
        del upgrades
        embedded_scalar_list.append(x)

        # unit_counts_bow: A bag-of-words unit count from `entity_list`.
        # The unit count vector is embedded by square rooting, passing through a linear layer, and passing through a ReLU
        x = F.relu(self.unit_counts_bow_fc(unit_counts_bow))
        del unit_counts_bow
        embedded_scalar_list.append(x)

        # cumulative_statistics: The cumulative statistics (including units, buildings, effects, and upgrades) are preprocessed
        # into a boolean vector of whether or not statistic is present in a human game.
        # That vector is split into 3 sub-vectors of units/buildings, effects, and upgrades,
        # and each subvector is passed through a linear of size 32 and a ReLU, and concatenated together.
        # The embedding is also added to `scalar_context`

        cumulative_statistics = []  # it is different in different baseline
        if self.baseline_type == "winloss" or self.baseline_type == "build_order" or self.baseline_type == "built_units":
            x = F.relu(self.units_buildings_fc(units_buildings))
            cumulative_statistics.append(x)
        if self.baseline_type == "effects" or self.baseline_type == "build_order":
            x = F.relu(self.effects_fc(effects))
            cumulative_statistics.append(x)
        if self.baseline_type == "upgrades" or self.baseline_type == "build_order":
            x = F.relu(self.upgrade_fc(upgrade))
            cumulative_statistics.append(x)

        embedded_scalar_list.extend(cumulative_statistics)
        del units_buildings, effects, upgrade, cumulative_statistics

        # beginning_build_order: The first 20 constructed entities are converted to a 2D tensor of size
        # [20, num_entity_types], concatenated with indices and the binary encodings
        # (as in the Entity Encoder) of where entities were constructed (if applicable).
        # The concatenation is passed through a transformer similar to the one in the entity encoder,
        # but with keys, queries, and values of 8 and with a MLP hidden size of 32.
        # The embedding is also added to `scalar_context`.
        print("beginning_build_order:",
              beginning_build_order) if debug else None
        print("beginning_build_order.shape:",
              beginning_build_order.shape) if debug else None

        batch_size = beginning_build_order.shape[0]

        seq = torch.arange(SCHP.count_beginning_build_order)
        seq = L.tensor_one_hot(seq, SCHP.count_beginning_build_order)
        seq = seq.unsqueeze(0).repeat(batch_size, 1,
                                      1).to(beginning_build_order.device)

        bo_sum = beginning_build_order.sum(dim=-1, keepdim=False)
        bo_sum = bo_sum.sum(dim=-1, keepdim=False)
        bo_sum = bo_sum.unsqueeze(1)
        bo_sum = bo_sum.repeat(1, SCHP.count_beginning_build_order)

        mask = torch.arange(SCHP.count_beginning_build_order)
        mask = mask.unsqueeze(0).repeat(batch_size, 1).to(bo_sum.device)
        mask = mask < bo_sum
        mask = mask.unsqueeze(2).repeat(1, 1, SCHP.count_beginning_build_order)

        # add the seq info, referenced by the processing way of DI-star
        x = torch.cat([beginning_build_order, seq], dim=2)
        x = self.before_beginning_build_order(x)

        # like in entity encoder, we add a sequence mask
        x = self.beginning_build_order_transformer(x, mask=mask)
        x = x.reshape(
            x.shape[0],
            SCHP.count_beginning_build_order * self.build_order_model_size)

        embedded_scalar_list.append(x)
        embedded_scalar = torch.cat(embedded_scalar_list, dim=1)

        del x, mask, bo_sum, seq, embedded_scalar_list

        return embedded_scalar
コード例 #7
0
ファイル: arch_model.py プロジェクト: liuruoze/mini-AlphaStar
    def mimic_forward(self,
                      state,
                      gt_action,
                      gt_select_units_num,
                      gt_is_one_hot=True,
                      multi_gpu_supvised_learning=False,
                      batch_size=None,
                      sequence_length=None,
                      hidden_state=None,
                      baseline_state=None,
                      baseline_opponent_state=None,
                      show=False,
                      obs_list=None):
        '''
        # inspired by the DI-star project
        # injected the args of ground truth into the forward calculation
        # to make sure the forward follow the right direction
        '''

        # shapes of embedded_entity, embedded_spatial, embedded_scalar are all [batch_size x embedded_size]
        entity_embeddings, embedded_entity, entity_nums, unit_types_one = self.entity_encoder(
            state.entity_state, return_unit_types=True)
        map_skip, embedded_spatial = self.spatial_encoder(
            state.map_state, entity_embeddings)
        embedded_scalar, scalar_context = self.scalar_encoder(
            state.statistical_state)

        available_actions_id = 6  # available_actions is at position 6
        available_actions = state.statistical_state[available_actions_id]

        del state

        lstm_output, hidden_state = self.core(embedded_scalar, embedded_entity,
                                              embedded_spatial, batch_size,
                                              sequence_length, hidden_state)

        del embedded_scalar, embedded_spatial

        if gt_is_one_hot:
            # TODO: remove these unzero
            gt_action_type = torch.nonzero(gt_action.action_type.long(),
                                           as_tuple=True)[-1].unsqueeze(dim=1)
            print('gt_action_type.shape',
                  gt_action_type.shape) if debug else None
            gt_delay = torch.nonzero(gt_action.delay.long(),
                                     as_tuple=True)[-1].unsqueeze(dim=1)
            print('gt_delay.shape', gt_delay.shape) if debug else None
            gt_queue = torch.nonzero(gt_action.queue.long(),
                                     as_tuple=True)[-1].unsqueeze(dim=1)
            print('gt_queue.shape', gt_queue.shape) if debug else None

            gt_units = gt_action.units.long()
            batch_size = gt_units.shape[0]
            select_size = gt_units.shape[1]
            units_size = gt_units.shape[2]

            padding = torch.zeros(batch_size,
                                  1,
                                  units_size,
                                  dtype=gt_units.dtype,
                                  device=gt_units.device)
            token = torch.tensor(AHP.max_entities - 1,
                                 dtype=padding.dtype,
                                 device=padding.device)
            padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1)
            gt_units = torch.cat([gt_units, padding], dim=1)
            print('gt_units', gt_units) if debug else None
            print('gt_units.shape', gt_units.shape) if debug else None
            gt_units[torch.arange(batch_size),
                     gt_select_units_num] = L.tensor_one_hot(
                         entity_nums, units_size).long()

            gt_units = gt_units.reshape(-1, units_size)
            print('gt_units.shape', gt_units.shape) if debug else None

            gt_units = torch.nonzero(gt_units, as_tuple=True)[-1]
            gt_units = gt_units.reshape(batch_size, -1, 1)
            print('gt_units', gt_units) if debug else None
            print('gt_units.shape', gt_units.shape) if debug else None

            gt_target_unit = gt_action.target_unit.long()
            gt_target_unit = gt_target_unit.reshape(-1,
                                                    gt_target_unit.shape[-1])
            gt_target_unit = torch.nonzero(gt_target_unit, as_tuple=True)[-1]
            print('gt_target_unit.shape',
                  gt_target_unit.shape) if debug else None
            gt_target_unit = gt_target_unit.reshape(batch_size, 1, 1)
        else:
            gt_action_type = gt_action.action_type
            gt_delay = gt_action.delay
            gt_queue = gt_action.queue

            gt_units = gt_action.units
            padding = torch.zeros(batch_size,
                                  1,
                                  1,
                                  dtype=gt_units.dtype,
                                  device=gt_units.device)
            token = torch.tensor(AHP.max_entities - 1,
                                 dtype=padding.dtype,
                                 device=padding.device)
            padding[:, 0] = token

            gt_units = torch.cat([gt_units, padding], dim=1)
            del padding, token

            print('gt_select_units_num',
                  gt_select_units_num) if debug else None
            print('gt_units', gt_units) if debug else None
            print('gt_units.shape', gt_units.shape) if debug else None

            gt_units[torch.arange(batch_size),
                     gt_select_units_num] = entity_nums.unsqueeze(dim=1)
            print('gt_units', gt_units) if debug else None
            print('gt_units.shape', gt_units.shape) if debug else None

            gt_target_unit = gt_action.target_unit

        action_type_logits, action_type, autoregressive_embedding = self.action_type_head(
            lstm_output, scalar_context, available_actions, gt_action_type)

        if obs_list is not None:
            unit_type_entity_mask = L.get_batch_unit_type_mask(
                action_type.squeeze(dim=1), obs_list)
            unit_type_entity_mask = torch.tensor(unit_type_entity_mask,
                                                 dtype=torch.bool,
                                                 device=action_type.device)
        else:
            unit_type_entity_mask = None

        delay_logits, _, autoregressive_embedding = self.delay_head(
            autoregressive_embedding, gt_delay)
        queue_logits, _, autoregressive_embedding = self.queue_head(
            autoregressive_embedding, gt_action_type, embedded_entity,
            gt_queue)

        # selected_units_head is special, we use forward_sl function
        print('gt_units', gt_units) if show else None
        print('gt_units.shape', gt_units.shape) if show else None

        units_logits, units, autoregressive_embedding, select_units_num = self.selected_units_head.mimic_forward(
            autoregressive_embedding,
            gt_action_type,
            entity_embeddings,
            entity_nums,
            gt_units,
            gt_select_units_num,
            show=show,
            unit_type_entity_mask=unit_type_entity_mask)

        print('units_logits', units_logits) if show else None
        print('units_logits.shape', units_logits.shape) if show else None

        target_unit_logits, target_unit = self.target_unit_head(
            autoregressive_embedding, gt_action_type, entity_embeddings,
            entity_nums, gt_target_unit)
        target_location_logits, target_location = self.location_head(
            autoregressive_embedding, gt_action_type, map_skip)

        action_logits = ArgsActionLogits(
            action_type=action_type_logits,
            delay=delay_logits,
            queue=queue_logits,
            units=units_logits,
            target_unit=target_unit_logits,
            target_location=target_location_logits)

        del action_logits, unit_type_entity_mask
        del gt_action, gt_action_type, gt_delay, gt_queue, entity_embeddings, gt_units, gt_select_units_num, autoregressive_embedding
        del map_skip, gt_target_unit, embedded_entity, scalar_context, available_actions

        if multi_gpu_supvised_learning:
            return action_type, entity_nums, units, target_unit, target_location, action_type_logits, delay_logits, queue_logits, \
                units_logits, target_unit_logits, target_location_logits, select_units_num, hidden_state, unit_types_one

        elif baseline_state is not None:
            winloss_baseline_value = self.winloss_baseline.forward(
                lstm_output, baseline_state, baseline_opponent_state)
            build_order_baseline_value = self.build_order_baseline.forward(
                lstm_output, baseline_state, baseline_opponent_state)
            built_units_baseline_value = self.built_units_baseline.forward(
                lstm_output, baseline_state, baseline_opponent_state)
            upgrades_baseline_value = self.upgrades_baseline.forward(
                lstm_output, baseline_state, baseline_opponent_state)
            effects_baseline_value = self.effects_baseline.forward(
                lstm_output, baseline_state, baseline_opponent_state)

            del lstm_output, baseline_state, baseline_opponent_state

            baseline_value_list = [
                winloss_baseline_value, build_order_baseline_value,
                built_units_baseline_value, upgrades_baseline_value,
                effects_baseline_value
            ]

            del winloss_baseline_value, build_order_baseline_value, built_units_baseline_value
            del upgrades_baseline_value, effects_baseline_value
        else:
            baseline_value_list = []

        return baseline_value_list, action_type, entity_nums, units, target_unit, target_location, action_type_logits, delay_logits, queue_logits, \
            units_logits, target_unit_logits, target_location_logits, select_units_num, hidden_state, unit_types_one
コード例 #8
0
    def forward(self,
                autoregressive_embedding,
                action_type,
                entity_embeddings,
                entity_num,
                unit_type_entity_mask=None):
        '''
        Inputs:
            autoregressive_embedding: [batch_size x autoregressive_embedding_size]
            action_type: [batch_size x 1]
            entity_embeddings: [batch_size x entity_size x embedding_size]
            entity_num: [batch_size]
        Output:
            units_logits: [batch_size x max_selected x entity_size]
            units: [batch_size x max_selected x 1]
            autoregressive_embedding: [batch_size x autoregressive_embedding_size]
        '''
        batch_size = entity_embeddings.shape[0]
        entity_size = entity_embeddings.shape[-2]
        device = next(self.parameters()).device
        key_size = self.new_variable.shape[0]
        original_ae = autoregressive_embedding

        # AlphaStar: If applicable, Selected Units Head first determines which entity types can accept `action_type`,
        # creates a one-hot of that type with maximum equal to the number of unit types,
        # and passes it through a linear of size 256 and a ReLU. This will be referred to in this head as `func_embed`.
        # QUESTION: one unit type or serveral unit types?
        # ANSWER: serveral unit types, each for one-hot
        # This is some places which introduce much human knowledge
        unit_types_one_hot = L.action_can_apply_to_selected_mask(
            action_type).to(device)

        # the_func_embed shape: [batch_size x 256]
        the_func_embed = F.relu(self.func_embed(unit_types_one_hot))
        del unit_types_one_hot

        # AlphaStar: It also computes a mask of which units can be selected, initialised to allow selecting all entities
        # that exist (including enemy units).
        # generate the length mask for all entities
        mask = torch.arange(entity_size, device=device).float()
        mask = mask.repeat(batch_size, 1)

        # now the entity nums should be added 1 (including the EOF)
        # this is because we also want to compute the mean including key value of the EOF
        added_entity_num = entity_num + 1

        # mask: [batch_size, entity_size]
        mask = mask < added_entity_num.unsqueeze(dim=1)
        assert mask.dtype == torch.bool

        # AlphaStar: It then computes a key corresponding to each entity by feeding `entity_embeddings`
        # through a 1D convolution with 32 channels and kernel size 1,
        # and creates a new variable corresponding to ending unit selection.
        # input: [batch_size x entity_size x embedding_size]
        # output: [batch_size x entity_size x key_size], note key_size = 32
        key = self.conv_1(entity_embeddings.transpose(-1,
                                                      -2)).transpose(-1, -2)

        # end index should be the same to the entity_num
        end_index = entity_num

        # replace the EOF with the new_variable
        # use calculation to achieve it
        if False:
            key[torch.arange(batch_size), end_index] = self.new_variable
        else:
            padding_end = torch.zeros(key.shape[0],
                                      1,
                                      key.shape[2],
                                      dtype=key.dtype,
                                      device=key.device)
            key = torch.cat([key[:, :-1, :], padding_end], dim=1)

            flag = torch.ones(key.shape, dtype=torch.bool, device=key.device)
            flag[torch.arange(batch_size), end_index] = False

            # [batch_size, entity_size, key_size]
            end_embedding = torch.ones(
                key.shape, dtype=key.dtype,
                device=key.device) * self.new_variable.reshape(1, -1)
            key_end_part = end_embedding * ~flag

            # use calculation to replace new_variable
            key_main_part = key * flag
            key = key_main_part + key_end_part

            del padding_end, flag, end_embedding, key_main_part, key_end_part

        # calculate the average of keys (consider the entity_num)
        key_mask = mask.unsqueeze(dim=2).repeat(1, 1, key.shape[-1])
        key_avg = torch.sum(key * key_mask, dim=1) / entity_num.reshape(
            batch_size, 1)
        del key_mask

        # creates a new variable corresponding to ending unit selection.
        # QUESTION: how to do that?
        # ANSWER: referred by the DI-star project, please see self.new_variable in init() method
        units_logits = []
        units = []
        hidden = None

        # referneced by DI-star
        # represented which sample in the batch has end the selection
        # note is_end should be bool type to make sure it is a right whether mask
        is_end = torch.zeros(batch_size, device=device).bool()

        # in the first selection, we should not select the end_index
        mask[torch.arange(batch_size), end_index] = False

        # if we stop selection early, we should record in each sample we select how many items
        select_units_num = torch.ones(
            batch_size, dtype=torch.long, device=device) * self.max_selected

        # AlphaStar: repeated for selecting up to 64 units
        for i in range(self.max_selected):
            if i == 1:
                mask[
                    torch.arange(batch_size),
                    end_index] = True  # in the second selection, we can select the EOF
                if self.is_rl_training and unit_type_entity_mask is not None:
                    unit_type_entity_mask[torch.arange(batch_size),
                                          end_index] = True

            x = self.fc_1(autoregressive_embedding)
            x = self.fc_2(F.relu(x + the_func_embed)).unsqueeze(dim=1)

            query, hidden = self.small_lstm(x, hidden)
            y = torch.sum(query * key, dim=-1)

            entity_logits = y.masked_fill(~mask, -1e9)
            if self.is_rl_training and self.use_unit_type_entity_mask and unit_type_entity_mask is not None:
                entity_logits = entity_logits.masked_fill(
                    ~unit_type_entity_mask, -1e9)

            temperature = self.temperature if self.is_rl_training else 1
            entity_logits = entity_logits / temperature
            del x, y, query

            entity_probs = self.softmax(entity_logits)
            entity_id = torch.multinomial(entity_probs, 1)

            units_logits.append(entity_logits.unsqueeze(-2))
            units.append(entity_id.unsqueeze(-2))

            mask[torch.arange(
                batch_size
            ), entity_id.squeeze(
                dim=1
            )] = False  # masked out so that it cannot be selected in future iterations.

            last_index = (entity_id.squeeze(dim=1) == end_index)
            is_end[last_index] = 1

            # we record how many items we select in a sample
            # we select i + 1 items, but this include the EOF, so actually items should be i + 1 - 1
            select_units_num[last_index] = i

            # AlphaStar: The one-hot position of the selected entity is multiplied by the keys,
            # reduced by the mean across the entities, passed through a linear layer of size 1024,
            # and added to `autoregressive_embedding` for subsequent iterations.
            entity_one_hot = L.tensor_one_hot(entity_id,
                                              entity_size).squeeze(-2)
            entity_one_hot_unsqueeze = entity_one_hot.unsqueeze(-2)

            out = torch.bmm(entity_one_hot_unsqueeze, key).squeeze(-2)
            out = out - key_avg
            t = self.project(out)
            autoregressive_embedding = autoregressive_embedding + t * ~is_end.unsqueeze(
                dim=1)

            del temperature, entity_logits, entity_probs, entity_id
            del last_index, entity_one_hot, entity_one_hot_unsqueeze, out, t

            if is_end.all():
                break

        # units_logits: [batch_size x select_units x entity_size]
        units_logits = torch.cat(units_logits, dim=1)

        # units: [batch_size x select_units x 1]
        units = torch.cat(units, dim=1)

        # we use zero padding to make units_logits has the size of [batch_size x max_selected x entity_size]
        # TODO: change the padding
        padding_size = self.max_selected - units_logits.shape[1]
        if padding_size > 0:
            pad_units_logits = torch.ones(units_logits.shape[0],
                                          padding_size,
                                          units_logits.shape[2],
                                          dtype=units_logits.dtype,
                                          device=units_logits.device) * (-1e9)
            units_logits = torch.cat([units_logits, pad_units_logits], dim=1)

            pad_units = torch.zeros(units.shape[0],
                                    padding_size,
                                    units.shape[2],
                                    dtype=units.dtype,
                                    device=units.device)
            pad_units[:, :, 0] = entity_size - 1  # None index, the same as -1
            units = torch.cat([units, pad_units], dim=1)
            del pad_units, pad_units_logits

        # AlphaStar: If `action_type` does not involve selecting units, this head is ignored.

        # select_unit_mask: [batch_size x 1]
        # note select_unit_mask should be bool type to make sure it is a right whether mask
        select_unit_mask = L.action_involve_selecting_units_mask(
            action_type).bool()

        no_select_units_index = ~select_unit_mask.squeeze(dim=1)
        print("no_select_units_index:",
              no_select_units_index) if debug else None

        select_units_num[no_select_units_index] = 0
        autoregressive_embedding[no_select_units_index] = original_ae[
            no_select_units_index]

        units_logits[no_select_units_index] = -1e9  # a magic number
        units[no_select_units_index, :,
              0] = entity_size - 1  # None index, the same as -1

        print("select_units_num:", select_units_num) if debug else None
        print("autoregressive_embedding:",
              autoregressive_embedding) if debug else None

        del select_unit_mask, no_select_units_index, mask, is_end, key, key_avg

        return units_logits, units, autoregressive_embedding, select_units_num
コード例 #9
0
    def mimic_forward(self,
                      autoregressive_embedding,
                      action_type,
                      entity_embeddings,
                      entity_num,
                      units,
                      select_units_num,
                      show=False,
                      unit_type_entity_mask=None):
        '''
        Inputs:
            autoregressive_embedding: [batch_size x autoregressive_embedding_size]
            action_type: [batch_size x 1]
            entity_embeddings: [batch_size x entity_size x embedding_size]
            entity_num: [batch_size]
        Output:
            units_logits: [batch_size x max_selected x entity_size]
            units: [batch_size x max_selected x 1]
            autoregressive_embedding: [batch_size x autoregressive_embedding_size]
        '''
        batch_size = entity_embeddings.shape[0]
        entity_size = entity_embeddings.shape[-2]
        device = next(self.parameters()).device
        key_size = self.new_variable.shape[0]
        original_ae = autoregressive_embedding

        # AlphaStar: If applicable, Selected Units Head first determines which entity types can accept `action_type`,
        # creates a one-hot of that type with maximum equal to the number of unit types,
        # and passes it through a linear of size 256 and a ReLU. This will be referred to in this head as `func_embed`.
        # QUESTION: one unit type or serveral unit types?
        # ANSWER: serveral unit types, each for one-hot
        # This is some places which introduce much human knowledge
        unit_types_one_hot = L.action_can_apply_to_selected_mask(
            action_type).to(device)

        # the_func_embed shape: [batch_size x 256]
        the_func_embed = F.relu(self.func_embed(unit_types_one_hot))
        del unit_types_one_hot

        # AlphaStar: It also computes a mask of which units can be selected, initialised to allow selecting all entities
        # that exist (including enemy units).
        # generate the length mask for all entities
        mask = torch.arange(entity_size, device=device).float()
        mask = mask.repeat(batch_size, 1)

        # now the entity nums should be added 1 (including the EOF)
        # this is because we also want to compute the mean including key value of the EOF
        added_entity_num = entity_num + 1

        # mask: [batch_size, entity_size]
        mask = mask < added_entity_num.unsqueeze(dim=1)
        print("mask:", mask) if debug else None
        print("mask.shape:", mask.shape) if debug else None

        assert mask.dtype == torch.bool

        # AlphaStar: It then computes a key corresponding to each entity by feeding `entity_embeddings`
        # through a 1D convolution with 32 channels and kernel size 1,
        # and creates a new variable corresponding to ending unit selection.

        # input: [batch_size x entity_size x embedding_size]
        # output: [batch_size x entity_size x key_size], note key_size = 32
        key = self.conv_1(entity_embeddings.transpose(-1,
                                                      -2)).transpose(-1, -2)

        # end index should be the same to the entity_num
        end_index = entity_num

        # replace the EOF with the new_variable
        # use calculation to achieve it
        if False:
            key[torch.arange(batch_size), end_index] = self.new_variable
        else:
            padding_end = torch.zeros(key.shape[0],
                                      1,
                                      key.shape[2],
                                      dtype=key.dtype,
                                      device=key.device)
            key = torch.cat([key[:, :-1, :], padding_end], dim=1)

            flag = torch.ones(key.shape, dtype=torch.bool, device=key.device)
            flag[torch.arange(batch_size), end_index] = False

            # [batch_size, entity_size, key_size]
            end_embedding = torch.ones(
                key.shape, dtype=key.dtype,
                device=key.device) * self.new_variable.reshape(1, -1)
            key_end_part = end_embedding * ~flag

            # use calculation to replace new_variable
            key_main_part = key * flag
            key = key_main_part + key_end_part
            del padding_end, flag, end_embedding, key_main_part, key_end_part

        # calculate the average of keys (consider the entity_num)
        key_mask = mask.unsqueeze(dim=2).repeat(1, 1, key.shape[-1])
        key_avg = torch.sum(key * key_mask, dim=1) / entity_num.reshape(
            batch_size, 1)
        del key_mask

        # creates a new variable corresponding to ending unit selection.
        # QUESTION: how to do that?
        # ANSWER: referred by the DI-star project, please see self.new_variable in init() method
        units_logits_list = []
        hidden = None

        # consider the EOF
        select_units_num = select_units_num + 1

        # designed with reference to DI-star
        max_seq_len = select_units_num.max()

        # for select_units_num
        selected_mask = torch.arange(max_seq_len, device=device).float()
        selected_mask = selected_mask.repeat(batch_size, 1)

        # mask: [batch_size, max_seq_len]
        selected_mask = selected_mask < select_units_num.unsqueeze(dim=1)
        assert selected_mask.dtype == torch.bool

        # in the first selection, we should not select the end_index
        mask[torch.arange(batch_size), end_index] = False

        is_end = torch.zeros(batch_size, device=device).bool()

        # designed with reference to DI-star
        for i in range(max_seq_len):
            if i != 0:
                # in the second selection, we can select the EOF
                if i == 1:
                    mask[torch.arange(batch_size), end_index] = True
                    if self.is_rl_training and unit_type_entity_mask is not None:
                        unit_type_entity_mask[torch.arange(batch_size),
                                              end_index] = True

            # AlphaStar: the network passes `autoregressive_embedding` through a linear of size 256,
            x = self.fc_1(autoregressive_embedding)

            # AlphaStar: adds `func_embed`, and passes the combination through a ReLU and a linear of size 32.
            # x shape: [batch_size x seq_len x 32], note seq_len = 1
            x = self.fc_2(F.relu(x + the_func_embed)).unsqueeze(dim=1)

            # AlphaStar: The result is fed into a LSTM with size 32 and zero initial state to get a query.
            query, hidden = self.small_lstm(x, hidden)
            y = torch.sum(query * key, dim=-1)

            # original mask usage is wrong, we should not let 0 * logits, zero value logit is still large!
            # we use a very big negetive value replaced by logits, like -1e9
            # y shape: [batch_size x entity_size]
            entity_logits = y.masked_fill(~mask, -1e9)
            if self.is_rl_training and self.use_unit_type_entity_mask and unit_type_entity_mask is not None:
                entity_logits = entity_logits.masked_fill(
                    ~unit_type_entity_mask, -1e9)

            temperature = self.temperature if self.is_rl_training else 1
            entity_logits = entity_logits / temperature
            del x, y, query, temperature

            # note, we add a dimension where is in the seq_one to help
            # we concat to the one : [batch_size x max_selected x ?]
            units_logits_list.append(entity_logits.unsqueeze(-2))
            del entity_logits

            # the last EOF should not be considered
            if i != max_seq_len - 1:

                entity_id = units[:, i]
                print('entity_id', entity_id[0]) if show else None
                print('entity_id.shape', entity_id[0].shape) if show else None

                last_index = (entity_id.squeeze(dim=1) == end_index)
                is_end[last_index] = 1

                # AlphaStar: That entity is masked out so that it cannot be selected in future iterations.
                mask[torch.arange(batch_size),
                     entity_id.squeeze(dim=1)] = False
                print('mask', mask[0]) if show else None

                # AlphaStar: The one-hot position of the selected entity is multiplied by the keys,
                # reduced by the mean across the entities, passed through a linear layer of size 1024,
                # and added to `autoregressive_embedding` for subsequent iterations.
                entity_one_hot = L.tensor_one_hot(entity_id,
                                                  entity_size).squeeze(-2)
                entity_one_hot_unsqueeze = entity_one_hot.unsqueeze(-2)

                # entity_one_hot_unsqueeze shape: [batch_size x seq_len x entity_size], note seq_len =1
                # key_shape: [batch_size x entity_size x key_size], note key_size = 32
                out = torch.bmm(entity_one_hot_unsqueeze, key).squeeze(-2)

                # AlphaStar: reduced by the mean across the entities,
                out = out - key_avg

                # t shape: [batch_size, autoregressive_embedding_size]
                t = self.project(out)

                # TODO, whether should be select_mask[:, i + 1] or select_mask[:, i] ?
                autoregressive_embedding = autoregressive_embedding + t * selected_mask[:, i + 1].unsqueeze(
                    dim=1)
                del t, out, entity_one_hot_unsqueeze, entity_one_hot, last_index, entity_id

                print("autoregressive_embedding:",
                      autoregressive_embedding) if debug else None

        # in SL, we make the selected can have 1 more, like 12 + 1
        max_selected = self.max_selected + 1
        units_logits_size = len(units_logits_list)

        if units_logits_size >= max_selected:
            # remove the last one
            units_logits = torch.cat(units_logits_list[:max_selected], dim=1)
        elif units_logits_size > 0 and units_logits_size < max_selected:
            units_logits = torch.cat(units_logits_list, dim=1)
            padding_size = max_selected - units_logits.shape[1]
            if padding_size > 0:
                pad_units_logits = torch.ones(
                    units_logits.shape[0],
                    padding_size,
                    units_logits.shape[2],
                    dtype=units_logits.dtype,
                    device=units_logits.device) * (-1e9)
                units_logits = torch.cat([units_logits, pad_units_logits],
                                         dim=1)
        else:
            units_logits = torch.ones(batch_size,
                                      max_selected,
                                      entity_size,
                                      dtype=action_type.dtype,
                                      device=action_type.device) * (-1e9)

        # AlphaStar: If `action_type` does not involve selecting units, this head is ignored.

        # select_unit_mask: [batch_size x 1]
        # note select_unit_mask should be bool type to make sure it is a right whether mask
        assert len(action_type.shape) == 2

        select_unit_mask = L.action_involve_selecting_units_mask(
            action_type).bool()
        no_select_units_index = ~select_unit_mask.squeeze(dim=1)
        print("no_select_units_index:",
              no_select_units_index) if debug else None

        #autoregressive_embedding[no_select_units_index] = original_ae[no_select_units_index]
        units_logits[no_select_units_index] = (-1e9)  # a magic number

        # remove the EOF
        select_units_num = select_units_num - 1

        del selected_mask, select_unit_mask, no_select_units_index, mask, units_logits_list, key, key_avg

        return units_logits, None, autoregressive_embedding, select_units_num
コード例 #10
0
def get_masked_classify_loss_for_multi_gpu(action_gt,
                                           action_pred,
                                           entity_nums,
                                           action_type_logits,
                                           delay_logits,
                                           queue_logits,
                                           units_logits,
                                           target_unit_logits,
                                           target_location_logits,
                                           select_units_num,
                                           criterion,
                                           device,
                                           decrease_smart_opertaion=False,
                                           only_consider_small=False,
                                           strict_comparsion=True,
                                           remove_none=True):
    loss = 0.

    # consider using move camera weight
    move_camera_weight = SU.get_move_camera_weight_in_SL(
        action_gt.action_type,
        action_pred,
        device,
        decrease_smart_opertaion=decrease_smart_opertaion,
        only_consider_small=only_consider_small).reshape(-1)
    #move_camera_weight = None
    action_type_loss = criterion(action_gt.action_type,
                                 action_type_logits,
                                 mask=move_camera_weight)
    loss += action_type_loss

    #mask_tensor = get_one_way_mask_in_SL(action_gt.action_type, device)
    mask_tensor = SU.get_two_way_mask_in_SL(action_gt.action_type,
                                            action_pred,
                                            device,
                                            strict_comparsion=False)

    # we now onsider delay loss
    delay_weight = 0.0
    if AHP.use_predict_step_mul:
        delay_weight = 1.0
    delay_loss = delay_weight * criterion(action_gt.delay, delay_logits)
    loss += delay_loss

    queue_loss = criterion(action_gt.queue,
                           queue_logits,
                           mask=mask_tensor[:, 2].reshape(-1))
    loss += queue_loss

    batch_size = action_gt.units.shape[0]
    select_size = action_gt.units.shape[1]
    units_size = action_gt.units.shape[-1]

    entity_nums = entity_nums
    print('entity_nums', entity_nums) if debug else None
    print('entity_nums.shape', entity_nums.shape) if debug else None

    # use extended select_size in SL for including EOF
    extended_select_size = select_size + 1

    units_mask = mask_tensor[:,
                             3]  # selected units is in the fourth position of units_mask
    units_mask = units_mask.unsqueeze(1).repeat(1, extended_select_size)
    units_mask = units_mask.reshape(-1)
    print('units_mask', units_mask) if debug else None
    print('units_mask.shape', units_mask.shape) if debug else None

    selected_mask = torch.arange(extended_select_size, device=device).float()
    selected_mask = selected_mask.repeat(batch_size, 1)

    # note, the select_units_num is actually gt_select_units_num here
    # we extend select_units_num by 1 to include the EFO
    selected_mask = selected_mask < (select_units_num + 1).unsqueeze(dim=1)
    selected_mask = selected_mask.reshape(-1)
    print('selected_mask', selected_mask) if debug else None
    print('selected_mask.shape', selected_mask.shape) if debug else None

    gt_units = action_gt.units.long()
    padding = torch.zeros(batch_size,
                          1,
                          units_size,
                          dtype=gt_units.dtype,
                          device=gt_units.device)
    token = torch.tensor(AHP.max_entities - 1,
                         dtype=padding.dtype,
                         device=padding.device)
    padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1)
    gt_units = torch.cat([gt_units, padding], dim=1)
    print('gt_units', gt_units) if debug else None
    print('gt_units.shape', gt_units.shape) if debug else None

    gt_units[torch.arange(batch_size),
             select_units_num] = L.tensor_one_hot(entity_nums,
                                                  units_size).long()
    print('gt_units', gt_units) if debug else None
    print('gt_units.shape', gt_units.shape) if debug else None

    gt_units = gt_units.float()

    all_units_mask = units_mask * selected_mask  # * gt_units_mask

    # TODO: change to a proporate calculation of selected units
    selected_units_weight = 10.
    units_loss = selected_units_weight * criterion(
        gt_units.reshape(-1, units_size),
        units_logits.reshape(-1, units_size),
        mask=all_units_mask,
        debug=False,
        outlier_remove=True,
        entity_nums=entity_nums,
        select_size=extended_select_size,
        select_units_num=select_units_num + 1)
    loss += units_loss

    target_unit_weight = 1.
    target_unit_loss = target_unit_weight * criterion(
        action_gt.target_unit.squeeze(-2),
        target_unit_logits.squeeze(-2),
        mask=mask_tensor[:, 4].reshape(-1),
        debug=False,
        outlier_remove=True,
        entity_nums=entity_nums)
    loss += target_unit_loss

    batch_size = action_gt.target_location.shape[0]
    location_weight = 5.
    target_location_loss = location_weight * criterion(
        action_gt.target_location.reshape(batch_size, -1),
        target_location_logits.reshape(batch_size, -1),
        mask=mask_tensor[:, 5].reshape(-1))
    loss += target_location_loss

    return loss, [
        action_type_loss.item(),
        delay_loss.item(),
        queue_loss.item(),
        units_loss.item(),
        target_unit_loss.item(),
        target_location_loss.item()
    ]