def change_units_and_logits(behavior_logits, gt_units, select_units_num, entity_nums): [batch_size, select_size, units_size] = behavior_logits.shape padding = torch.zeros(batch_size, 1, units_size, dtype=behavior_logits.dtype, device=behavior_logits.device) token = torch.tensor(AHP.max_entities - 1, dtype=torch.long, device=padding.device) padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1).float() behavior_logits = torch.cat([behavior_logits, padding], dim=1) select_units_num = select_units_num.reshape(-1).long() entity_nums = entity_nums.reshape(-1).long() # print('behavior_logits[0][1]', behavior_logits[0][1]) if 1 else None # print('behavior_logits.shape', behavior_logits.shape) if 1 else None behavior_logits[torch.arange(batch_size), -1] = (L.tensor_one_hot(entity_nums, units_size).float() - 1) * 1e9 # print('behavior_logits[0][1]', behavior_logits[0][1]) if 1 else None # print('behavior_logits[0][-1]', behavior_logits[0][-1]) if 1 else None # print('behavior_logits.shape', behavior_logits.shape) if 1 else None gt_units = gt_units.long() padding = torch.zeros(batch_size, 1, 1, dtype=gt_units.dtype, device=gt_units.device) token = torch.tensor(AHP.max_entities - 1, dtype=padding.dtype, device=padding.device) padding[:, 0] = token gt_units = torch.cat([gt_units, padding], dim=1) # print('gt_units[0][1]', gt_units[0][1]) if 1 else None # print('gt_units.shape', gt_units.shape) if 1 else None gt_units[torch.arange(batch_size), -1] = entity_nums.unsqueeze(dim=1) # print('gt_units[0][1]', gt_units[0][1]) if 1 else None # print('gt_units.shape', gt_units.shape) if 1 else None # print('stop', stop) del padding, token, select_units_num, entity_nums gt_units = gt_units.long() return behavior_logits, gt_units
def forward(self, autoregressive_embedding, delay=None): # AlphaStar: `autoregressive_embedding` is decoded using a 2-layer (each with size 256) # linear network with ReLUs, x = F.relu(self.fc_1(autoregressive_embedding)) x = F.relu(self.fc_2(x)) # AlphaStar: before being embedded into `delay_logits` that has size 128 (one for each # possible requested delay in game steps). # note: no temperature used here delay_logits = self.embed_fc(x) # AlphaStar: `delay` is sampled from `delay_logits` using a multinomial, though unlike all other arguments, # no temperature is applied to `delay_logits` before sampling. if delay is None: delay_probs = self.softmax(delay_logits) delay = torch.multinomial(delay_probs, 1) del delay_probs # AlphaStar: Similar to `action_type`, `delay` is projected to a 1D tensor of size 1024 through # a 2-layer (each with size 256) linear network with ReLUs, and added to `autoregressive_embedding` # similar to action_type here, change it to one_hot version delay_one_hot = L.tensor_one_hot(delay, self.max_delay) delay_one_hot = delay_one_hot.squeeze(-2) z = F.relu(self.fc_3(delay_one_hot)) z = F.relu(self.fc_4(z)) t = self.project(z) # the operation may auto broadcasting, so we need a test autoregressive_embedding = autoregressive_embedding + t del delay_one_hot, x, z, t return delay_logits, delay, autoregressive_embedding
def forward(self, autoregressive_embedding, action_type, embedded_entity=None, queue=None): # AlphaStar: Queued Head is similar to the delay head except a temperature of 0.8 # AlphaStar: is applied to the logits before sampling, x = self.fc_1(autoregressive_embedding) x = self.relu(x) x = self.fc_2(x) x = self.relu(x) # note: temperature is used here, compared to delay head # AlphaStar: the size of `queued_logits` is 2 (for queueing and not queueing), queue_logits = self.embed_fc(x) temperature = self.temperature if self.is_rl_training else 1 queue_logits = queue_logits / temperature if queue is None: queue_probs = self.softmax(queue_logits) queue = torch.multinomial(queue_probs, 1) del queue_probs # similar to action_type here, change it to one_hot version queue_one_hot = L.tensor_one_hot(queue, self.max_queue) # to make the dim of queue_one_hot as queue queue_one_hot = queue_one_hot.squeeze(-2) z = self.relu(self.fc_3(queue_one_hot)) z = self.relu(self.fc_4(z)) t = self.project(z) # AlphaStar: and the projected `queued` is not added to `autoregressive_embedding` # if queuing is not possible for the chosen `action_type` # note: projected `queued` is not added to `autoregressive_embedding` if queuing is not # possible for the chosen `action_type` mask = L.action_can_be_queued_mask(action_type).float() autoregressive_embedding = autoregressive_embedding + mask * t del queue_one_hot, x, z, t, mask, action_type return queue_logits, queue, autoregressive_embedding
def forward(self, scalar_list): [agent_statistics, home_race, away_race, upgrades, enemy_upgrades, time, available_actions, unit_counts_bow, mmr, units_buildings, effects, upgrade, beginning_build_order, last_delay, last_action_type, last_repeat_queued] = scalar_list embedded_scalar_list = [] scalar_context_list = [] # agent_statistics: Embedded by taking log(agent_statistics + 1) and passing through a linear of size 64 and a ReLU the_log_statistics = torch.log(agent_statistics + 1) x = F.relu(self.statistics_fc(the_log_statistics)) del agent_statistics, the_log_statistics embedded_scalar_list.append(x) # race: Both races are embedded into a one-hot with maximum 5, and embedded through a linear of size 32 and a ReLU. x = F.relu(self.home_race_fc(home_race.float())) del home_race embedded_scalar_list.append(x) # The embedding is also added to `scalar_context`. scalar_context_list.append(x) # race: Both races are embedded into a one-hot with maximum 5, and embedded through a linear of size 32 and a ReLU. x = F.relu(self.away_race_fc(away_race.float())) del away_race # TODO: During training, the opponent's requested race is hidden in 10% of matches, to simulate playing against the Random race. embedded_scalar_list.append(x) # The embedding is also added to `scalar_context`. scalar_context_list.append(x) # TODO: If we don't know the opponent's race (either because they are random or it is hidden), # we add their true race to the observation once we observe one of their units. # upgrades: The boolean vector of whether an upgrade is present is embedded through a linear of size 128 and a ReLU x = F.relu(self.upgrades_fc(upgrades)) del upgrades embedded_scalar_list.append(x) # enemy_upgrades: Embedded the same as upgrades x = F.relu(self.enemy_upgrades_fc(enemy_upgrades)) del enemy_upgrades embedded_scalar_list.append(x) # time: A transformer positional encoder encoded the time into a 1D tensor of size 64 # do it in the preprocess x = time embedded_scalar_list.append(x) # available_actions: From `entity_list`, we compute which actions may be available and which can never be available. # For example, the agent controls a Stalker and has researched the Blink upgrade, # then the Blink action may be available (even though in practice it may be on cooldown). # The boolean vector of action availability is passed through a linear of size 64 and a ReLU. x = F.relu(self.available_actions_fc(available_actions)) del available_actions embedded_scalar_list.append(x) # The embedding is also added to `scalar_context` scalar_context_list.append(x) # unit_counts_bow: A bag-of-words unit count from `entity_list`. # The unit count vector is embedded by square rooting, passing through a linear layer, and passing through a ReLU # note make sure unit_counts_bow all >= 0, otherwise torch.sqrt will produce nan ! print('unit_counts_bow', unit_counts_bow) if debug else None assert (unit_counts_bow >= 0).all() unit_counts_bow = torch.sqrt(unit_counts_bow) x = F.relu(self.unit_counts_bow_fc(unit_counts_bow)) del unit_counts_bow embedded_scalar_list.append(x) # mmr: During supervised learning, this is the MMR of the player we are trying to imitate. Elsewhere, this is fixed at 6200. # MMR is mapped to a one-hot of min(mmr / 1000, 6) with maximum 6, then passed through a linear of size 64 and a ReLU x = F.relu(self.mmr_fc(mmr)) del mmr embedded_scalar_list.append(x) # cumulative_statistics: The cumulative statistics (including units, buildings, effects, and upgrades) are preprocessed # into a boolean vector of whether or not statistic is present in a human game. # That vector is split into 3 sub-vectors of units/buildings, effects, and upgrades, # and each subvector is passed through a linear of size 32 and a ReLU, and concatenated together. # The embedding is also added to `scalar_context` x = F.relu(self.units_buildings_fc(units_buildings)) del units_buildings embedded_scalar_list.append(x) scalar_context_list.append(x) x = F.relu(self.effects_fc(effects)) del effects embedded_scalar_list.append(x) scalar_context_list.append(x) x = F.relu(self.upgrade_fc(upgrade)) del upgrade embedded_scalar_list.append(x) scalar_context_list.append(x) # beginning_build_order: The first 20 constructed entities are converted to a 2D tensor of size # [20, num_entity_types], concatenated with indices and the binary encodings # (as in the Entity Encoder) of where entities were constructed (if applicable). # The concatenation is passed through a transformer similar to the one in the entity encoder, # but with keys, queries, and values of 8 and with a MLP hidden size of 32. # The embedding is also added to `scalar_context`. print("beginning_build_order:", beginning_build_order) if debug else None print("beginning_build_order.shape:", beginning_build_order.shape) if debug else None batch_size = beginning_build_order.shape[0] seq = torch.arange(SCHP.count_beginning_build_order) seq = L.tensor_one_hot(seq, SCHP.count_beginning_build_order) seq = seq.unsqueeze(0).repeat(batch_size, 1, 1).to(beginning_build_order.device) bo_sum = beginning_build_order.sum(dim=-1, keepdim=False) bo_sum = bo_sum.sum(dim=-1, keepdim=False) bo_sum = bo_sum.unsqueeze(1) bo_sum = bo_sum.repeat(1, SCHP.count_beginning_build_order) mask = torch.arange(SCHP.count_beginning_build_order) mask = mask.unsqueeze(0).repeat(batch_size, 1).to(bo_sum.device) mask = mask < bo_sum mask = mask.unsqueeze(2).repeat(1, 1, SCHP.count_beginning_build_order) # add the seq info, referenced by the processing way of DI-star x = torch.cat([beginning_build_order, seq], dim=2) x = self.before_beginning_build_order(x) # like in entity encoder, we add a sequence mask x = self.beginning_build_order_transformer(x, mask=mask) x = x.reshape(x.shape[0], SCHP.count_beginning_build_order * self.build_order_model_size) embedded_scalar_list.append(x) scalar_context_list.append(x) del mask, bo_sum, seq # last_delay: The delay between when we last acted and the current observation, in game steps. # This may be different from what we requested due to network latency or APM limits. # It is encoded into a one-hot with maximum 128 and passed through a linear of size 64 and a ReLU x = F.relu(self.last_delay_fc(last_delay)) del last_delay embedded_scalar_list.append(x) # last_action_type: The last action type is encoded into a one-hot with maximum equal # to the number of possible actions, and passed through a linear of size 128 and a ReLU x = F.relu(self.last_action_type_fc(last_action_type)) del last_action_type embedded_scalar_list.append(x) # last_repeat_queued: Some other action arguments (queued and repeat) are one-hots with # maximum equal to the number of possible values for those arguments, # and jointly passed through a linear of size 256 and ReLU x = F.relu(self.last_repeat_queued_fc(last_repeat_queued)) del last_repeat_queued embedded_scalar_list.append(x) embedded_scalar = torch.cat(embedded_scalar_list, dim=1) embedded_scalar_out = F.relu(self.fc_1(embedded_scalar)) scalar_context = torch.cat(scalar_context_list, dim=1) scalar_context_out = F.relu(self.fc_2(scalar_context)) del x, embedded_scalar_list, scalar_context_list, embedded_scalar, scalar_context return embedded_scalar_out, scalar_context_out
def forward(self, lstm_output, scalar_context, action_type_mask=None, action_type=None): batch_size = lstm_output.shape[0] # AlphaStar: The action type head embeds `lstm_output` into a 1D tensor of size 256 x = self.embed_fc(lstm_output) # AlphaStar: passes it through 16 ResBlocks with layer normalization each of size 256, and applies a ReLU. # QUESTION: There is no map, how to use resblocks? # ANSWER: USE resblock1D # input shape is [batch_size x seq_size x embedding_size] # note that embedding_size is equal to channel_size in conv1d # we change this to [batch_size x embedding_size x seq_size] x = x.unsqueeze(-1) for resblock in self.resblock_stack: x = resblock(x) x = F.relu(x) x = x.squeeze(-1) # AlphaStar: The output is converted to a tensor with one logit for each possible # action type through a `GLU` gated by `scalar_context`. action_type_logits = self.glu_1(x, scalar_context) # inspired by the DI-star project, in action_type_head if self.is_rl_training and self.use_action_type_mask and action_type_mask is not None: action_type_mask = action_type_mask.bool() if action_type is not None: for i, a in enumerate(action_type): action_type_mask[i, a.item()] = True action_type_logits = action_type_logits + (~action_type_mask * (-1e9)) del action_type_mask print('action_type_logits:', action_type_logits) if debug else None print('action_type_logits.shape:', action_type_logits.shape) if debug else None # AlphaStar: `action_type` is sampled from these logits using a multinomial with temperature 0.8. # Note that during supervised learning, `action_type` will be the ground truth human action # type, and temperature is 1.0 (and similarly for all other arguments). temperature = self.temperature if self.is_rl_training else 1 action_type_logits = action_type_logits / temperature print('action_type_logits:', action_type_logits) if debug else None print('action_type_logits.shape:', action_type_logits.shape) if debug else None # note, torch.multinomial need samples to non-negative, finite and have a non-zero sum # which is different with tf.multinomial which can accept negative values like log(action_type_probs) if action_type is None: action_type_probs = self.softmax(action_type_logits) action_type_probs = action_type_probs.reshape(batch_size, -1) print('action_type_probs:', action_type_probs) if debug else None print('action_type_probs.shape:', action_type_probs.shape) if debug else None action_type = torch.multinomial(action_type_probs, 1) action_type = action_type.reshape(batch_size, -1) del action_type_probs # change action_type to one_hot version action_type_one_hot = L.tensor_one_hot(action_type, self.max_action_num) action_type_one_hot = action_type_one_hot.squeeze(-2) # AlphaStar: `autoregressive_embedding` is then generated by first applying a ReLU # and linear layer of size 256 to the one-hot version of `action_type` z = F.relu(self.fc_1(action_type_one_hot)) # AlphaStar: and projecting it to a 1D tensor of size 1024 through a `GLU` gated by `scalar_context`. z = self.glu_2(z, scalar_context) # AlphaStar: That projection is added to another projection of `lstm_output` into a 1D tensor of size # 1024 gated by `scalar_context` to yield `autoregressive_embedding`. t = self.glu_3(lstm_output, scalar_context) # the add operation may auto broadcasting, so we need an assert test autoregressive_embedding = z + t del action_type_one_hot, lstm_output, scalar_context, x, z, t return action_type_logits, action_type, autoregressive_embedding
def pre_forward(self, various_observations): [ agent_statistics, upgrades, unit_counts_bow, units_buildings, effects, upgrade, beginning_build_order, cumulative_score ] = various_observations # These features are all concatenated together to yield `action_type_input`, # passed through a linear of size 256, then passed through 16 ResBlocks with 256 hidden units # and layer normalization, passed through a ReLU, then passed through # a linear with 1 hidden unit. device = next(self.parameters()).device embedded_scalar_list = [] # agent_statistics: Embedded by taking log(agent_statistics + 1) and passing through a linear of size 64 and a ReLU the_log_statistics = torch.log(agent_statistics + 1) x = F.relu(self.statistics_fc(the_log_statistics)) del agent_statistics, the_log_statistics embedded_scalar_list.append(x) score_log_statistics = torch.log(cumulative_score + 1) x = F.relu(self.cumulatscore_fc(score_log_statistics)) del cumulative_score, score_log_statistics embedded_scalar_list.append(x) # upgrades: The boolean vector of whether an upgrade is present is embedded through a linear of size 128 and a ReLU x = F.relu(self.upgrades_fc(upgrades)) del upgrades embedded_scalar_list.append(x) # unit_counts_bow: A bag-of-words unit count from `entity_list`. # The unit count vector is embedded by square rooting, passing through a linear layer, and passing through a ReLU x = F.relu(self.unit_counts_bow_fc(unit_counts_bow)) del unit_counts_bow embedded_scalar_list.append(x) # cumulative_statistics: The cumulative statistics (including units, buildings, effects, and upgrades) are preprocessed # into a boolean vector of whether or not statistic is present in a human game. # That vector is split into 3 sub-vectors of units/buildings, effects, and upgrades, # and each subvector is passed through a linear of size 32 and a ReLU, and concatenated together. # The embedding is also added to `scalar_context` cumulative_statistics = [] # it is different in different baseline if self.baseline_type == "winloss" or self.baseline_type == "build_order" or self.baseline_type == "built_units": x = F.relu(self.units_buildings_fc(units_buildings)) cumulative_statistics.append(x) if self.baseline_type == "effects" or self.baseline_type == "build_order": x = F.relu(self.effects_fc(effects)) cumulative_statistics.append(x) if self.baseline_type == "upgrades" or self.baseline_type == "build_order": x = F.relu(self.upgrade_fc(upgrade)) cumulative_statistics.append(x) embedded_scalar_list.extend(cumulative_statistics) del units_buildings, effects, upgrade, cumulative_statistics # beginning_build_order: The first 20 constructed entities are converted to a 2D tensor of size # [20, num_entity_types], concatenated with indices and the binary encodings # (as in the Entity Encoder) of where entities were constructed (if applicable). # The concatenation is passed through a transformer similar to the one in the entity encoder, # but with keys, queries, and values of 8 and with a MLP hidden size of 32. # The embedding is also added to `scalar_context`. print("beginning_build_order:", beginning_build_order) if debug else None print("beginning_build_order.shape:", beginning_build_order.shape) if debug else None batch_size = beginning_build_order.shape[0] seq = torch.arange(SCHP.count_beginning_build_order) seq = L.tensor_one_hot(seq, SCHP.count_beginning_build_order) seq = seq.unsqueeze(0).repeat(batch_size, 1, 1).to(beginning_build_order.device) bo_sum = beginning_build_order.sum(dim=-1, keepdim=False) bo_sum = bo_sum.sum(dim=-1, keepdim=False) bo_sum = bo_sum.unsqueeze(1) bo_sum = bo_sum.repeat(1, SCHP.count_beginning_build_order) mask = torch.arange(SCHP.count_beginning_build_order) mask = mask.unsqueeze(0).repeat(batch_size, 1).to(bo_sum.device) mask = mask < bo_sum mask = mask.unsqueeze(2).repeat(1, 1, SCHP.count_beginning_build_order) # add the seq info, referenced by the processing way of DI-star x = torch.cat([beginning_build_order, seq], dim=2) x = self.before_beginning_build_order(x) # like in entity encoder, we add a sequence mask x = self.beginning_build_order_transformer(x, mask=mask) x = x.reshape( x.shape[0], SCHP.count_beginning_build_order * self.build_order_model_size) embedded_scalar_list.append(x) embedded_scalar = torch.cat(embedded_scalar_list, dim=1) del x, mask, bo_sum, seq, embedded_scalar_list return embedded_scalar
def mimic_forward(self, state, gt_action, gt_select_units_num, gt_is_one_hot=True, multi_gpu_supvised_learning=False, batch_size=None, sequence_length=None, hidden_state=None, baseline_state=None, baseline_opponent_state=None, show=False, obs_list=None): ''' # inspired by the DI-star project # injected the args of ground truth into the forward calculation # to make sure the forward follow the right direction ''' # shapes of embedded_entity, embedded_spatial, embedded_scalar are all [batch_size x embedded_size] entity_embeddings, embedded_entity, entity_nums, unit_types_one = self.entity_encoder( state.entity_state, return_unit_types=True) map_skip, embedded_spatial = self.spatial_encoder( state.map_state, entity_embeddings) embedded_scalar, scalar_context = self.scalar_encoder( state.statistical_state) available_actions_id = 6 # available_actions is at position 6 available_actions = state.statistical_state[available_actions_id] del state lstm_output, hidden_state = self.core(embedded_scalar, embedded_entity, embedded_spatial, batch_size, sequence_length, hidden_state) del embedded_scalar, embedded_spatial if gt_is_one_hot: # TODO: remove these unzero gt_action_type = torch.nonzero(gt_action.action_type.long(), as_tuple=True)[-1].unsqueeze(dim=1) print('gt_action_type.shape', gt_action_type.shape) if debug else None gt_delay = torch.nonzero(gt_action.delay.long(), as_tuple=True)[-1].unsqueeze(dim=1) print('gt_delay.shape', gt_delay.shape) if debug else None gt_queue = torch.nonzero(gt_action.queue.long(), as_tuple=True)[-1].unsqueeze(dim=1) print('gt_queue.shape', gt_queue.shape) if debug else None gt_units = gt_action.units.long() batch_size = gt_units.shape[0] select_size = gt_units.shape[1] units_size = gt_units.shape[2] padding = torch.zeros(batch_size, 1, units_size, dtype=gt_units.dtype, device=gt_units.device) token = torch.tensor(AHP.max_entities - 1, dtype=padding.dtype, device=padding.device) padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1) gt_units = torch.cat([gt_units, padding], dim=1) print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_units[torch.arange(batch_size), gt_select_units_num] = L.tensor_one_hot( entity_nums, units_size).long() gt_units = gt_units.reshape(-1, units_size) print('gt_units.shape', gt_units.shape) if debug else None gt_units = torch.nonzero(gt_units, as_tuple=True)[-1] gt_units = gt_units.reshape(batch_size, -1, 1) print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_target_unit = gt_action.target_unit.long() gt_target_unit = gt_target_unit.reshape(-1, gt_target_unit.shape[-1]) gt_target_unit = torch.nonzero(gt_target_unit, as_tuple=True)[-1] print('gt_target_unit.shape', gt_target_unit.shape) if debug else None gt_target_unit = gt_target_unit.reshape(batch_size, 1, 1) else: gt_action_type = gt_action.action_type gt_delay = gt_action.delay gt_queue = gt_action.queue gt_units = gt_action.units padding = torch.zeros(batch_size, 1, 1, dtype=gt_units.dtype, device=gt_units.device) token = torch.tensor(AHP.max_entities - 1, dtype=padding.dtype, device=padding.device) padding[:, 0] = token gt_units = torch.cat([gt_units, padding], dim=1) del padding, token print('gt_select_units_num', gt_select_units_num) if debug else None print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_units[torch.arange(batch_size), gt_select_units_num] = entity_nums.unsqueeze(dim=1) print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_target_unit = gt_action.target_unit action_type_logits, action_type, autoregressive_embedding = self.action_type_head( lstm_output, scalar_context, available_actions, gt_action_type) if obs_list is not None: unit_type_entity_mask = L.get_batch_unit_type_mask( action_type.squeeze(dim=1), obs_list) unit_type_entity_mask = torch.tensor(unit_type_entity_mask, dtype=torch.bool, device=action_type.device) else: unit_type_entity_mask = None delay_logits, _, autoregressive_embedding = self.delay_head( autoregressive_embedding, gt_delay) queue_logits, _, autoregressive_embedding = self.queue_head( autoregressive_embedding, gt_action_type, embedded_entity, gt_queue) # selected_units_head is special, we use forward_sl function print('gt_units', gt_units) if show else None print('gt_units.shape', gt_units.shape) if show else None units_logits, units, autoregressive_embedding, select_units_num = self.selected_units_head.mimic_forward( autoregressive_embedding, gt_action_type, entity_embeddings, entity_nums, gt_units, gt_select_units_num, show=show, unit_type_entity_mask=unit_type_entity_mask) print('units_logits', units_logits) if show else None print('units_logits.shape', units_logits.shape) if show else None target_unit_logits, target_unit = self.target_unit_head( autoregressive_embedding, gt_action_type, entity_embeddings, entity_nums, gt_target_unit) target_location_logits, target_location = self.location_head( autoregressive_embedding, gt_action_type, map_skip) action_logits = ArgsActionLogits( action_type=action_type_logits, delay=delay_logits, queue=queue_logits, units=units_logits, target_unit=target_unit_logits, target_location=target_location_logits) del action_logits, unit_type_entity_mask del gt_action, gt_action_type, gt_delay, gt_queue, entity_embeddings, gt_units, gt_select_units_num, autoregressive_embedding del map_skip, gt_target_unit, embedded_entity, scalar_context, available_actions if multi_gpu_supvised_learning: return action_type, entity_nums, units, target_unit, target_location, action_type_logits, delay_logits, queue_logits, \ units_logits, target_unit_logits, target_location_logits, select_units_num, hidden_state, unit_types_one elif baseline_state is not None: winloss_baseline_value = self.winloss_baseline.forward( lstm_output, baseline_state, baseline_opponent_state) build_order_baseline_value = self.build_order_baseline.forward( lstm_output, baseline_state, baseline_opponent_state) built_units_baseline_value = self.built_units_baseline.forward( lstm_output, baseline_state, baseline_opponent_state) upgrades_baseline_value = self.upgrades_baseline.forward( lstm_output, baseline_state, baseline_opponent_state) effects_baseline_value = self.effects_baseline.forward( lstm_output, baseline_state, baseline_opponent_state) del lstm_output, baseline_state, baseline_opponent_state baseline_value_list = [ winloss_baseline_value, build_order_baseline_value, built_units_baseline_value, upgrades_baseline_value, effects_baseline_value ] del winloss_baseline_value, build_order_baseline_value, built_units_baseline_value del upgrades_baseline_value, effects_baseline_value else: baseline_value_list = [] return baseline_value_list, action_type, entity_nums, units, target_unit, target_location, action_type_logits, delay_logits, queue_logits, \ units_logits, target_unit_logits, target_location_logits, select_units_num, hidden_state, unit_types_one
def forward(self, autoregressive_embedding, action_type, entity_embeddings, entity_num, unit_type_entity_mask=None): ''' Inputs: autoregressive_embedding: [batch_size x autoregressive_embedding_size] action_type: [batch_size x 1] entity_embeddings: [batch_size x entity_size x embedding_size] entity_num: [batch_size] Output: units_logits: [batch_size x max_selected x entity_size] units: [batch_size x max_selected x 1] autoregressive_embedding: [batch_size x autoregressive_embedding_size] ''' batch_size = entity_embeddings.shape[0] entity_size = entity_embeddings.shape[-2] device = next(self.parameters()).device key_size = self.new_variable.shape[0] original_ae = autoregressive_embedding # AlphaStar: If applicable, Selected Units Head first determines which entity types can accept `action_type`, # creates a one-hot of that type with maximum equal to the number of unit types, # and passes it through a linear of size 256 and a ReLU. This will be referred to in this head as `func_embed`. # QUESTION: one unit type or serveral unit types? # ANSWER: serveral unit types, each for one-hot # This is some places which introduce much human knowledge unit_types_one_hot = L.action_can_apply_to_selected_mask( action_type).to(device) # the_func_embed shape: [batch_size x 256] the_func_embed = F.relu(self.func_embed(unit_types_one_hot)) del unit_types_one_hot # AlphaStar: It also computes a mask of which units can be selected, initialised to allow selecting all entities # that exist (including enemy units). # generate the length mask for all entities mask = torch.arange(entity_size, device=device).float() mask = mask.repeat(batch_size, 1) # now the entity nums should be added 1 (including the EOF) # this is because we also want to compute the mean including key value of the EOF added_entity_num = entity_num + 1 # mask: [batch_size, entity_size] mask = mask < added_entity_num.unsqueeze(dim=1) assert mask.dtype == torch.bool # AlphaStar: It then computes a key corresponding to each entity by feeding `entity_embeddings` # through a 1D convolution with 32 channels and kernel size 1, # and creates a new variable corresponding to ending unit selection. # input: [batch_size x entity_size x embedding_size] # output: [batch_size x entity_size x key_size], note key_size = 32 key = self.conv_1(entity_embeddings.transpose(-1, -2)).transpose(-1, -2) # end index should be the same to the entity_num end_index = entity_num # replace the EOF with the new_variable # use calculation to achieve it if False: key[torch.arange(batch_size), end_index] = self.new_variable else: padding_end = torch.zeros(key.shape[0], 1, key.shape[2], dtype=key.dtype, device=key.device) key = torch.cat([key[:, :-1, :], padding_end], dim=1) flag = torch.ones(key.shape, dtype=torch.bool, device=key.device) flag[torch.arange(batch_size), end_index] = False # [batch_size, entity_size, key_size] end_embedding = torch.ones( key.shape, dtype=key.dtype, device=key.device) * self.new_variable.reshape(1, -1) key_end_part = end_embedding * ~flag # use calculation to replace new_variable key_main_part = key * flag key = key_main_part + key_end_part del padding_end, flag, end_embedding, key_main_part, key_end_part # calculate the average of keys (consider the entity_num) key_mask = mask.unsqueeze(dim=2).repeat(1, 1, key.shape[-1]) key_avg = torch.sum(key * key_mask, dim=1) / entity_num.reshape( batch_size, 1) del key_mask # creates a new variable corresponding to ending unit selection. # QUESTION: how to do that? # ANSWER: referred by the DI-star project, please see self.new_variable in init() method units_logits = [] units = [] hidden = None # referneced by DI-star # represented which sample in the batch has end the selection # note is_end should be bool type to make sure it is a right whether mask is_end = torch.zeros(batch_size, device=device).bool() # in the first selection, we should not select the end_index mask[torch.arange(batch_size), end_index] = False # if we stop selection early, we should record in each sample we select how many items select_units_num = torch.ones( batch_size, dtype=torch.long, device=device) * self.max_selected # AlphaStar: repeated for selecting up to 64 units for i in range(self.max_selected): if i == 1: mask[ torch.arange(batch_size), end_index] = True # in the second selection, we can select the EOF if self.is_rl_training and unit_type_entity_mask is not None: unit_type_entity_mask[torch.arange(batch_size), end_index] = True x = self.fc_1(autoregressive_embedding) x = self.fc_2(F.relu(x + the_func_embed)).unsqueeze(dim=1) query, hidden = self.small_lstm(x, hidden) y = torch.sum(query * key, dim=-1) entity_logits = y.masked_fill(~mask, -1e9) if self.is_rl_training and self.use_unit_type_entity_mask and unit_type_entity_mask is not None: entity_logits = entity_logits.masked_fill( ~unit_type_entity_mask, -1e9) temperature = self.temperature if self.is_rl_training else 1 entity_logits = entity_logits / temperature del x, y, query entity_probs = self.softmax(entity_logits) entity_id = torch.multinomial(entity_probs, 1) units_logits.append(entity_logits.unsqueeze(-2)) units.append(entity_id.unsqueeze(-2)) mask[torch.arange( batch_size ), entity_id.squeeze( dim=1 )] = False # masked out so that it cannot be selected in future iterations. last_index = (entity_id.squeeze(dim=1) == end_index) is_end[last_index] = 1 # we record how many items we select in a sample # we select i + 1 items, but this include the EOF, so actually items should be i + 1 - 1 select_units_num[last_index] = i # AlphaStar: The one-hot position of the selected entity is multiplied by the keys, # reduced by the mean across the entities, passed through a linear layer of size 1024, # and added to `autoregressive_embedding` for subsequent iterations. entity_one_hot = L.tensor_one_hot(entity_id, entity_size).squeeze(-2) entity_one_hot_unsqueeze = entity_one_hot.unsqueeze(-2) out = torch.bmm(entity_one_hot_unsqueeze, key).squeeze(-2) out = out - key_avg t = self.project(out) autoregressive_embedding = autoregressive_embedding + t * ~is_end.unsqueeze( dim=1) del temperature, entity_logits, entity_probs, entity_id del last_index, entity_one_hot, entity_one_hot_unsqueeze, out, t if is_end.all(): break # units_logits: [batch_size x select_units x entity_size] units_logits = torch.cat(units_logits, dim=1) # units: [batch_size x select_units x 1] units = torch.cat(units, dim=1) # we use zero padding to make units_logits has the size of [batch_size x max_selected x entity_size] # TODO: change the padding padding_size = self.max_selected - units_logits.shape[1] if padding_size > 0: pad_units_logits = torch.ones(units_logits.shape[0], padding_size, units_logits.shape[2], dtype=units_logits.dtype, device=units_logits.device) * (-1e9) units_logits = torch.cat([units_logits, pad_units_logits], dim=1) pad_units = torch.zeros(units.shape[0], padding_size, units.shape[2], dtype=units.dtype, device=units.device) pad_units[:, :, 0] = entity_size - 1 # None index, the same as -1 units = torch.cat([units, pad_units], dim=1) del pad_units, pad_units_logits # AlphaStar: If `action_type` does not involve selecting units, this head is ignored. # select_unit_mask: [batch_size x 1] # note select_unit_mask should be bool type to make sure it is a right whether mask select_unit_mask = L.action_involve_selecting_units_mask( action_type).bool() no_select_units_index = ~select_unit_mask.squeeze(dim=1) print("no_select_units_index:", no_select_units_index) if debug else None select_units_num[no_select_units_index] = 0 autoregressive_embedding[no_select_units_index] = original_ae[ no_select_units_index] units_logits[no_select_units_index] = -1e9 # a magic number units[no_select_units_index, :, 0] = entity_size - 1 # None index, the same as -1 print("select_units_num:", select_units_num) if debug else None print("autoregressive_embedding:", autoregressive_embedding) if debug else None del select_unit_mask, no_select_units_index, mask, is_end, key, key_avg return units_logits, units, autoregressive_embedding, select_units_num
def mimic_forward(self, autoregressive_embedding, action_type, entity_embeddings, entity_num, units, select_units_num, show=False, unit_type_entity_mask=None): ''' Inputs: autoregressive_embedding: [batch_size x autoregressive_embedding_size] action_type: [batch_size x 1] entity_embeddings: [batch_size x entity_size x embedding_size] entity_num: [batch_size] Output: units_logits: [batch_size x max_selected x entity_size] units: [batch_size x max_selected x 1] autoregressive_embedding: [batch_size x autoregressive_embedding_size] ''' batch_size = entity_embeddings.shape[0] entity_size = entity_embeddings.shape[-2] device = next(self.parameters()).device key_size = self.new_variable.shape[0] original_ae = autoregressive_embedding # AlphaStar: If applicable, Selected Units Head first determines which entity types can accept `action_type`, # creates a one-hot of that type with maximum equal to the number of unit types, # and passes it through a linear of size 256 and a ReLU. This will be referred to in this head as `func_embed`. # QUESTION: one unit type or serveral unit types? # ANSWER: serveral unit types, each for one-hot # This is some places which introduce much human knowledge unit_types_one_hot = L.action_can_apply_to_selected_mask( action_type).to(device) # the_func_embed shape: [batch_size x 256] the_func_embed = F.relu(self.func_embed(unit_types_one_hot)) del unit_types_one_hot # AlphaStar: It also computes a mask of which units can be selected, initialised to allow selecting all entities # that exist (including enemy units). # generate the length mask for all entities mask = torch.arange(entity_size, device=device).float() mask = mask.repeat(batch_size, 1) # now the entity nums should be added 1 (including the EOF) # this is because we also want to compute the mean including key value of the EOF added_entity_num = entity_num + 1 # mask: [batch_size, entity_size] mask = mask < added_entity_num.unsqueeze(dim=1) print("mask:", mask) if debug else None print("mask.shape:", mask.shape) if debug else None assert mask.dtype == torch.bool # AlphaStar: It then computes a key corresponding to each entity by feeding `entity_embeddings` # through a 1D convolution with 32 channels and kernel size 1, # and creates a new variable corresponding to ending unit selection. # input: [batch_size x entity_size x embedding_size] # output: [batch_size x entity_size x key_size], note key_size = 32 key = self.conv_1(entity_embeddings.transpose(-1, -2)).transpose(-1, -2) # end index should be the same to the entity_num end_index = entity_num # replace the EOF with the new_variable # use calculation to achieve it if False: key[torch.arange(batch_size), end_index] = self.new_variable else: padding_end = torch.zeros(key.shape[0], 1, key.shape[2], dtype=key.dtype, device=key.device) key = torch.cat([key[:, :-1, :], padding_end], dim=1) flag = torch.ones(key.shape, dtype=torch.bool, device=key.device) flag[torch.arange(batch_size), end_index] = False # [batch_size, entity_size, key_size] end_embedding = torch.ones( key.shape, dtype=key.dtype, device=key.device) * self.new_variable.reshape(1, -1) key_end_part = end_embedding * ~flag # use calculation to replace new_variable key_main_part = key * flag key = key_main_part + key_end_part del padding_end, flag, end_embedding, key_main_part, key_end_part # calculate the average of keys (consider the entity_num) key_mask = mask.unsqueeze(dim=2).repeat(1, 1, key.shape[-1]) key_avg = torch.sum(key * key_mask, dim=1) / entity_num.reshape( batch_size, 1) del key_mask # creates a new variable corresponding to ending unit selection. # QUESTION: how to do that? # ANSWER: referred by the DI-star project, please see self.new_variable in init() method units_logits_list = [] hidden = None # consider the EOF select_units_num = select_units_num + 1 # designed with reference to DI-star max_seq_len = select_units_num.max() # for select_units_num selected_mask = torch.arange(max_seq_len, device=device).float() selected_mask = selected_mask.repeat(batch_size, 1) # mask: [batch_size, max_seq_len] selected_mask = selected_mask < select_units_num.unsqueeze(dim=1) assert selected_mask.dtype == torch.bool # in the first selection, we should not select the end_index mask[torch.arange(batch_size), end_index] = False is_end = torch.zeros(batch_size, device=device).bool() # designed with reference to DI-star for i in range(max_seq_len): if i != 0: # in the second selection, we can select the EOF if i == 1: mask[torch.arange(batch_size), end_index] = True if self.is_rl_training and unit_type_entity_mask is not None: unit_type_entity_mask[torch.arange(batch_size), end_index] = True # AlphaStar: the network passes `autoregressive_embedding` through a linear of size 256, x = self.fc_1(autoregressive_embedding) # AlphaStar: adds `func_embed`, and passes the combination through a ReLU and a linear of size 32. # x shape: [batch_size x seq_len x 32], note seq_len = 1 x = self.fc_2(F.relu(x + the_func_embed)).unsqueeze(dim=1) # AlphaStar: The result is fed into a LSTM with size 32 and zero initial state to get a query. query, hidden = self.small_lstm(x, hidden) y = torch.sum(query * key, dim=-1) # original mask usage is wrong, we should not let 0 * logits, zero value logit is still large! # we use a very big negetive value replaced by logits, like -1e9 # y shape: [batch_size x entity_size] entity_logits = y.masked_fill(~mask, -1e9) if self.is_rl_training and self.use_unit_type_entity_mask and unit_type_entity_mask is not None: entity_logits = entity_logits.masked_fill( ~unit_type_entity_mask, -1e9) temperature = self.temperature if self.is_rl_training else 1 entity_logits = entity_logits / temperature del x, y, query, temperature # note, we add a dimension where is in the seq_one to help # we concat to the one : [batch_size x max_selected x ?] units_logits_list.append(entity_logits.unsqueeze(-2)) del entity_logits # the last EOF should not be considered if i != max_seq_len - 1: entity_id = units[:, i] print('entity_id', entity_id[0]) if show else None print('entity_id.shape', entity_id[0].shape) if show else None last_index = (entity_id.squeeze(dim=1) == end_index) is_end[last_index] = 1 # AlphaStar: That entity is masked out so that it cannot be selected in future iterations. mask[torch.arange(batch_size), entity_id.squeeze(dim=1)] = False print('mask', mask[0]) if show else None # AlphaStar: The one-hot position of the selected entity is multiplied by the keys, # reduced by the mean across the entities, passed through a linear layer of size 1024, # and added to `autoregressive_embedding` for subsequent iterations. entity_one_hot = L.tensor_one_hot(entity_id, entity_size).squeeze(-2) entity_one_hot_unsqueeze = entity_one_hot.unsqueeze(-2) # entity_one_hot_unsqueeze shape: [batch_size x seq_len x entity_size], note seq_len =1 # key_shape: [batch_size x entity_size x key_size], note key_size = 32 out = torch.bmm(entity_one_hot_unsqueeze, key).squeeze(-2) # AlphaStar: reduced by the mean across the entities, out = out - key_avg # t shape: [batch_size, autoregressive_embedding_size] t = self.project(out) # TODO, whether should be select_mask[:, i + 1] or select_mask[:, i] ? autoregressive_embedding = autoregressive_embedding + t * selected_mask[:, i + 1].unsqueeze( dim=1) del t, out, entity_one_hot_unsqueeze, entity_one_hot, last_index, entity_id print("autoregressive_embedding:", autoregressive_embedding) if debug else None # in SL, we make the selected can have 1 more, like 12 + 1 max_selected = self.max_selected + 1 units_logits_size = len(units_logits_list) if units_logits_size >= max_selected: # remove the last one units_logits = torch.cat(units_logits_list[:max_selected], dim=1) elif units_logits_size > 0 and units_logits_size < max_selected: units_logits = torch.cat(units_logits_list, dim=1) padding_size = max_selected - units_logits.shape[1] if padding_size > 0: pad_units_logits = torch.ones( units_logits.shape[0], padding_size, units_logits.shape[2], dtype=units_logits.dtype, device=units_logits.device) * (-1e9) units_logits = torch.cat([units_logits, pad_units_logits], dim=1) else: units_logits = torch.ones(batch_size, max_selected, entity_size, dtype=action_type.dtype, device=action_type.device) * (-1e9) # AlphaStar: If `action_type` does not involve selecting units, this head is ignored. # select_unit_mask: [batch_size x 1] # note select_unit_mask should be bool type to make sure it is a right whether mask assert len(action_type.shape) == 2 select_unit_mask = L.action_involve_selecting_units_mask( action_type).bool() no_select_units_index = ~select_unit_mask.squeeze(dim=1) print("no_select_units_index:", no_select_units_index) if debug else None #autoregressive_embedding[no_select_units_index] = original_ae[no_select_units_index] units_logits[no_select_units_index] = (-1e9) # a magic number # remove the EOF select_units_num = select_units_num - 1 del selected_mask, select_unit_mask, no_select_units_index, mask, units_logits_list, key, key_avg return units_logits, None, autoregressive_embedding, select_units_num
def get_masked_classify_loss_for_multi_gpu(action_gt, action_pred, entity_nums, action_type_logits, delay_logits, queue_logits, units_logits, target_unit_logits, target_location_logits, select_units_num, criterion, device, decrease_smart_opertaion=False, only_consider_small=False, strict_comparsion=True, remove_none=True): loss = 0. # consider using move camera weight move_camera_weight = SU.get_move_camera_weight_in_SL( action_gt.action_type, action_pred, device, decrease_smart_opertaion=decrease_smart_opertaion, only_consider_small=only_consider_small).reshape(-1) #move_camera_weight = None action_type_loss = criterion(action_gt.action_type, action_type_logits, mask=move_camera_weight) loss += action_type_loss #mask_tensor = get_one_way_mask_in_SL(action_gt.action_type, device) mask_tensor = SU.get_two_way_mask_in_SL(action_gt.action_type, action_pred, device, strict_comparsion=False) # we now onsider delay loss delay_weight = 0.0 if AHP.use_predict_step_mul: delay_weight = 1.0 delay_loss = delay_weight * criterion(action_gt.delay, delay_logits) loss += delay_loss queue_loss = criterion(action_gt.queue, queue_logits, mask=mask_tensor[:, 2].reshape(-1)) loss += queue_loss batch_size = action_gt.units.shape[0] select_size = action_gt.units.shape[1] units_size = action_gt.units.shape[-1] entity_nums = entity_nums print('entity_nums', entity_nums) if debug else None print('entity_nums.shape', entity_nums.shape) if debug else None # use extended select_size in SL for including EOF extended_select_size = select_size + 1 units_mask = mask_tensor[:, 3] # selected units is in the fourth position of units_mask units_mask = units_mask.unsqueeze(1).repeat(1, extended_select_size) units_mask = units_mask.reshape(-1) print('units_mask', units_mask) if debug else None print('units_mask.shape', units_mask.shape) if debug else None selected_mask = torch.arange(extended_select_size, device=device).float() selected_mask = selected_mask.repeat(batch_size, 1) # note, the select_units_num is actually gt_select_units_num here # we extend select_units_num by 1 to include the EFO selected_mask = selected_mask < (select_units_num + 1).unsqueeze(dim=1) selected_mask = selected_mask.reshape(-1) print('selected_mask', selected_mask) if debug else None print('selected_mask.shape', selected_mask.shape) if debug else None gt_units = action_gt.units.long() padding = torch.zeros(batch_size, 1, units_size, dtype=gt_units.dtype, device=gt_units.device) token = torch.tensor(AHP.max_entities - 1, dtype=padding.dtype, device=padding.device) padding[:, 0] = L.tensor_one_hot(token, units_size).reshape(-1) gt_units = torch.cat([gt_units, padding], dim=1) print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_units[torch.arange(batch_size), select_units_num] = L.tensor_one_hot(entity_nums, units_size).long() print('gt_units', gt_units) if debug else None print('gt_units.shape', gt_units.shape) if debug else None gt_units = gt_units.float() all_units_mask = units_mask * selected_mask # * gt_units_mask # TODO: change to a proporate calculation of selected units selected_units_weight = 10. units_loss = selected_units_weight * criterion( gt_units.reshape(-1, units_size), units_logits.reshape(-1, units_size), mask=all_units_mask, debug=False, outlier_remove=True, entity_nums=entity_nums, select_size=extended_select_size, select_units_num=select_units_num + 1) loss += units_loss target_unit_weight = 1. target_unit_loss = target_unit_weight * criterion( action_gt.target_unit.squeeze(-2), target_unit_logits.squeeze(-2), mask=mask_tensor[:, 4].reshape(-1), debug=False, outlier_remove=True, entity_nums=entity_nums) loss += target_unit_loss batch_size = action_gt.target_location.shape[0] location_weight = 5. target_location_loss = location_weight * criterion( action_gt.target_location.reshape(batch_size, -1), target_location_logits.reshape(batch_size, -1), mask=mask_tensor[:, 5].reshape(-1)) loss += target_location_loss return loss, [ action_type_loss.item(), delay_loss.item(), queue_loss.item(), units_loss.item(), target_unit_loss.item(), target_location_loss.item() ]