def test_transition_with_zero_trans_mat_is_never_used(self):
        # first find most likely transition with uniform trans_mat
        wta1 = BioWTARegressor(self.n_models, self.n_features)
        r1 = wta1.transform(self.predictors, self.dependent)
        k1 = np.argmax(r1, axis=1)

        trans_count1 = np.zeros((self.n_models, self.n_models), dtype=int)
        for ki, kf in zip(k1, k1[1:]):
            trans_count1[ki, kf] += 1

        kmli, kmlf = np.unravel_index(trans_count1.argmax(), trans_count1.shape)
        self.assertGreater(trans_count1[kmli, kmlf], 0)

        # next disallow the system performing that transition
        trans_mat = np.ones((self.n_models, self.n_models)) / self.n_models
        trans_mat[kmli] = np.ones(self.n_models) / (self.n_models - 1)
        trans_mat[kmli, kmlf] = 0

        wta2 = BioWTARegressor(self.n_models, self.n_features, trans_mat=trans_mat)
        r2 = wta2.transform(self.predictors, self.dependent)
        k2 = np.argmax(r2, axis=1)

        trans_count2 = np.zeros((self.n_models, self.n_models), dtype=int)
        for ki, kf in zip(k2, k2[1:]):
            trans_count2[ki, kf] += 1

        self.assertEqual(trans_count2[kmli, kmlf], 0)
    def test_switching_rate_to_zero_fixes_weights(self):
        schedule = np.zeros(self.n_samples)
        schedule[: self.n_partial] = self.rate
        wta = BioWTARegressor(self.n_models, self.n_features, rate=schedule)

        wta.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(wta.weights_, self.wta_partial.weights_)
    def test_constructor_copies_weight_schedule(self):
        schedule = self.rate * np.ones(self.n_samples)
        wta = BioWTARegressor(self.n_models, self.n_features, rate=schedule)

        schedule[:] = 0
        wta.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(wta.weights_, self.wta_full.weights_)
class TestBioWTARegressorLatentPriorTemperatureOne(unittest.TestCase):
    def setUp(self):
        self.n_models = 3
        self.n_features = 4
        self.temperature = 1.0
        self.wta = BioWTARegressor(
            self.n_models, self.n_features, temperature=self.temperature
        )

        self.rng = np.random.default_rng(10)
        self.n_samples = 79
        self.predictors = self.rng.normal(size=(self.n_samples, self.n_features))
        self.dependent = self.rng.normal(size=self.n_samples)

    def test_initial_r_value_changes_by_correct_amount_according_to_start_prob(self):
        r1 = self.wta.transform(self.predictors[[0]], self.dependent[[0]])

        # change initial state distribution
        start_prob = (lambda v: v / np.sum(v))(self.rng.uniform(size=self.n_models))
        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            temperature=self.temperature,
            start_prob=start_prob,
        )

        r2 = wta2.transform(self.predictors[[0]], self.dependent[[0]])

        diff_log_r = np.log(r2[0]) - np.log(r1[0])
        np.testing.assert_allclose(
            diff_log_r - diff_log_r[0], np.log(start_prob) - np.log(start_prob[0]),
        )

    def test_second_r_value_changes_by_correct_amount_according_to_trans_mat(self):
        r1 = self.wta.transform(self.predictors[:2], self.dependent[:2])

        # change transition matrix
        trans_mat = [
            (lambda v: v / np.sum(v))(self.rng.uniform(size=self.n_models))
            for _ in range(self.n_models)
        ]
        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            temperature=self.temperature,
            trans_mat=trans_mat,
        )

        r2 = wta2.transform(self.predictors[:2], self.dependent[:2])
        np.testing.assert_allclose(r1[0], r2[0])

        diff_log_r = np.log(r2[1]) - np.log(r1[1])
        expected_diff_log_r = r2[0] @ np.log(trans_mat)
        np.testing.assert_allclose(
            diff_log_r - diff_log_r[0], expected_diff_log_r - expected_diff_log_r[0]
        )
    def test_default_error_timescale_is_one(self):
        T = 1.0
        wta0 = BioWTARegressor(**self.kwargs, error_timescale=1.0, temperature=T)
        wta1 = BioWTARegressor(**self.kwargs, temperature=T)

        r0 = wta0.transform(self.predictors, self.dependent)
        r1 = wta1.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r0, r1)
        np.testing.assert_allclose(wta0.weights_, wta1.weights_)
    def test_weight_change_from_repeated_transform_calls_equivalent_to_one_call(self):
        n1 = 2 * self.n_samples // 7
        n2 = 3 * self.n_samples // 7
        self.wta.transform(self.predictors[:n1], self.dependent[:n1])
        self.wta.transform(self.predictors[n1 : n1 + n2], self.dependent[n1 : n1 + n2])
        self.wta.transform(self.predictors[n1 + n2 :], self.dependent[n1 + n2 :])

        wta_again = BioWTARegressor(**self.kwargs)
        wta_again.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(self.wta.weights_, wta_again.weights_)
    def test_weight_updates_use_instantaneous_error(self):
        wta0 = BioWTARegressor(**self.kwargs, error_timescale=1.0)
        wta1 = BioWTARegressor(**self.kwargs, error_timescale=2.3)

        weights0 = np.copy(wta0.weights_)
        np.testing.assert_allclose(wta1.weights_, weights0)

        wta0.transform(self.predictors[:1], self.dependent[:1])
        wta1.transform(self.predictors[:1], self.dependent[:1])

        self.assertGreater(np.max(np.abs(wta0.weights_ - weights0)), 1e-5)
        np.testing.assert_allclose(wta0.weights_, wta1.weights_)
    def test_output_of_repeated_calls_to_transform_equivalent_to_single_call(self):
        n1 = 4 * self.n_samples // 7
        tau = 1.0
        T = 0.5
        wta = BioWTARegressor(**self.kwargs, error_timescale=tau, temperature=T)
        r1 = wta.transform(self.predictors[:n1], self.dependent[:n1])
        r2 = wta.transform(self.predictors[n1:], self.dependent[n1:])

        wta_again = BioWTARegressor(**self.kwargs, error_timescale=tau, temperature=T)
        r = wta_again.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r, np.vstack((r1, r2)))
    def test_no_warning_when_some_start_prob_are_zero(self):
        start_prob = np.ones(self.n_models) / (self.n_models - 2)

        start_prob[[0, 2]] = 0
        wta = BioWTARegressor(self.n_models, self.n_features, start_prob=start_prob)

        with warnings.catch_warnings(record=True) as warn_list:
            # ensure warnings haven't been disabled
            warnings.simplefilter("always")
            wta.transform(self.predictors, self.dependent)

            # ensure no warnings have been triggered
            self.assertEqual(len(warn_list), 0)
    def test_last_value_of_rate_is_used_if_more_samples_than_len_rate(self):
        n = 3 * self.n_samples // 4
        schedule_short = self.rng.uniform(0, self.rate, size=n)
        schedule = np.hstack(
            (schedule_short, (self.n_samples - n) * [schedule_short[-1]])
        )

        wta1 = BioWTARegressor(self.n_models, self.n_features, rate=schedule_short)
        wta2 = BioWTARegressor(self.n_models, self.n_features, rate=schedule)

        wta1.transform(self.predictors, self.dependent)
        wta2.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(wta1.weights_, wta2.weights_)
    def test_no_warning_when_some_trans_mat_are_zero(self):
        # disallow the system from performing some transition
        trans_mat = np.ones((self.n_models, self.n_models)) / self.n_models
        trans_mat[0] = np.ones(self.n_models) / (self.n_models - 1)
        trans_mat[0, 1] = 0
        wta = BioWTARegressor(self.n_models, self.n_features, trans_mat=trans_mat)

        with warnings.catch_warnings(record=True) as warn_list:
            # ensure warnings haven't been disabled
            warnings.simplefilter("always")
            wta.transform(self.predictors, self.dependent)

            # ensure no warnings have been triggered
            self.assertEqual(len(warn_list), 0)
    def test_state_with_zero_start_prob_is_not_used_at_first_step(self):
        # first find the most likely state with uniform start_prob
        wta1 = BioWTARegressor(self.n_models, self.n_features)
        r1 = wta1.transform(self.predictors, self.dependent)
        k0 = np.argmax(r1[0])

        # next disallow the system from starting in that state
        start_prob = np.ones(self.n_models) / (self.n_models - 1)
        # we never start in state k0
        start_prob[k0] = 0

        wta2 = BioWTARegressor(self.n_models, self.n_features, start_prob=start_prob)
        r2 = wta2.transform(self.predictors, self.dependent)

        self.assertLess(r2[0, k0], 1e-6)
    def test_resulting_r_same_if_rate_is_constant_then_switches(self):
        schedule = np.zeros(self.n_samples)
        schedule[: self.n_partial] = self.rate
        wta = BioWTARegressor(self.n_models, self.n_features, rate=schedule)

        r = wta.transform(self.predictors, self.dependent)
        np.testing.assert_allclose(r[: self.n_partial], self.r_partial)
    def test_small_chunk_same_as_no_chunk(self):
        wta1 = BioWTARegressor(
            self.n_models,
            self.n_features,
            start_prob=self.start_prob,
            trans_mat=self.trans_mat,
        )
        r1 = wta1.transform(self.predictors, self.dependent)

        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            start_prob=self.start_prob,
            trans_mat=self.trans_mat,
        )
        r2 = wta2.transform(self.predictors, self.dependent, chunk_hint=12)

        np.testing.assert_allclose(r1, r2)
    def test_output_of_repeated_calls_to_transform_equivalent_to_single_call(self):
        n1 = 4 * self.n_samples // 7
        r1 = self.wta.transform(self.predictors[:n1], self.dependent[:n1])
        r2 = self.wta.transform(self.predictors[n1:], self.dependent[n1:])

        wta_again = BioWTARegressor(**self.kwargs)
        r = wta_again.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r, np.vstack((r1, r2)))
    def test_float_trans_mat(self):
        r1 = self.wta.transform(self.predictors, self.dependent)

        wta2 = BioWTARegressor(
            self.n_models, self.n_features, trans_mat=1 / self.n_models
        )
        r2 = wta2.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r1, r2)
    def test_changing_latent_state_transition_matrix_changes_output(self):
        trans_mat = [
            (lambda v: v / np.sum(v))(self.rng.uniform(size=self.n_models))
            for _ in range(self.n_models)
        ]
        wta2 = BioWTARegressor(self.n_models, self.n_features, trans_mat=trans_mat)

        r1 = self.wta.transform(self.predictors, self.dependent)
        r2 = wta2.transform(self.predictors, self.dependent)

        self.assertGreater(np.max(np.abs(r1 - r2)), 1e-3)
    def test_monitor_as_object(self):
        names = ["weights_"]
        monitor = AttributeMonitor(names)
        self.wta.transform(self.predictors, self.dependent, monitor=monitor)

        wta_alt = BioWTARegressor(self.n_models, self.n_features)
        _, history_alt = wta_alt.transform(
            self.predictors, self.dependent, monitor=names
        )

        np.testing.assert_allclose(monitor.history_.weights_, history_alt.weights_)
    def test_default_initial_latent_state_distribution_is_uniform(self):
        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            start_prob=np.ones(self.n_models) / self.n_models,
        )

        r1 = self.wta.transform(self.predictors, self.dependent)
        r2 = wta2.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r1, r2)
    def test_default_latent_state_transition_matrix_is_uniform(self):
        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            trans_mat=np.ones((self.n_models, self.n_models)) / self.n_models,
        )

        r1 = self.wta.transform(self.predictors, self.dependent)
        r2 = wta2.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r1, r2)
    def test_callable_rate_works_like_constant(self):
        n_models = 3
        n_features = 4

        rng = np.random.default_rng(2)
        n_samples = 55
        predictors = rng.normal(size=(n_samples, n_features))
        dependent = rng.normal(size=n_samples)

        rate = 1e-4

        def rate_fct(_):
            return rate

        wta1 = BioWTARegressor(n_models, n_features, rate=rate_fct)
        wta2 = BioWTARegressor(n_models, n_features, rate=rate)

        wta1.transform(predictors, dependent)
        wta2.transform(predictors, dependent)

        np.testing.assert_allclose(wta1.weights_, wta2.weights_)
    def test_callable_rate_works_like_vector(self):
        n_models = 3
        n_features = 4

        rng = np.random.default_rng(1)
        n_samples = 55
        predictors = rng.normal(size=(n_samples, n_features))
        dependent = rng.normal(size=n_samples)

        def rate_fct(i):
            return 1 / (1 + 0.5 * i)

        wta1 = BioWTARegressor(n_models, n_features, rate=rate_fct)

        schedule = [rate_fct(_) for _ in range(n_samples)]
        wta2 = BioWTARegressor(n_models, n_features, rate=schedule)

        wta1.transform(predictors, dependent)
        wta2.transform(predictors, dependent)

        np.testing.assert_allclose(wta1.weights_, wta2.weights_)
    def test_changing_initial_latent_state_distribution_changes_output(self):
        r1 = self.wta.transform(self.predictors, self.dependent)
        k = np.argmax(r1[0])

        start_prob = self.rng.uniform(size=self.n_models)
        start_prob[k] /= 50

        start_prob = start_prob / np.sum(start_prob)
        wta2 = BioWTARegressor(self.n_models, self.n_features, start_prob=start_prob)

        r2 = wta2.transform(self.predictors, self.dependent)

        self.assertGreater(np.max(np.abs(r1 - r2)), 1e-3)
    def test_transform_retval_uses_recent_loss_not_instantaneous_error(self):
        weights = self.rng.normal(size=(self.n_models, self.n_features))

        x = self.predictors
        y0 = x[: self.n_samples - 1] @ weights[0]
        y = np.hstack((y0, x[-1] @ weights[1]))

        wta0 = BioWTARegressor(
            **self.kwargs, error_timescale=1.0, weights=weights, rate=0
        )
        wta1 = BioWTARegressor(
            **self.kwargs, error_timescale=1000.0, weights=weights, rate=0
        )

        r0 = wta0.transform(x, y)
        r1 = wta1.transform(x, y)

        k0 = r0.argmax(axis=1)
        k1 = r1.argmax(axis=1)

        self.assertEqual(k0[-1], 1)
        self.assertEqual(k1[-1], 0)
    def test_recent_loss_correct_if_timescale_is_not_zero(self):
        tau = 3.5
        wta = BioWTARegressor(**self.kwargs, error_timescale=tau)
        _, history = wta.transform(
            self.predictors, self.dependent, monitor=["error_", "recent_loss_"]
        )

        loss_exp = np.zeros((self.n_samples, self.n_models))
        crt_loss = np.zeros(self.n_models)
        for i in range(self.n_samples):
            crt_loss += (history.error_[i] ** 2 - crt_loss) / tau
            loss_exp[i] = crt_loss

        np.testing.assert_allclose(history.recent_loss_, loss_exp)
    def test_log_r_prop_temperature(self):
        r = self.wta.transform(self.predictors[[0]], self.dependent[[0]])

        temperature_again = 3.2
        wta_again = BioWTARegressor(
            self.n_models, self.n_features, temperature=temperature_again
        )
        r_again = wta_again.transform(self.predictors[[0]], self.dependent[[0]])

        d_logr = np.log(r[0]) - np.log(r[0, 0])
        d_logr_again = np.log(r_again[0]) - np.log(r_again[0, 0])

        np.testing.assert_allclose(
            d_logr_again, d_logr * self.temperature / temperature_again
        )
    def test_when_start_prob_large_for_a_state_then_that_state_gets_high_r(self):
        r1 = self.wta.transform(self.predictors[[0]], self.dependent[[0]])
        k1 = np.argmax(r1[0])

        # make another state win out
        k = 0 if k1 != 0 else 1
        p_large = 0.999
        start_prob = (1 - p_large) * np.ones(self.n_models) / (self.n_models - 1)
        start_prob[k] = p_large
        wta2 = BioWTARegressor(self.n_models, self.n_features, start_prob=start_prob)

        r2 = wta2.transform(self.predictors[[0]], self.dependent[[0]])
        k2 = np.argmax(r2[0])

        self.assertNotEqual(k1, k2)
        self.assertEqual(k2, k)
class TestBioWTARegressorArbitraryStartProbAndTransMat(unittest.TestCase):
    def setUp(self):
        self.n_models = 3
        self.n_features = 4
        self.rng = np.random.default_rng(9)

        def normalize_v(v: np.ndarray) -> np.ndarray:
            return v / np.sum(v)

        self.start_prob = normalize_v(self.rng.uniform(size=self.n_models))
        self.trans_mat = [
            normalize_v(self.rng.uniform(size=self.n_models))
            for _ in range(self.n_models)
        ]

        self.kwargs = {
            "n_models": self.n_models,
            "n_features": self.n_features,
            "start_prob": self.start_prob,
            "trans_mat": self.trans_mat,
        }
        self.wta = BioWTARegressor(**self.kwargs)

        self.n_samples = 79
        self.predictors = self.rng.normal(size=(self.n_samples, self.n_features))
        self.dependent = self.rng.normal(size=self.n_samples)

    def test_output_of_repeated_calls_to_transform_equivalent_to_single_call(self):
        n1 = 4 * self.n_samples // 7
        r1 = self.wta.transform(self.predictors[:n1], self.dependent[:n1])
        r2 = self.wta.transform(self.predictors[n1:], self.dependent[n1:])

        wta_again = BioWTARegressor(**self.kwargs)
        r = wta_again.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(r, np.vstack((r1, r2)))

    def test_weight_change_from_repeated_transform_calls_equivalent_to_one_call(self):
        n1 = 2 * self.n_samples // 7
        n2 = 3 * self.n_samples // 7
        self.wta.transform(self.predictors[:n1], self.dependent[:n1])
        self.wta.transform(self.predictors[n1 : n1 + n2], self.dependent[n1 : n1 + n2])
        self.wta.transform(self.predictors[n1 + n2 :], self.dependent[n1 + n2 :])

        wta_again = BioWTARegressor(**self.kwargs)
        wta_again.transform(self.predictors, self.dependent)

        np.testing.assert_allclose(self.wta.weights_, wta_again.weights_)
    def test_initial_r_value_changes_by_correct_amount_according_to_start_prob(self):
        r1 = self.wta.transform(self.predictors[[0]], self.dependent[[0]])

        # change initial state distribution
        start_prob = (lambda v: v / np.sum(v))(self.rng.uniform(size=self.n_models))
        wta2 = BioWTARegressor(
            self.n_models,
            self.n_features,
            temperature=self.temperature,
            start_prob=start_prob,
        )

        r2 = wta2.transform(self.predictors[[0]], self.dependent[[0]])

        diff_log_r = np.log(r2[0]) - np.log(r1[0])
        np.testing.assert_allclose(
            diff_log_r - diff_log_r[0], np.log(start_prob) - np.log(start_prob[0]),
        )
    def test_history_same_when_chunk_hint_changes(self):
        names = ["prediction_"]
        _, history = self.wta.transform(
            self.predictors,
            self.dependent,
            monitor=names,
            chunk_hint=1000,
            return_history=True,
        )

        wta_alt = BioWTARegressor(self.n_models, self.n_features)
        _, history_alt = wta_alt.transform(
            self.predictors,
            self.dependent,
            monitor=names,
            chunk_hint=1,
            return_history=True,
        )

        np.testing.assert_allclose(history.prediction_, history_alt.prediction_)