Example #1
0
class Trainer:
    def __init__(self,
                 model_config,
                 game_config,
                 player1_model=None,
                 player1_mode='model_training',
                 player2_model=None,
                 player2_mode='random',
                 model_dir='.'):

        self.model_config = model_config
        self.game_config = game_config
        self.device = torch.device(
            'cuda:0') if torch.cuda.is_available() else torch.device('cpu')

        self.player1_mode = player1_mode
        self.player2_mode = player2_mode

        self.player1_model = player1_model
        self.player2_model = player2_model

        if player1_mode.startswith('model'):
            if player1_model is None:
                self.player1_model = GameModel(game_config)
            self.player1_model.to(self.device)

        if player2_mode.startswith('model'):
            if player2_model is None:
                self.player2_model = GameModel(game_config)
            self.player2_model.to(self.device)

        self.model_dir = model_dir
        self.vis = Visdom(env=self.model_config.model_name)

    def collect_game_data(self):
        win_num = 0
        game_data = []
        start_time = time.time()

        for game_round in range(self.model_config.round_num):
            # initialize game
            deck1 = [
                c[1] for c in self.game_config.card_config.cards_list[1:] * 4
            ]
            player1 = HearthStoneGod('Player1',
                                     deck1,
                                     CardClass.MAGE.default_hero,
                                     self.game_config,
                                     game_model=self.player1_model,
                                     mode=self.player1_mode,
                                     device=self.device)

            deck2 = [
                c[1] for c in self.game_config.card_config.cards_list[1:] * 4
            ]
            player2 = HearthStoneGod('Player2',
                                     deck2,
                                     CardClass.MAGE.default_hero,
                                     self.game_config,
                                     game_model=self.player2_model,
                                     mode=self.player2_mode,
                                     device=self.device)

            game = Game(players=[player1, player2])
            game.start()

            # play game
            # mulligan
            player1.mulligan(skip=True)
            player2.mulligan(skip=True)

            try:
                while True:
                    player = game.current_player
                    player.play_turn(game)

            except GameOver:
                if player2.hero.dead:
                    win_num += 1
                    game_data.append((player1.replay, player1.hero.health))
                else:
                    game_data.append((player1.replay, -player2.hero.health))

        end_time = time.time()
        win_rate = win_num / self.model_config.round_num

        return game_data, win_rate

    def train(self):
        # set seed
        random.seed(self.model_config.seed)
        torch.manual_seed(self.model_config.seed)
        np.random.seed(self.model_config.seed)

        device = self.device

        if self.model_config.optim == 'Adam':
            optimizer = torch.optim.Adam(self.player1_model.parameters(),
                                         self.model_config.learning_rate)
        elif self.model_config.optim == 'SGD':
            optimizer = torch.optim.SGD(self.player1_model.parameters(),
                                        self.model_config.learning_rate)
        else:
            raise NotImplementedError(self.model_config.optim +
                                      " is not implemented!")

        best_win_rate = 0

        win_rates = []

        for epoch in range(self.model_config.epoch):
            game_data, win_rate = self.collect_game_data()

            # plot win rates
            win_rates.append(win_rate)
            self.vis.line(win_rates,
                          np.arange(epoch + 1),
                          win='win rates',
                          opts={'title': 'win rates'})

            # save
            if win_rate > best_win_rate or epoch % 10 == 0:
                self.save(epoch + 1, win_rate)
                best_win_rate = win_rate

            print(f'epoch {epoch} win rate:', win_rate)

            # dataloader
            dataset = HearthStoneDataset(game_data)
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=self.model_config.batch_size, shuffle=True)

            # train
            for step, data in enumerate(dataloader):
                optimizer.zero_grad()

                hand, hero, current_minions, opponent, opponent_minions, \
                    action_mask, targets_mask, action, target, reward = data

                hand, hero, current_minions = hand.to(device), hero.to(
                    device), current_minions.to(device)
                opponent, opponent_minions = opponent.to(
                    device), opponent_minions.to(device)
                action_mask, targets_mask = action_mask.to(
                    device), targets_mask.to(device)
                action, target, reward = action.to(device), target.to(
                    device), reward.to(device)

                # actions
                action_policy, action_logits, game_state = self.player1_model.get_action(
                    hand, hero, current_minions, opponent, opponent_minions,
                    action_mask)
                action_loss = -(reward * action_policy.log_prob(action)).mean()

                # targets
                targets_policy = self.player1_model.get_target(
                    action_logits, game_state, targets_mask)
                target_loss = -(reward *
                                targets_policy.log_prob(target)).mean()

                # total loss
                loss = action_loss + target_loss

                if step % 100 == 0:
                    print(loss.item())

                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.player1_model.parameters(),
                    self.model_config.grad_clip,
                    norm_type=self.model_config.grad_norm_type)
                optimizer.step()

    def save(self, epoch, win_rate):
        save_data = {
            'game_config': self.game_config,
            'model_config': self.model_config,
            'state_dict': self.player1_model.state_dict(),
            'epoch': epoch,
            'win_rate': win_rate
        }

        save_dir = os.path.join(self.model_dir, self.model_config.model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        save_path = os.path.join(save_dir, f'{epoch}_{win_rate}.pth')
        torch.save(save_data, save_path)