Beispiel #1
0
    def base_net_forward(self,
                         available_action_ids,
                         minimap_numeric,
                         player_relative_minimap,
                         player_relative_screen,
                         screen_numeric,
                         screen_unit_type,
                         last_action_reward,
                         lstm_state=None):
        player_relative_screen = one_hot_encoding(
            player_relative_screen, SCREEN_FEATURES.player_relative.scale,
            (FLAGS.resolution, FLAGS.resolution))

        player_relative_minimap = one_hot_encoding(
            player_relative_minimap, MINIMAP_FEATURES.player_relative.scale,
            (FLAGS.resolution, FLAGS.resolution))
        map_output = self.base_conv_net(available_action_ids, minimap_numeric,
                                        player_relative_minimap,
                                        player_relative_screen, screen_numeric,
                                        screen_unit_type)
        lstm_output, lstm_state = self.base_lstm_net(map_output,
                                                     last_action_reward,
                                                     lstm_state)
        fc1, value = self.base_value_net(lstm_output)
        pi_out, sp_pi_out = self.base_policy_net(available_action_ids,
                                                 lstm_output, fc1)
        return pi_out, sp_pi_out, value, lstm_state
Beispiel #2
0
    def forward(self, available_action_ids, minimap_numeric,
                player_relative_minimap, player_relative_screen,
                screen_numeric, screen_unit_type):
        units_embedded = self.emb(screen_unit_type)
        # Let's not one-hot zero which is background
        player_relative_screen_one_hot = one_hot_encoding(
            player_relative_screen, SCREEN_FEATURES.player_relative.scale,
            (32, 32))

        player_relative_minimap_one_hot = one_hot_encoding(
            player_relative_minimap, MINIMAP_FEATURES.player_relative.scale,
            (32, 32))
        units_embedded_nchw = units_embedded.permute(0, 3, 1, 2)
        screen_numeric_all = torch.cat((screen_numeric, units_embedded_nchw,
                                        player_relative_screen_one_hot),
                                       dim=1)  # 13  5  4
        minimap_numeric_all = torch.cat(
            (minimap_numeric, player_relative_minimap_one_hot), dim=1)  # 5 4
        screen_output_1 = self.screen_conv1(screen_numeric_all)
        screen_output_2 = self.screen_conv2(screen_output_1)
        minimap_output_1 = self.minimap_conv1(minimap_numeric_all)
        minimap_output_2 = self.minimap_conv2(minimap_output_1)
        map_output = torch.cat([screen_output_2, minimap_output_2], dim=1)
        spatial_action_logits = self.spatial_action_conv(map_output)
        spatial_action_probs = F.softmax(spatial_action_logits.view(
            spatial_action_logits.size(0), -1),
                                         dim=1)
        map_output_flat = map_output.view(map_output.size(0), -1)
        fc_1 = F.relu(self.fc1(map_output_flat))
        action_output = F.softmax(self.action_fc(fc_1), dim=1)
        value = torch.squeeze(self.value_fc(fc_1))

        action = action_output * available_action_ids
        action_id_probs = action / torch.sum(action, dim=1, keepdim=True)
        return action_id_probs, spatial_action_probs, value