Beispiel #1
0
 def test_bad_given_thresholds_not_in_valid_range(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     bad_range_thresholds = [THRESHOLD_MAX * 2] * total_bits
     with pytest.raises(InvalidSpecError):
         make_real_mux_env(num_address_bits=num_address_bits,
                           thresholds=bad_range_thresholds)
Beispiel #2
0
 def test_bad_given_thresholds_not_float(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     non_float_thresholds = [int(THRESHOLD_MIN)] * total_bits
     with pytest.raises(InvalidSpecError):
         make_real_mux_env(num_address_bits=num_address_bits,
                           thresholds=non_float_thresholds)
Beispiel #3
0
 def test_bad_given_thresholds_incorrect_len_2d_array(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     bad_len_thresholds = [[0.5] * total_bits, [0.5] * total_bits]
     with pytest.raises(InvalidSpecError):
         make_real_mux_env(num_address_bits=num_address_bits,
                           thresholds=bad_len_thresholds)
Beispiel #4
0
 def test_labels_values(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_bits
     mux = make_real_mux_env(num_address_bits=num_address_bits,
                             thresholds=thresholds)
     assert np.all((mux.labels == 0) | (mux.labels == 1))
Beispiel #5
0
 def test_data_values(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_bits
     mux = make_real_mux_env(num_address_bits=num_address_bits,
                             thresholds=thresholds)
     assert np.all(0.0 <= mux.data) and np.all(mux.data < 1.0)
Beispiel #6
0
 def test_labels_dtype(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_bits
     mux = make_real_mux_env(num_address_bits=num_address_bits,
                             thresholds=thresholds)
     for elem in mux.labels.flatten():
         assert isinstance(elem, np.integer)
Beispiel #7
0
 def test_data_dtype(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_bits
     mux = make_real_mux_env(num_address_bits=num_address_bits,
                             thresholds=thresholds)
     for elem in mux.data.flatten():
         assert isinstance(elem, np.floating)
Beispiel #8
0
 def test_obs_space_integrity(self):
     num_address_bits = 1
     total_feature_dims = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_feature_dims
     mux = make_real_mux_env(thresholds=thresholds, num_address_bits=1)
     assert len(mux.obs_space) == total_feature_dims
     for dim in mux.obs_space:
         assert dim.lower == 0.0
         assert dim.upper == 1.0
Beispiel #9
0
 def test_action_set_integrity(self):
     num_address_bits = 1
     total_bits = calc_total_bits(num_address_bits)
     thresholds = [0.5] * total_bits
     mux = make_real_mux_env(thresholds=thresholds, num_address_bits=1)
     assert mux.action_set == {0, 1}
Beispiel #10
0
 def test_bad_num_address_bits_negative(self):
     with pytest.raises(InvalidSpecError):
         make_real_mux_env(num_address_bits=-1, thresholds=[])
Beispiel #11
0
    def test_sample_training_run(self):
        # 6-mux
        env = make_real_mux_env(thresholds=[0.5] * 6,
                                num_address_bits=2,
                                shuffle_dataset=True,
                                reward_correct=1000,
                                reward_incorrect=0,
                                num_samples=64)

        rule_repr = CentreSpreadRuleRepr(env.obs_space)

        alg_hyperparams = {
            "N": 400,
            "beta": 0.2,
            "alpha": 0.1,
            "epsilon_nought": 0.01,
            "nu": 5,
            "gamma": 0.71,
            "theta_ga": 12,
            "chi": 0.8,
            "mu": 0.04,
            "theta_del": 20,
            "delta": 0.1,
            "theta_sub": 20,
            "p_wildcard": 0.33,
            "prediction_I": 1e-3,
            "epsilon_I": 1e-3,
            "fitness_I": 1e-3,
            "p_explore": 0.5,
            "theta_mna": len(env.action_set),
            "do_ga_subsumption": True,
            "do_as_subsumption": True,
            "seed": 0,
            "m": 0.1,
            "s_nought": 1.0
        }

        alg = make_canonical_xcs(env, rule_repr, alg_hyperparams)

        monitor_items = [
            MonitorItem("num_micros",
                        lambda experiment: experiment.population.num_micros),
            MonitorItem("num_macros",
                        lambda experiment: experiment.population.num_macros),
            MonitorItem(
                "performance", lambda experiment: experiment.calc_performance(
                    strat="accuracy")),
            MonitorItem(
                "mean_error", lambda experiment: calc_summary_stat(
                    experiment.population, "mean", "error")),
            MonitorItem(
                "max_fitness", lambda experiment: calc_summary_stat(
                    experiment.population, "max", "fitness")),
            MonitorItem(
                "deletion", lambda experiment: experiment.population.
                operations_record["deletion"]),
            MonitorItem(
                "covering", lambda experiment: experiment.population.
                operations_record["covering"]),
            MonitorItem(
                "as_subsumption", lambda experiment: experiment.population.
                operations_record["as_subsumption"]),
            MonitorItem(
                "ga_subsumption", lambda experiment: experiment.population.
                operations_record["ga_subsumption"]),
            MonitorItem(
                "discovery", lambda experiment: experiment.population.
                operations_record["discovery"]),
            MonitorItem(
                "absorption", lambda experiment: experiment.population.
                operations_record["absorption"])
        ]

        experiment = Experiment(env,
                                alg,
                                num_training_epochs=10,
                                monitor_items=monitor_items)
        experiment.run()
        print(experiment._monitor.query())