コード例 #1
0
ファイル: policy.py プロジェクト: shibei00/tetris_mcts
def policy_clt(nodes, visit, value, variance):

    n = visit.sum()

    _q = value + norm_quantile(n) * np.sqrt(variance / visit)

    return nodes[np.argmax(_q)]
コード例 #2
0
def policy_gauss(child_nodes, node_stats, curr_reward):
    stats = np.zeros((2, len(child_nodes)), dtype=np.float32)
    n = 0
    for i, c in enumerate(child_nodes):
        n += node_stats[c][0]
        stats[0][i] = node_stats[c][1] + node_stats[c][2] - curr_reward
        stats[1][i] = node_stats[c][3]

    _q = stats[0] + norm_quantile(n) * np.sqrt(stats[1])

    return child_nodes[np.argmax(_q)]
コード例 #3
0
def policy_dist(child_nodes, node_stats, node_dist, curr_reward, vmin, vmax):
    stats = np.zeros((2, len(child_nodes)), dtype=np.float32)
    n = 0
    for i, c in enumerate(child_nodes):
        n += node_stats[c][0]
        mean, var = mean_variance(node_dist[c], vmin, vmax)
        stats[0][i] = node_stats[c][1] + node_stats[c][2] - curr_reward
        #stats[1][i] = var / (node_stats[c][0] + eps)
        stats[1][i] = node_stats[c][3] / (node_stats[c][0] + eps)

    q = stats[0] + norm_quantile(n) * np.sqrt(stats[1])

    return child_nodes[np.argmax(q)]
コード例 #4
0
ファイル: core.py プロジェクト: ktp-forked-repos/tetris_mcts
def select_index_clt(index, child, node_stats):

    trace = []

    while True:

        trace.append(index)

        _child_nodes = []
        for i in range(n_actions):
            if child[index][i] != 0:
                _child_nodes.append(child[index][i])

        _child_nodes = list(set(_child_nodes))

        len_c = len(_child_nodes)

        if len_c == 0:
            break

        has_unvisited_node = False

        _stats = np.zeros((2, len_c), dtype=np.float32)

        _n = 0

        for i in range(len_c):
            _idx = _child_nodes[i]
            if node_stats[_idx][0] == 0:
                index = _idx
                has_unvisited_node = True
                break
            _n += node_stats[_idx][0]
            _stats[0][i] = node_stats[_idx][1] + node_stats[_idx][
                2] - node_stats[index][2]
            _stats[1][i] = node_stats[_idx][3] / node_stats[_idx][0]

        if has_unvisited_node:
            continue

        _c = np.sqrt(_stats[1]) * norm_quantile(_n)

        _q = _stats[0]

        _v = _q + _c

        _a = np.argmax(_v)

        index = _child_nodes[_a]

    return np.array(trace, dtype=np.int32)