def setUp(self):
        self.rng = np.random.default_rng(3)

        self.input_dim = 4
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0011
        self.tau = 0.34
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.kwargs = dict(
            weights=self.w0,
            lateral=self.m0,
            rate=self.alpha,
            tau=self.tau,
            scalings=self.pc_scalings,
        )
        self.circuit = NonRecurrent(**self.kwargs)

        self.n_samples = 85
        self.x = self.rng.normal(size=(self.n_samples, self.input_dim))
    def test_monitor_fit(self):
        names = ["weights_", "lateral_", "output_"]
        monitor = AttributeMonitor(names)

        n_samples = 100

        rng = np.random.default_rng(1)
        x = rng.normal(size=(n_samples, self.n_features))

        self.circuit.transform(x, monitor=monitor)

        circuit1 = NonRecurrent(rng=self.seed, **self.kwargs)
        weights = []
        lateral = []
        output = []
        for crt_x in x:
            weights.append(np.copy(circuit1.weights_))
            lateral.append(np.copy(circuit1.lateral_))
            output.append(np.copy(circuit1.output_))

            circuit1.transform([crt_x])

        np.testing.assert_allclose(monitor.history_.weights_, weights)
        np.testing.assert_allclose(monitor.history_.lateral_, lateral)
        np.testing.assert_allclose(monitor.history_.output_, output)
    def setUp(self):
        self.rng = np.random.default_rng(2)

        self.input_dim = 4
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.001
        self.tau = 0.5
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.kwargs = dict(
            weights=self.w0,
            lateral=self.m0,
            rate=self.alpha,
            tau=self.tau,
            non_negative=True,
            scalings=self.pc_scalings,
            whiten=False,
        )

        self.circuit = NonRecurrent(**self.kwargs)
 def setUp(self):
     self.n_features = 4
     self.n_components = 3
     self.seed = 4
     self.rng = np.random.default_rng(self.seed)
     self.kwargs = dict(n_features=self.n_features,
                        n_components=self.n_components)
     self.circuit = NonRecurrent(rng=self.rng, **self.kwargs)
    def test_switching_rate_to_zero_fixes_weights(self):
        schedule = np.zeros(self.n_samples)
        schedule[:self.n_partial] = self.rate
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        circuit.transform(self.x)

        np.testing.assert_allclose(circuit.weights_,
                                   self.circuit_partial.weights_)
    def test_constructor_copies_weight_schedule(self):
        schedule = self.rate * np.ones(self.n_samples)
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        schedule[:] = 0
        circuit.transform(self.x)

        np.testing.assert_allclose(circuit.weights_,
                                   self.circuit_full.weights_)
    def test_output_history_same_if_rate_is_constant_then_switches(self):
        schedule = np.zeros(self.n_samples)
        schedule[:self.n_partial] = self.rate
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        monitor = AttributeMonitor(["output_"])
        circuit.transform(self.x, monitor=monitor)
        np.testing.assert_allclose(
            monitor.history_.output_[:self.n_partial],
            self.monitor_partial.history_.output_,
        )
class TestNonRecurrentFitWithPCScaling(unittest.TestCase):
    def setUp(self):
        self.rng = np.random.default_rng(1)

        self.input_dim = 4
        self.output_dim = 3

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0005
        self.tau = 0.4
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.circuit = NonRecurrent(weights=self.w0,
                                    lateral=self.m0,
                                    tau=self.tau,
                                    rate=self.alpha,
                                    scalings=self.pc_scalings)

    def test_lateral_weights_evolution_no_whitening(self):
        n_steps = 10
        self.circuit.whiten = False
        lbd = np.diag(self.pc_scalings)

        for k in range(n_steps):
            x = self.rng.normal(size=self.input_dim)

            old_m = np.array(self.circuit.lateral_)

            self.circuit.transform([x])

            expected_m = old_m + (self.alpha / self.tau) * (
                np.outer(self.circuit.output_, self.circuit.output_) -
                lbd @ old_m @ lbd)
            np.testing.assert_allclose(self.circuit.lateral_, expected_m)

    def test_lateral_weights_evolution_with_whitening(self):
        n_steps = 10
        self.circuit.whiten = True

        lbd = np.diag(self.pc_scalings)

        for k in range(n_steps):
            x = self.rng.normal(size=self.input_dim)

            old_m = np.array(self.circuit.lateral_)

            self.circuit.transform([x])

            expected_m = old_m + (self.alpha / self.tau) * (np.outer(
                self.circuit.output_, self.circuit.output_) - lbd @ lbd)
            np.testing.assert_allclose(self.circuit.lateral_, expected_m)
    def test_callable_rate_works_like_constant(self):
        n_features = 5
        n_components = 3
        seed = 3
        kwargs = dict(n_features=n_features,
                      n_components=n_components,
                      rng=seed)

        rng = np.random.default_rng(0)
        n_samples = 55
        x = rng.normal(size=(n_samples, n_features))

        rate = 1e-4

        def rate_fct(_):
            return rate

        circuit1 = NonRecurrent(rate=rate_fct, **kwargs)
        circuit1.transform(x)

        schedule = [rate_fct(_) for _ in range(n_samples)]
        circuit2 = NonRecurrent(rate=schedule, **kwargs)
        circuit2.transform(x)

        np.testing.assert_allclose(circuit1.weights_, circuit2.weights_)
        np.testing.assert_allclose(circuit1.lateral_, circuit2.lateral_)
    def test_output_identically_zero_if_input_is_in_ker_weights(self):
        rng = np.random.default_rng(0)

        n_steps = 10
        input_dim = 4
        output_dim = 2  # <input_dim so we have non-trivial kernel

        w0 = rng.normal(size=(output_dim, input_dim))
        circuit = NonRecurrent(weights=w0)

        for k in range(n_steps):
            x = generate_random_from_kernel(w0, 1, rng).ravel()
            circuit.transform([x])
            np.testing.assert_allclose(circuit.output_, 0, atol=1e-10)
class TestNonRecurrentTransform(unittest.TestCase):
    def setUp(self):
        self.rng = np.random.default_rng(3)

        self.input_dim = 4
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0011
        self.tau = 0.34
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.kwargs = dict(
            weights=self.w0,
            lateral=self.m0,
            rate=self.alpha,
            tau=self.tau,
            scalings=self.pc_scalings,
        )
        self.circuit = NonRecurrent(**self.kwargs)

        self.n_samples = 85
        self.x = self.rng.normal(size=(self.n_samples, self.input_dim))

    def test_fit_infer_returns_same_as_monitor_output(self):
        monitor = AttributeMonitor(["output_"])
        res = self.circuit.transform(self.x, monitor=monitor)

        np.testing.assert_allclose(res, monitor.history_.output_)
    def setUp(self):
        self.rng = np.random.default_rng(1)

        self.input_dim = 5
        self.output_dim = 3

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.001
        self.tau = 0.6

        self.circuit = NonRecurrent(weights=self.w0,
                                    lateral=self.m0,
                                    tau=self.tau,
                                    rate=self.alpha)
    def test_small_chunk_same_as_no_chunk(self):
        circuit1 = NonRecurrent(**self.kwargs)
        circuit1.transform(self.x)

        circuit2 = NonRecurrent(**self.kwargs)
        circuit2.transform(self.x, chunk_hint=12)

        np.testing.assert_allclose(circuit1.weights_, circuit2.weights_)
        np.testing.assert_allclose(circuit1.lateral_, circuit2.lateral_)
        np.testing.assert_allclose(circuit1.output_, circuit2.output_)
    def setUp(self):
        self.rng = np.random.default_rng(2)

        self.input_dim = 3
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0008
        self.tau = 0.7

        self.circuit = NonRecurrent(weights=self.w0,
                                    lateral=self.m0,
                                    tau=self.tau,
                                    rate=self.alpha,
                                    non_negative=True)
    def test_weights_just_decaying_if_input_is_in_ker_weights(self):
        rng = np.random.default_rng(1)
        n_steps = 10

        input_dim = 5
        output_dim = 4  # <input_dim so we have non-trivial kernel

        w0 = rng.normal(size=(output_dim, input_dim))
        alpha = 0.0009
        tau = 0.4
        circuit = NonRecurrent(weights=w0, tau=tau, rate=alpha)

        gamma_w = 1 - circuit.rate

        for k in range(n_steps):
            x = generate_random_from_kernel(w0, 1, rng).ravel()
            circuit.transform([x])
            np.testing.assert_allclose(circuit.weights_,
                                       (gamma_w**(k + 1)) * w0)
    def test_constructor_leaves_lateral_unchanged_if_it_is_positive_definite(
            self):
        rng = np.random.default_rng(0)
        output_dim = 3

        sqrt_m0 = rng.normal(size=(output_dim, output_dim))
        m0 = sqrt_m0 @ sqrt_m0.T
        circuit = NonRecurrent(n_features=5, lateral=m0)

        np.testing.assert_allclose(circuit.lateral_, m0)
class TestNonRecurrentFitNonnegativeWithPCScalings(unittest.TestCase):
    def setUp(self):
        self.rng = np.random.default_rng(2)

        self.input_dim = 4
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.001
        self.tau = 0.5
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.kwargs = dict(
            weights=self.w0,
            lateral=self.m0,
            rate=self.alpha,
            tau=self.tau,
            non_negative=True,
            scalings=self.pc_scalings,
            whiten=False,
        )

        self.circuit = NonRecurrent(**self.kwargs)

    def test_constraint_implemented_before_lateral_evolution(self):
        n_steps = 10

        lbd = np.diag(self.pc_scalings)

        for k in range(n_steps):
            x = self.rng.normal(size=self.input_dim)

            old_m = np.array(self.circuit.lateral_)

            self.circuit.transform([x])

            expected_m = old_m + (self.alpha / self.tau) * (
                np.outer(self.circuit.output_, self.circuit.output_) -
                lbd @ old_m @ lbd)
            np.testing.assert_allclose(self.circuit.lateral_, expected_m)
    def test_schedule_used_in_sequence_for_multiple_calls_to_fit(self):
        schedule = self.rng.uniform(0, self.rate, size=self.n_samples)

        circuit1 = NonRecurrent(rate=schedule, **self.kwargs)
        circuit2 = NonRecurrent(rate=schedule, **self.kwargs)

        circuit1.transform(self.x)

        circuit2.transform(self.x[:self.n_samples // 2])
        circuit2.transform(self.x[self.n_samples // 2:])

        np.testing.assert_allclose(circuit1.weights_, circuit2.weights_)
    def test_last_value_of_rate_is_used_if_more_samples_than_len_rate(self):
        n = 3 * self.n_samples // 4
        schedule_short = self.rng.uniform(0, self.rate, size=n)
        schedule = np.hstack(
            (schedule_short, (self.n_samples - n) * [schedule_short[-1]]))

        circuit1 = NonRecurrent(rate=schedule_short, **self.kwargs)
        circuit2 = NonRecurrent(rate=schedule, **self.kwargs)

        circuit1.transform(self.x)
        circuit2.transform(self.x)

        np.testing.assert_allclose(circuit1.weights_, circuit2.weights_)
    def test_history_same_when_chunk_hint_changes(self):
        names = ["weights_", "lateral_", "output_"]
        monitor = AttributeMonitor(names)

        n_samples = 100

        rng = np.random.default_rng(1)
        x = rng.normal(size=(n_samples, self.n_features))

        self.circuit.transform(x, monitor=monitor, chunk_hint=13)

        circuit_alt = NonRecurrent(rng=self.seed, **self.kwargs)
        monitor_alt = AttributeMonitor(names)
        circuit_alt.transform(x, monitor=monitor_alt, chunk_hint=2)

        np.testing.assert_allclose(monitor.history_.weights_,
                                   monitor_alt.history_.weights_)
        np.testing.assert_allclose(monitor.history_.lateral_,
                                   monitor_alt.history_.lateral_)
        np.testing.assert_allclose(monitor.history_.output_,
                                   monitor_alt.history_.output_)
class TestNonRecurrentFitNonnegative(unittest.TestCase):
    def setUp(self):
        self.rng = np.random.default_rng(2)

        self.input_dim = 3
        self.output_dim = 2

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0008
        self.tau = 0.7

        self.circuit = NonRecurrent(weights=self.w0,
                                    lateral=self.m0,
                                    tau=self.tau,
                                    rate=self.alpha,
                                    non_negative=True)

    def test_outputs_stay_non_negative(self):
        n_steps = 10
        for k in range(n_steps):
            x = self.rng.normal(size=self.input_dim)

            self.circuit.transform([x])
            self.assertGreaterEqual(np.min(self.circuit.output_), 0)

    def test_constraint_implemented_before_weights_evolution(self):
        n_steps = 10
        for k in range(n_steps):
            x = self.rng.normal(size=self.input_dim)

            old_w = np.array(self.circuit.weights_)

            self.circuit.transform([x])

            expected_w = old_w + self.alpha * (
                np.outer(self.circuit.output_, x) - old_w)
            np.testing.assert_allclose(self.circuit.weights_, expected_w)
    def test_lateral_just_decaying_if_input_is_in_ker_weights(self):
        # assuming default is PSP problem, not PSW (i.e., whiten == False)
        rng = np.random.default_rng(2)
        n_steps = 10

        input_dim = 3
        output_dim = 2  # <input_dim so we have non-trivial kernel

        w0 = rng.normal(size=(output_dim, input_dim))
        sqrt_m0 = rng.normal(size=(output_dim, output_dim))
        m0 = sqrt_m0 @ sqrt_m0.T
        alpha = 0.0015
        tau = 0.7
        circuit = NonRecurrent(weights=w0, lateral=m0, tau=tau, rate=alpha)

        gamma_m = 1 - circuit.rate / circuit.tau

        for k in range(n_steps):
            x = generate_random_from_kernel(w0, 1, rng).ravel()
            circuit.transform([x])
            np.testing.assert_allclose(circuit.lateral_,
                                       (gamma_m**(k + 1)) * m0)
class TestNonRecurrentClone(unittest.TestCase):
    def setUp(self):
        self.rng = np.random.default_rng(4)

        self.input_dim = 5
        self.output_dim = 3

        self.w0 = self.rng.normal(size=(self.output_dim, self.input_dim))
        self.sqrt_m0 = self.rng.normal(size=(self.output_dim, self.output_dim))
        self.m0 = self.sqrt_m0 @ self.sqrt_m0.T

        self.alpha = 0.0012
        self.tau = 0.35
        self.pc_scalings = self.rng.uniform(size=self.output_dim)

        self.kwargs = dict(
            weights=self.w0,
            lateral=self.m0,
            rate=self.alpha,
            tau=self.tau,
            scalings=self.pc_scalings,
        )
        self.circuit = NonRecurrent(**self.kwargs)

    def test_clone_copies_meta_parameters(self):
        circuit_copy = self.circuit.clone()

        self.assertEqual(circuit_copy.n_components, self.circuit.n_components)
        self.assertEqual(circuit_copy.rate, self.circuit.rate)
        self.assertEqual(circuit_copy.tau, self.circuit.tau)
        np.testing.assert_allclose(circuit_copy.scalings,
                                   self.circuit.scalings)
        self.assertEqual(circuit_copy.non_negative, self.circuit.non_negative)
        self.assertEqual(circuit_copy.whiten, self.circuit.whiten)

    def test_clone_copies_last_output(self):
        self.circuit.transform(self.rng.normal(size=(1, self.input_dim)))
        circuit_copy = self.circuit.clone()
        np.testing.assert_allclose(self.circuit.output_, circuit_copy.output_)
    def test_constructor_ensures_lateral_is_positive_definite(self):
        # this implies that it should be symmetric
        rng = np.random.default_rng(0)
        output_dim = 3

        m0 = rng.normal(size=(output_dim, output_dim))
        circuit = NonRecurrent(n_features=5, lateral=m0)

        # symmetric...
        self.assertEqual(circuit.lateral_.shape[0], circuit.lateral_.shape[1])

        # ...and positive-definite
        evals, _ = np.linalg.eigh(circuit.lateral_)
        self.assertGreaterEqual(np.min(evals), 0)
    def setUp(self):
        self.n_features = 5
        self.n_components = 3
        self.seed = 3
        self.kwargs = dict(n_features=self.n_features,
                           n_components=self.n_components,
                           rng=self.seed)

        self.rng = np.random.default_rng(0)
        self.n_samples = 53
        self.x = self.rng.normal(size=(self.n_samples, self.n_features))

        self.rate = 0.005

        self.monitor_full = AttributeMonitor(["output_"])
        self.circuit_full = NonRecurrent(rate=self.rate, **self.kwargs)
        self.circuit_full.transform(self.x, monitor=self.monitor_full)

        self.n_partial = self.n_samples // 2
        self.monitor_partial = AttributeMonitor(["output_"])
        self.circuit_partial = NonRecurrent(rate=self.rate, **self.kwargs)
        self.circuit_partial.transform(self.x[:self.n_partial],
                                       monitor=self.monitor_partial)
    def __init__(
        self,
        n_models: int,
        n_features: int,
        nsm_rate: Union[float, Sequence, Callable[[float], float]] = 1e-3,
        xcorr_rate: float = 0.05,
        rng: Union[int, np.random.RandomState, np.random.Generator] = 0,
        nsm_kws: Optional[dict] = None,
    ):
        """ Initialize the segmentation model.

        Parameters
        ----------
        n_models
            Number of models in mixture.
        n_features
            Number of predictor variables (features).
        nsm_rate
            Learning rate or learning schedule for the non-negative similarity matching
            (NSM) algorithm. See `bioslds.nsm.NonRecurrent`.
        xcorr_rate
            Learning rate for the cross-correlation calculator. See
            `bioslds.xcorr.OnlineCrosscorrelation`.
        rng
            Random number generator or seed to use for generating initial NSM weight
            values. This is simply passed to `bioslds.nsm.NonRecurrent`.
        nsm_kws
            Additional keyword arguments to pass to `bioslds.nsm.NonRecurrent.__init__`.
        """
        self.n_models = n_models
        self.n_features = n_features
        self.n_components = self.n_models

        self.xcorr = OnlineCrosscorrelation(self.n_features, rate=xcorr_rate)

        if nsm_kws is None:
            nsm_kws = {}
        else:
            nsm_kws = copy.copy(nsm_kws)
        nsm_kws.setdefault("rng", rng)
        nsm_kws.setdefault("non_negative", True)
        nsm_kws.setdefault("rate", nsm_rate)
        self.nsm = NonRecurrent(self.n_features, self.n_models, **nsm_kws)

        super().__init__(["xcorr", "nsm"])
    def test_schedule_used_in_sequence_for_multiple_calls_to_fit(self):
        n_features = 6
        n_components = 4
        seed = 2
        kwargs = dict(n_features=n_features,
                      n_components=n_components,
                      rng=seed)

        rng = np.random.default_rng(0)
        n_samples = 50
        x = rng.normal(size=(n_samples, n_features))

        def rate_fct(i):
            return 1 / (100 + 5 * i)

        circuit1 = NonRecurrent(rate=rate_fct, **kwargs)
        circuit2 = NonRecurrent(rate=rate_fct, **kwargs)

        circuit1.transform(x)

        circuit2.transform(x[:n_samples // 2])
        circuit2.transform(x[n_samples // 2:])

        np.testing.assert_allclose(circuit1.weights_, circuit2.weights_)
 def test_whiten_is_false_by_default(self):
     circuit = NonRecurrent(n_features=5, n_components=4)
     self.assertFalse(circuit.whiten)
class TestNonRecurrentVectorLearningRate(unittest.TestCase):
    def setUp(self):
        self.n_features = 5
        self.n_components = 3
        self.seed = 3
        self.kwargs = dict(n_features=self.n_features,
                           n_components=self.n_components,
                           rng=self.seed)

        self.rng = np.random.default_rng(0)
        self.n_samples = 53
        self.x = self.rng.normal(size=(self.n_samples, self.n_features))

        self.rate = 0.005

        self.monitor_full = AttributeMonitor(["output_"])
        self.circuit_full = NonRecurrent(rate=self.rate, **self.kwargs)
        self.circuit_full.transform(self.x, monitor=self.monitor_full)

        self.n_partial = self.n_samples // 2
        self.monitor_partial = AttributeMonitor(["output_"])
        self.circuit_partial = NonRecurrent(rate=self.rate, **self.kwargs)
        self.circuit_partial.transform(self.x[:self.n_partial],
                                       monitor=self.monitor_partial)

    def test_final_weights_different_in_partial_and_full_run(self):
        self.assertGreater(
            np.max(
                np.abs(self.circuit_partial.weights_ -
                       self.circuit_full.weights_)),
            1e-3,
        )
        self.assertGreater(
            np.max(
                np.abs(self.circuit_partial.lateral_ -
                       self.circuit_full.lateral_)),
            1e-3,
        )

    def test_switching_rate_to_zero_fixes_weights(self):
        schedule = np.zeros(self.n_samples)
        schedule[:self.n_partial] = self.rate
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        circuit.transform(self.x)

        np.testing.assert_allclose(circuit.weights_,
                                   self.circuit_partial.weights_)

    def test_output_history_same_if_rate_is_constant_then_switches(self):
        schedule = np.zeros(self.n_samples)
        schedule[:self.n_partial] = self.rate
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        monitor = AttributeMonitor(["output_"])
        circuit.transform(self.x, monitor=monitor)
        np.testing.assert_allclose(
            monitor.history_.output_[:self.n_partial],
            self.monitor_partial.history_.output_,
        )

    def test_constructor_copies_weight_schedule(self):
        schedule = self.rate * np.ones(self.n_samples)
        circuit = NonRecurrent(rate=schedule, **self.kwargs)

        schedule[:] = 0
        circuit.transform(self.x)

        np.testing.assert_allclose(circuit.weights_,
                                   self.circuit_full.weights_)
 def test_n_samples_starts_at_0(self):
     circuit = NonRecurrent(n_features=3, n_components=4)
     self.assertEqual(0, circuit.n_samples_)