示例#1
0
def double_switch_local_search(initial_split, impurity_fn):
    """Does a local search, switching sides of two nodes at a time, one in each.

    impurity_fn receives left_values and right_values and return the impurity.
    """
    best_split = split.Split(left_values=initial_split.left_values,
                             right_values=initial_split.right_values,
                             impurity=initial_split.impurity)
    found_improvement = True
    while found_improvement:
        found_improvement = False
        for left_value, right_value in itertools.product(
                best_split.left_values, best_split.right_values):
            curr_left_values = (best_split.left_values - set([left_value]) +
                                set([right_value]))
            curr_right_values = (best_split.right_values + set([left_value]) -
                                 set([right_value]))
            curr_split = split.Split(left_values=curr_left_values,
                                     right_values=curr_right_values,
                                     impurity=impurity_fn(
                                         curr_left_values, curr_right_values))
            if curr_split.is_better_than(best_split):
                found_improvement = True
                best_split = curr_split
                break
    return best_split
示例#2
0
    def test_is_better_than(self):
        default_split_1 = split.Split()
        default_split_2 = split.Split()
        split_1 = split.Split(left_values=set([1]), right_values=set([2, 3]), impurity=1.0)
        split_2 = split.Split(left_values=set([1, 2]), right_values=set([3]), impurity=2.0)

        self.assertTrue(split_1.is_better_than(split_2))
        self.assertFalse(split_2.is_better_than(split_1))
        self.assertTrue(split_1.is_better_than(default_split_1))
        self.assertTrue(split_2.is_better_than(default_split_1))
        self.assertFalse(default_split_1.is_better_than(default_split_2))
        self.assertFalse(default_split_2.is_better_than(default_split_1))
示例#3
0
    def test_eq(self):
        default_split = split.Split()
        split_1 = split.Split(left_values=set([1]), right_values=set([2]), impurity=1.0)
        split_2 = split.Split(left_values=set([1]), right_values=set([2]), impurity=2.0)
        split_3 = split.Split(left_values=set([1, 2]), right_values=set([2]), impurity=2.0)
        split_4 = split.Split(left_values=set([1]), right_values=set([2, 3]), impurity=2.0)

        self.assertFalse(split_1 == default_split)
        self.assertTrue(split_1 == split_2)
        self.assertFalse(split_1 == split_3)
        self.assertFalse(split_1 == split_4)
        self.assertTrue(split_2 == split_3)
        self.assertTrue(split_2 == split_4)
示例#4
0
def single_switch_local_search(initial_split, impurity_fn):
    """Does a local search, switching the side of one node at a time.

    impurity_fn receives left_values and right_values and return the impurity.
    """
    best_split = split.Split(left_values=initial_split.left_values,
                             right_values=initial_split.right_values,
                             impurity=initial_split.impurity)
    found_improvement = True
    while found_improvement:
        found_improvement = False
        for value in best_split.left_values:
            curr_left_values = best_split.left_values - set([value])
            curr_right_values = best_split.right_values + set([value])
            curr_split = split.Split(left_values=curr_left_values,
                                     right_values=curr_right_values,
                                     impurity=impurity_fn(
                                         curr_left_values, curr_right_values))
            if curr_split.is_better_than(best_split):
                found_improvement = True
                best_split = curr_split
                break
        if found_improvement:
            continue
        for value in best_split.right_values:
            curr_left_values = best_split.left_values + set([value])
            curr_right_values = best_split.right_values - set([value])
            curr_split = split.Split(left_values=curr_left_values,
                                     right_values=curr_right_values,
                                     impurity=impurity_fn(
                                         curr_left_values, curr_right_values))
            if curr_split.is_better_than(best_split):
                found_improvement = True
                best_split = curr_split
                break
    return best_split
示例#5
0
    def test_is_valid(self):
        curr_split = split.Split(left_values=set([1]), right_values=set([2]), impurity=1.0)

        self.assertTrue(curr_split.is_valid())
示例#6
0
    def test_default_is_not_valid(self):
        default_split = split.Split()

        self.assertFalse(default_split.is_valid())
示例#7
0
    def test_iteration_number(self):
        default_split = split.Split()
        default_split.set_iteration_number(10)

        self.assertEqual(10, default_split.iteration_number)
示例#8
0
def simulate(method,
             periods,
             true_rates,
             deviation,
             change,
             trials,
             max_p=None,
             rounding=True,
             accelerate=True,
             memory=True,
             shape='linear',
             cutoff=28,
             cut_level=0.5):
    """
    Simulate option choosing and results adding for n periods
    and a given chooser, return respective successes with optimum and base
    """
    num_options = len(true_rates)

    rate_changes = [
        random.uniform(1 - change, 1 + change) for rate in true_rates
    ]

    # Initialize Split or Bandit instances
    if method == 'split':
        chooser = spl.Split(num_options=num_options)
    elif method == 'bandit':
        chooser = ban.Bandit(num_options=num_options,
                             memory=memory,
                             shape=shape,
                             cutoff=cutoff,
                             cut_level=cut_level)

    # For each period calculate and add successes for methods as well as
    # the optimal (max) and the random choice (base)
    successes = []
    max_successes = []
    base_successes = []
    for period in range(periods):
        # Calculate success rates under uncertainty (with deviation)
        rates = [
            min(
                max(
                    np.random.RandomState((i + 1) * (period + 1)).normal(
                        loc=rate * rate_changes[i]**period,
                        scale=rate * rate_changes[i]**period * deviation), 0),
                1) for i, rate in enumerate(true_rates)
        ]

        # Add results to Split or Bandit
        if method == 'split':
            successes.append(
                add_split_results(trials, max_p, rates, chooser, period,
                                  rounding))
        elif method == 'bandit':
            if memory:
                chooser.add_period()
            successes.append(
                add_bandit_results(num_options, trials, rates, chooser, period,
                                   rounding, accelerate))

        # Add results to max and base successes
        if period == 0:
            if rounding:
                max_successes = [round(trials * max(rates))]
                base_successes = [
                    np.sum([
                        round(trials / num_options * rates[i])
                        for i in range(num_options)
                    ])
                ]
            else:
                max_successes = [trials * max(rates)]
                base_successes = [
                    np.sum([
                        trials / num_options * rates[i]
                        for i in range(num_options)
                    ])
                ]
        else:
            if rounding:
                max_successes.append(max_successes[-1] +
                                     round(trials * max(rates)))
                base_successes.append(base_successes[-1] + np.sum([
                    round(trials / num_options * rates[i])
                    for i in range(num_options)
                ]))
            else:
                max_successes.append(max_successes[-1] + trials * max(rates))
                base_successes.append(base_successes[-1] + np.sum([
                    trials / num_options * rates[i] for i in range(num_options)
                ]))

    return [successes, max_successes, base_successes]