示例#1
0
    def search_minibatch(self, count, state_int, player,
                         net, device="cpu"):
        """
        Perform several MCTS searches.
        """
        backup_queue = []
        expand_states = []
        expand_players = []
        expand_queue = []
        planned = set()
        for _ in range(count):
            value, leaf_state, leaf_player, states, actions = \
                self.find_leaf(state_int, player)
            if value is not None:
                backup_queue.append((value, states, actions))
            else:
                if leaf_state not in planned:
                    planned.add(leaf_state)
                    leaf_state_lists = game.decode_binary(
                        leaf_state)
                    expand_states.append(leaf_state_lists)
                    expand_players.append(leaf_player)
                    expand_queue.append((leaf_state, states,
                                         actions))

        # do expansion of nodes
        if expand_queue:
            batch_v = model.state_lists_to_batch(
                expand_states, expand_players, device)
            logits_v, values_v = net(batch_v)
            probs_v = F.softmax(logits_v, dim=1)
            values = values_v.data.cpu().numpy()[:, 0]
            probs = probs_v.data.cpu().numpy()

            # create the nodes
            for (leaf_state, states, actions), value, prob in \
                    zip(expand_queue, values, probs):
                self.visit_count[leaf_state] = [0]*game.GAME_COLS
                self.value[leaf_state] = [0.0]*game.GAME_COLS
                self.value_avg[leaf_state] = [0.0]*game.GAME_COLS
                self.probs[leaf_state] = prob
                backup_queue.append((value, states, actions))

        # perform backup of the searches
        for value, states, actions in backup_queue:
            # leaf state is not stored in states and actions, so the value of the leaf will be the value of the opponent
            cur_value = -value
            for state_int, action in zip(states[::-1],
                                         actions[::-1]):
                self.visit_count[state_int][action] += 1
                self.value[state_int][action] += cur_value
                self.value_avg[state_int][action] = \
                    self.value[state_int][action] / \
                    self.visit_count[state_int][action]
                cur_value = -cur_value
示例#2
0
 def test_encoding(self):
     s = [[0, 1, 0], [0], [1, 1, 1], [], [1], [], []]
     batch_v = model.state_lists_to_batch([s, s], [game.PLAYER_BLACK, game.PLAYER_WHITE])
     batch = batch_v.data.numpy()
     np.testing.assert_equal(
         batch,
         [
             # black player's view
             [
                 # player
                 [
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 1, 0, 0, 0, 0],
                     [1, 0, 1, 0, 0, 0, 0],
                     [0, 0, 1, 0, 1, 0, 0],
                 ],
                 # opponent
                 [
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [1, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0, 0, 0],
                 ],
             ],
             # white player's view
             [
                 # player
                 [
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [1, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0, 0, 0],
                 ],
                 # opponent
                 [
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 1, 0, 0, 0, 0],
                     [1, 0, 1, 0, 0, 0, 0],
                     [0, 0, 1, 0, 1, 0, 0],
                 ],
             ],
         ],
     )
示例#3
0
    def search_minibatch(self,
                         count,
                         state_int,
                         player,
                         net,
                         step,
                         device="cpu"):
        backup_queue = []
        expand_states = []
        expand_steps = []
        expand_queue = []
        planned = set()
        for _ in range(count):
            value, leaf_state, leaf_step, states, actions = self.find_leaf(
                state_int, player, step)
            if value is not None:
                backup_queue.append((value, states, actions))
            else:
                if leaf_state not in planned:
                    planned.add(leaf_state)
                    leaf_state_lists = game.decode_binary(leaf_state)
                    expand_states.append(leaf_state_lists)
                    expand_steps.append(leaf_step)
                    expand_queue.append((leaf_state, states, actions))

        if expand_queue:
            batch_v = model.state_lists_to_batch(expand_states, expand_steps,
                                                 device)
            logits_v, values_v = net(batch_v)
            probs_v = F.softmax(logits_v, dim=1)
            values = values_v.data.cpu().numpy()[:, 0]
            probs = probs_v.data.cpu().numpy()

            for (leaf_state, states,
                 actions), value, prob in zip(expand_queue, values, probs):
                self.visit_count[leaf_state] = [0] * actionTable.AllMoveLength
                self.value[leaf_state] = [0.0] * actionTable.AllMoveLength
                self.value_avg[leaf_state] = [0.0] * actionTable.AllMoveLength
                self.probs[leaf_state] = prob
                backup_queue.append((value, states, actions))

        for value, states, actions in backup_queue:
            cur_value = -value
            for state_int, action in zip(states[::-1], actions[::-1]):
                self.visit_count[state_int][action] += 1
                self.value[state_int][action] += cur_value
                self.value_avg[state_int][action] =\
                    self.value[state_int][action] / self.visit_count[state_int][action]
                cur_value = -cur_value
 def test_encoding(self):
     s = [[0, 1, 0], [0], [1, 1, 1], [], [1], [], []]
     batch_v = model.state_lists_to_batch([s, s], [game.PLAYER_BLACK, game.PLAYER_WHITE])
     batch = batch_v.data.numpy()
     np.testing.assert_equal(batch, [
         # black player's view
         [
             # player
             [
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 1, 0, 0, 0, 0],
                 [1, 0, 1, 0, 0, 0, 0],
                 [0, 0, 1, 0, 1, 0, 0],
             ],
             # opponent
             [
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [1, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0, 0],
             ]
         ],
         # white player's view
         [
             # player
             [
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [1, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0, 0],
             ],
             # opponent
             [
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 0],
                 [0, 0, 1, 0, 0, 0, 0],
                 [1, 0, 1, 0, 0, 0, 0],
                 [0, 0, 1, 0, 1, 0, 0],
             ]
         ],
     ])
    def search_minibatch(self, count, state_int, player, net, device="cpu"):
        """
        Perform several MCTS searches.
        """
        backup_queue = []
        expand_states = []
        expand_players = []
        expand_queue = []
        planned = set()
        for _ in range(count):
            value, leaf_state, leaf_player, states, actions = self.find_leaf(state_int, player)
            if value is not None:
                backup_queue.append((value, states, actions))
            else:
                if leaf_state not in planned:
                    planned.add(leaf_state)
                    leaf_state_lists = game.decode_binary(leaf_state)
                    expand_states.append(leaf_state_lists)
                    expand_players.append(leaf_player)
                    expand_queue.append((leaf_state, states, actions))

        # do expansion of nodes
        if expand_queue:
            batch_v = model.state_lists_to_batch(expand_states, expand_players, device)
            logits_v, values_v = net(batch_v)
            probs_v = F.softmax(logits_v, dim=1)
            values = values_v.data.cpu().numpy()[:, 0]
            probs = probs_v.data.cpu().numpy()

            # create the nodes
            for (leaf_state, states, actions), value, prob in zip(expand_queue, values, probs):
                self.visit_count[leaf_state] = [0] * game.GAME_COLS
                self.value[leaf_state] = [0.0] * game.GAME_COLS
                self.value_avg[leaf_state] = [0.0] * game.GAME_COLS
                self.probs[leaf_state] = prob
                backup_queue.append((value, states, actions))

        # perform backup of the searches
        for value, states, actions in backup_queue:
            # leaf state is not stored in states and actions, so the value of the leaf will be the value of the opponent
            cur_value = -value
            for state_int, action in zip(states[::-1], actions[::-1]):
                self.visit_count[state_int][action] += 1
                self.value[state_int][action] += cur_value
                self.value_avg[state_int][action] = self.value[state_int][action] / self.visit_count[state_int][action]
                cur_value = -cur_value
示例#6
0
                step_idx, game_steps, game_nodes, speed_steps, speed_nodes, best_idx, len(replay_buffer)))
            step_idx += 1

            if len(replay_buffer) < MIN_REPLAY_TO_TRAIN:
                continue

            # train
            sum_loss = 0.0
            sum_value_loss = 0.0
            sum_policy_loss = 0.0

            for _ in range(TRAIN_ROUNDS):
                batch = random.sample(replay_buffer, BATCH_SIZE)
                batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch)
                batch_states_lists = [game.decode_binary(state) for state in batch_states]
                states_v = model.state_lists_to_batch(batch_states_lists, batch_who_moves, device)

                optimizer.zero_grad()
                probs_v = torch.FloatTensor(batch_probs).to(device)
                values_v = torch.FloatTensor(batch_values).to(device)
                out_logits_v, out_values_v = net(states_v)

                loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
                loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
                loss_policy_v = loss_policy_v.sum(dim=1).mean()

                loss_v = loss_policy_v + loss_value_v
                loss_v.backward()
                optimizer.step()
                sum_loss += loss_v.item()
                sum_value_loss += loss_value_v.item()
示例#7
0
        print(step_idx, end=' ')
        step_idx += 1
        if len(replay_buffer) < MIN_REPLAY_TO_TRAIN:
            continue

        ctime=time.time()
        print("%.2f "%(ctime-ptime), end=' ')
        if step_idx%5<1: print()
        ptime=ctime

        for _ in range(TRAIN_ROUNDS):
            batch = random.sample(replay_buffer, BATCH_SIZE)
            batch_states, batch_steps, batch_probs, batch_values = zip(*batch)
            batch_states_lists = [game.decode_binary(state) for state in batch_states]
            states_v = model.state_lists_to_batch(batch_states_lists, batch_steps, device)

            optimizer.zero_grad()
            probs_v = torch.FloatTensor(batch_probs).to(device)
            values_v = torch.FloatTensor(batch_values).to(device)
            out_logits_v, out_values_v = net(states_v)
            #traced_script_module = torch.jit.trace(net, states_v)
            #file_name = os.path.join(saves_path, "best_%d.pt" % (best_idx))
            #traced_script_module.save(file_name); sys.exit()
            loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
            loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
            loss_policy_v = loss_policy_v.sum(dim=1).mean()

            loss_v = loss_policy_v + loss_value_v
            loss_v.backward()
            optimizer.step()
示例#8
0
                continue

            # train
            sum_loss = 0.0
            sum_value_loss = 0.0
            sum_policy_loss = 0.0

            for _ in range(TRAIN_ROUNDS):
                batch = random.sample(replay_buffer, BATCH_SIZE)
                batch_states, batch_who_moves, batch_probs, batch_values = zip(
                    *batch)
                batch_states_lists = [
                    game.decode_binary(state) for state in batch_states
                ]
                states_v = model.state_lists_to_batch(batch_states_lists,
                                                      batch_who_moves,
                                                      args.cuda)

                optimizer.zero_grad()
                probs_v = Variable(torch.FloatTensor(batch_probs))
                values_v = Variable(torch.FloatTensor(batch_values))
                if args.cuda:
                    probs_v = probs_v.cuda()
                    values_v = values_v.cuda()
                out_logits_v, out_values_v = net(states_v)

                loss_value_v = F.mse_loss(out_values_v, values_v)
                loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
                loss_policy_v = loss_policy_v.sum(dim=1).mean()

                loss_v = loss_policy_v + loss_value_v
示例#9
0
def training(tb_tracker,
             net,
             optimizer,
             scheduler,
             replay_buffer,
             probs_queue,
             saves_path,
             step,
             device=torch.device("cpu")):
    tmp_net = net.to(device)

    #    while len(replay_buffer) < MIN_REPLAY_TO_TRAIN:
    #        time.sleep(10)
    #        while not replay_queue.empty():
    #            replay_buffer.append((*replay_queue.get(), probs_queue.get()))
    #        print("replay buffer size: {}, {:.0f} MB".format(
    #                                                    *replay_buffer_size(
    #                                                            replay_buffer)))
    for i in range(TRAIN_STEPS):
        step_idx = TRAIN_STEPS * step + i + 1
        #        while not replay_queue.empty():
        #            if len(replay_buffer) == REPLAY_BUFFER:
        #                replay_buffer.popleft()
        #            replay_buffer.append((*replay_queue.get(), probs_queue.get()))
        #        print("replay buffer size: {}, {:.0f} MB".format(
        #                                                    *replay_buffer_size(
        #                                                            replay_buffer)))
        # train
        sum_loss = 0.0
        sum_value_loss = 0.0
        sum_policy_loss = 0.0
        t_train = time.time()

        for _ in range(TRAIN_ROUNDS):
            batch = random.sample(replay_buffer, BATCH_SIZE)
            batch_states, _, batch_values, batch_probs = zip(*batch)
            batch_states_lists = [
                game_c.decode_binary(state) for state in batch_states
            ]
            states_v = model.state_lists_to_batch(batch_states_lists, device)

            optimizer.zero_grad()
            probs_v = torch.FloatTensor(batch_probs).to(device)
            values_v = torch.FloatTensor(batch_values).to(device)
            out_logits_v, out_values_v = tmp_net(states_v)

            del batch
            del batch_states, batch_probs, batch_values
            del batch_states_lists
            del states_v

            loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
            # cross entropy loss
            loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
            loss_policy_v = loss_policy_v.sum(dim=1).mean()

            loss_v = loss_policy_v + loss_value_v
            loss_v.backward()
            optimizer.step()
            sum_loss += loss_v.item()
            sum_value_loss += loss_value_v.item()
            sum_policy_loss += loss_policy_v.item()

            del probs_v, values_v, out_logits_v, out_values_v
            del loss_value_v, loss_policy_v, loss_v


#        scheduler.step(sum_loss / TRAIN_ROUNDS, step_idx)
        scheduler.step()
        tb_tracker.track("loss_total", sum_loss / TRAIN_ROUNDS, step_idx)
        tb_tracker.track("loss_value", sum_value_loss / TRAIN_ROUNDS, step_idx)
        tb_tracker.track("loss_policy", sum_policy_loss / TRAIN_ROUNDS,
                         step_idx)

        print("Training step #{}: {:.2f} [s]".format(step_idx,
                                                     time.time() - t_train))
        t_train = time.time()

    # save net
    file_name = os.path.join(saves_path, "%06d.dat" % (step_idx))
    print("Model is saved as {}".format(file_name))
    torch.save(net.state_dict(), file_name)
示例#10
0
    def search_minibatch(self,
                         count,
                         state_int,
                         player,
                         net,
                         root_mask,
                         device="cpu"):
        """
        Perform several MCTS searches.
        """
        backup_queue = []
        expand_queue = []
        planned = set()
        for i in range(count):
            value, leaf_state, leaf_player, states, actions = self.find_leaf(
                state_int, player, root_mask)
            self.subtrees.append(states)

            # end of the game
            if value is not None:
                backup_queue.append((value, states, actions))
            # encounter leaf node which is not end of the game
            else:
                # avoid duplication of leaf state
                if leaf_state not in planned:
                    planned.add(leaf_state)
                    expand_queue.append((leaf_state, states, actions))
                else:
                    states.clear()
                    self.subtrees.pop()
        del planned

        # do expansion of nodes
        if expand_queue:
            expand_states = []
            keys = self.visited_net_results.keys()
            new_expand_queue = []
            existed_expand_queue = []
            value_list = []
            prob_list = []
            rotate_list = []
            new_rotate_list = []
            for leaf_state, states, actions in expand_queue:
                rotate_num = np.random.randint(8)
                if (leaf_state, rotate_num) in keys:
                    existed_expand_queue.append((leaf_state, states, actions))
                    rotate_list.append(rotate_num)
                    value, prob = self.visited_net_results[(leaf_state,
                                                            rotate_num)]
                    value_list.append(value)
                    prob_list.append(prob)
                else:
                    new_expand_queue.append((leaf_state, states, actions))
                    new_rotate_list.append(rotate_num)
                    leaf_state_lists = game_c.decode_binary(leaf_state)
                    expand_states.append(leaf_state_lists)
            expand_queue = [*existed_expand_queue, *new_expand_queue]
            rotate_list.extend(new_rotate_list)

            if len(new_rotate_list) == 0:
                values = value_list
                probs = prob_list
            else:
                batch_v = model.state_lists_to_batch(expand_states, device,
                                                     new_rotate_list)
                logits_v, values_v = net(batch_v)
                probs_v = F.softmax(logits_v, dim=1)
                values = values_v.data.cpu().numpy()[:, 0]
                probs = probs_v.data.cpu().numpy()

                values = [*value_list, *list(values)]
                probs = [*prob_list, *list(probs)]

            expand_states.clear()

            # create the nodes
            for (leaf_state, states, actions), value, prob, rotate_num in zip(
                    expand_queue, values, probs, rotate_list):
                #            for (leaf_state, states, actions), value, prob in zip(expand_queue, values, probs):
                self.visit_count[leaf_state] = np.zeros(game.BOARD_SIZE**2 + 1,
                                                        dtype=np.int32)
                self.value[leaf_state] = np.zeros(game.BOARD_SIZE**2 + 1,
                                                  dtype=np.float32)
                self.value_avg[leaf_state] = np.zeros(game.BOARD_SIZE**2 + 1,
                                                      dtype=np.float32)
                prob_without_pass = prob[:-1].reshape(
                    [game.BOARD_SIZE, game.BOARD_SIZE])
                prob_without_pass = game.multiple_transform(
                    prob_without_pass, rotate_num, True)
                self.probs[leaf_state] = np.concatenate(
                    [prob_without_pass.flatten(), [prob[-1]]])
                #                self.probs[leaf_state] = prob
                backup_queue.append((value, states, actions))
                self.visited_net_results[(leaf_state, rotate_num)] = (value,
                                                                      prob)
            rotate_list.clear()

        expand_queue.clear()

        # perform backup of the searches
        for value, states, actions in backup_queue:
            # leaf state is not stored in states and actions, so the value of the leaf will be the value of the opponent
            cur_value = -value
            for state_int, action in zip(states[::-1], actions[::-1]):
                self.visit_count[state_int][action] += 1
                self.value[state_int][action] += cur_value
                self.value_avg[state_int][action] = self.value[state_int][
                    action] / self.visit_count[state_int][action]
                cur_value = -cur_value
            actions.clear()
        backup_queue.clear()