def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    mean_payoffs = get_example_2x2_payoffs()
    game = response_graph_ucb_utils.BernoulliGameSampler(
        [2, 2], mean_payoffs, payoff_bounds=[-1., 1.])
    game.p_max = mean_payoffs
    game.means = mean_payoffs
    print('Game means:\n', game.means)

    exploration_strategy = 'uniform-exhaustive'
    confidence_method = 'ucb-standard'
    r_ucb = response_graph_ucb.ResponseGraphUCB(
        game,
        exploration_strategy=exploration_strategy,
        confidence_method=confidence_method,
        delta=0.1)
    results = r_ucb.run()

    # Plotting
    print('Number of total samples: {}'.format(np.sum(r_ucb.count[0])))
    r_ucb.visualise_2x2x2(real_values=game.means, graph=results['graph'])
    r_ucb.visualise_count_history(figsize=(5, 3))
    plt.gca().xaxis.label.set_fontsize(15)
    plt.gca().yaxis.label.set_fontsize(15)

    # Compare to ground truth graph
    real_graph = r_ucb.construct_real_graph()
    r_ucb.plot_graph(real_graph)
    plt.show()
Beispiel #2
0
    def test_sampler(self):
        mean_payoffs = self.get_example_2x2_payoffs()
        game = response_graph_ucb_utils.BernoulliGameSampler(
            [2, 2], mean_payoffs, payoff_bounds=[-1., 1.])
        game.p_max = mean_payoffs
        game.means = mean_payoffs

        # Parameters to run
        sampling_methods = [
            'uniform-exhaustive', 'uniform', 'valence-weighted',
            'count-weighted'
        ]
        conf_methods = [
            'ucb-standard', 'ucb-standard-relaxed', 'clopper-pearson-ucb',
            'clopper-pearson-ucb-relaxed'
        ]
        per_payoff_confidence = [True, False]
        time_dependent_delta = [True, False]

        methods = list(
            itertools.product(sampling_methods, conf_methods,
                              per_payoff_confidence, time_dependent_delta))
        max_total_interactions = 50

        for m in methods:
            r_ucb = response_graph_ucb.ResponseGraphUCB(
                game,
                exploration_strategy=m[0],
                confidence_method=m[1],
                delta=0.1,
                ucb_eps=1e-1,
                per_payoff_confidence=m[2],
                time_dependent_delta=m[3])
            _ = r_ucb.run(max_total_iterations=max_total_interactions)
Beispiel #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Parameters to run
    deltas = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5]
    sampling_methods = [
        'uniform-exhaustive', 'uniform', 'valence-weighted', 'count-weighted'
    ]
    conf_methods = [
        'ucb-standard', 'ucb-standard-relaxed', 'clopper-pearson-ucb',
        'clopper-pearson-ucb-relaxed'
    ]

    methods = list(itertools.product(sampling_methods, conf_methods))
    mean_counts = {m: [[] for _ in range(len(deltas))] for m in methods}
    edge_errs = {m: [[] for _ in range(len(deltas))] for m in methods}

    if FLAGS.game_name == 'bernoulli':
        max_total_interactions = 50000
        repetitions = 20
    elif FLAGS.game_name == 'soccer':
        max_total_interactions = 100000
        repetitions = 5
    elif FLAGS.game_name == 'kuhn_poker_3p':
        max_total_interactions = 100000
        repetitions = 5
    else:
        raise ValueError(
            'game_name must be "bernoulli", "soccer", or "kuhn_poker_3p".')

    for r in range(repetitions):
        print('Iteration {}'.format(r + 1))
        G = utils.get_game_for_sampler(FLAGS.game_name)  # pylint: disable=invalid-name

        for m in methods:
            print('  Method: {}'.format(m))
            for ix, d in enumerate(deltas):
                print('    Delta: {}'.format(d))
                r_ucb = response_graph_ucb.ResponseGraphUCB(
                    G,
                    exploration_strategy=m[0],
                    confidence_method=m[1],
                    delta=d,
                    ucb_eps=1e-1)
                results = r_ucb.run(
                    max_total_iterations=max_total_interactions)

                # Updated
                mean_counts[m][ix].append(results['interactions'])
                real_graph = r_ucb.construct_real_graph()
                edge_errs[m][ix].append(
                    utils.digraph_edge_hamming_dist(real_graph,
                                                    results['graph']))

    # Plotting
    _, axes = plt.subplots(1, 2, figsize=(10, 4))
    max_mean_count = 0
    for m in methods:
        utils.plot_timeseries(axes,
                              id_ax=0,
                              data=np.asarray(mean_counts[m]).T,
                              xticks=deltas,
                              xlabel=r'$\delta$',
                              ylabel='Interactions required',
                              label=utils.get_method_tuple_acronym(m),
                              logx=True,
                              logy=True,
                              linespecs=utils.get_method_tuple_linespecs(m))
        if np.max(mean_counts[m]) > max_mean_count:
            max_mean_count = np.max(mean_counts[m])
    plt.xlim(left=np.min(deltas), right=np.max(deltas))
    plt.ylim(top=max_mean_count * 1.05)

    max_error = 0
    for m in methods:
        utils.plot_timeseries(axes,
                              id_ax=1,
                              data=np.asarray(edge_errs[m]).T,
                              xticks=deltas,
                              xlabel=r'$\delta$',
                              ylabel='Response graph errors',
                              label=utils.get_method_tuple_acronym(m),
                              logx=True,
                              logy=False,
                              linespecs=utils.get_method_tuple_linespecs(m))
        if np.max(edge_errs[m]) > max_error:
            max_error = np.max(edge_errs[m])
    plt.xlim(left=np.min(deltas), right=np.max(deltas))
    plt.ylim(bottom=0, top=max_error * 1.05)

    # Shared legend
    plt.figure(figsize=(1, 6))
    plt.figlegend(*axes[0].get_legend_handles_labels(),
                  loc='center right',
                  bbox_to_anchor=(0.8, 0.5),
                  bbox_transform=plt.gcf().transFigure,
                  ncol=1,
                  handlelength=1.7)
    plt.tight_layout()
    plt.show()