コード例 #1
0
ファイル: mcts.py プロジェクト: halflearned/organ-matching-rl
def rollout(env, t_begin, t_end, taken, gamma=0.97):

    snap = snapshot(env, t_begin)
    snap.populate(t_begin + 1, t_end, seed=clock_seed())

    opt_take, opt_leave = compare_optimal(snap, t_begin + 1, t_end, set(taken))

    m_take = get_n_matched(opt_take["matched"], t_begin, t_end)
    m_leave = get_n_matched(opt_leave["matched"], t_begin, t_end)
    m_take[0] = len(taken)

    value_leave = disc_mean(m_leave, gamma)
    value_take = disc_mean(m_take, gamma)

    return value_take / value_leave
コード例 #2
0
def rollout(env, t_begin, t_end, taken, gamma=0.97):

    snap = snapshot(env, t_begin)
    snap.populate(t_begin + 1, t_end, seed=clock_seed())
    snap.removed_container[t_begin].update(taken)

    value = greedy(snap, t_begin + 1, t_end)
    matched = get_n_matched(value["matched"], t_begin, t_end)
    matched[0] = len(taken)

    return disc_mean(matched, gamma)
コード例 #3
0
def rollout(env, t_begin, t_end, taken, gamma):

    snap = snapshot(env, t_begin)
    snap.populate(t_begin + 1, t_end, seed=clock_seed())
    snap.removed_container[t_begin].update(taken)

    opt = optimal(snap, t_begin + 1, t_end)
    opt_matched = get_n_matched(opt["matched"], t_begin, t_end)
    opt_matched[0] = len(taken)
    opt_value = disc_mean(opt_matched, gamma)

    #    g = greedy(snap, t_begin+1, t_end)
    #    g_matched = get_n_matched(g["matched"], t_begin, t_end)
    #    g_matched[0] = len(taken)
    #    g_value = disc_mean(g_matched,  gamma)

    r = opt_value  #- g_value

    return r
コード例 #4
0
    disc = 0.1

    net = torch.load("results/RNN_50-1-abo_4386504")

    #%%

    for k in [2]:

        print("Creating environment")
        env = ABOKidneyExchange(entry_rate, death_rate, time_length, seed=k)

        print("Solving environment")
        opt = optimal(env)
        gre = greedy(env)

        o = get_n_matched(opt["matched"], 0, env.time_length)
        g = get_n_matched(gre["matched"], 0, env.time_length)

        rewards = []
        actions = []
        t = -1
        print("Beginning")
        #%%
        for t in range(env.time_length):

            living = np.array(env.get_living(t))
            if len(living) == 1:
                continue

            probs, counts = evaluate_policy(net, env, t)
            _, cycles = get_cycles(env, living)
コード例 #5
0
            seed = 123456  #clock_seed()

        print("Opening file", file)
        try:
            net = torch.load("results/" + file)
        except Exception as e:
            print(str(e))
            continue

        env_type = "abo"
        env = ABOKidneyExchange(entry_rate, death_rate, max_time, seed=seed)

        opt = optimal(env)
        gre = greedy(env)

        o = get_n_matched(opt["matched"], 0, env.time_length)
        g = get_n_matched(gre["matched"], 0, env.time_length)

        rewards = np.zeros(env.time_length)

        #%%
        np.random.seed(clock_seed())
        for t in range(env.time_length):
            probs, count = evaluate_policy(net, env, t, dtype="numpy")

            for i in range(count):
                probs, _ = evaluate_policy(net, env, t, dtype="numpy")
                cycles = two_cycles(env, t)
                if len(cycles) == 0:
                    break
                elif len(cycles) == 1: