Пример #1
0
def test_the_network_can_get_a_projn_by_name() -> None:
    n = net.Net()
    n.new_layer("lr1", 1)
    n.new_layer("lr2", 1)
    n.new_projn("pr1", "lr1", "lr2")

    assert n._get_projn("pr1") is n.projns["pr1"]
Пример #2
0
def test_a_new_projn_validates_its_spec() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    with pytest.raises(specs.ValidationError):
        n.new_projn(
            "projn1", "layer1", "layer2", spec=specs.ProjnSpec(integ=-1))
Пример #3
0
def test_the_network_can_get_a_oscill_by_name() -> None:
    n = net.Net()
    n.new_layer("lr1", 1)
    n.new_layer("lr2", 1)
    n.new_oscill("theta", ["lr1", "lr2"])

    assert n._get_oscill("theta") is n.oscills["theta"]
Пример #4
0
def test_net_can_pause_and_resume_logging() -> None:
    n = net.Net()
    n.new_layer(
        "layer1",
        2,
        spec=specs.LayerSpec(log_on_cycle=(
            "unit_act",
            "avg_act",
        )))
    for i in range(2):
        n.cycle()
    n.pause_logging()
    for i in range(2):
        n.cycle()
    n.resume_logging()
    for i in range(2):
        n.cycle()

    parts_time = torch.Tensor(n.logs("cycle", "layer1").parts["time"])
    whole_time = torch.Tensor(n.logs("cycle", "layer1").whole["time"])
    assert list(parts_time.size()) == [8]
    assert list(whole_time.size()) == [4]

    assert (parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all()
    assert (whole_time == torch.Tensor([0, 1, 4, 5])).all()
Пример #5
0
def test_net_trial_log_pausing_and_resuming() -> None:
    n = net.Net()
    n.new_layer(
        "layer1",
        2,
        spec=specs.LayerSpec(log_on_trial=(
            "unit_act",
            "avg_act",
        )))

    n.phase_cycle(phase=events.MinusPhase, num_cycles=5)
    n.phase_cycle(phase=events.PlusPhase, num_cycles=5)
    n.end_trial()

    n.pause_logging("trial")

    n.phase_cycle(phase=events.MinusPhase, num_cycles=5)
    n.phase_cycle(phase=events.PlusPhase, num_cycles=5)
    n.end_trial()

    n.resume_logging("trial")

    n.phase_cycle(phase=events.MinusPhase, num_cycles=5)
    n.phase_cycle(phase=events.PlusPhase, num_cycles=5)
    n.end_trial()

    parts_time = torch.Tensor(n.logs("trial", "layer1").parts["time"])
    whole_time = torch.Tensor(n.logs("trial", "layer1").whole["time"])

    assert list(parts_time.size()) == [4]
    assert list(whole_time.size()) == [2]

    assert (parts_time == torch.Tensor([0, 0, 2, 2])).all()
    assert (whole_time == torch.Tensor([0, 2])).all()
Пример #6
0
def test_kwta_suppresses_all_but_k_units() -> None:
    n = net.Net()

    n.new_layer(name="lr1", size=1)

    lr2_spec = sp.LayerSpec(
        inhibition_type="kwta",
        k=2,
        log_on_cycle=("unit_act", ),
        unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0))
    n.new_layer(name="lr2", size=3, spec=lr2_spec)

    pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3))
    n.new_projn("proj1", "lr1", "lr2", pr1_spec)

    pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1])
    n.new_projn("proj2", "lr1", "lr2", pr2_spec)

    n.force_layer("lr1", [1])
    for i in range(100):
        n.cycle()

    logs = n.logs("cycle", "lr2")
    acts = logs[logs.time == 99]["act"]
    assert (acts > 0.8).sum() == 2
Пример #7
0
def test_you_can_unclamp_layers() -> None:
    n = net.Net()
    n.new_layer("layer1", 1)
    n.new_layer("layer2", 2)
    n.new_projn("projn1", pre="layer1", post="layer2")
    n.new_projn("projn2", pre="layer2", post="layer1")

    n.clamp_layer("layer2", [0])
    n.clamp_layer("layer1", [0.7])
    n.unclamp_layer("layer2")

    # Drive layer 2 so it should spike
    for _ in range(50):
        n.cycle()

    assert (n.layers["layer2"].units.act > 0).all()

    n.clamp_layer("layer1", [0])
    n.unclamp_layer("layer1", "layer2")

    for _ in range(50):
        n.cycle()

    assert (n.layers["layer1"].units.act > 0).all()
    assert (n.layers["layer2"].units.act > 0).all()
Пример #8
0
def test_running_a_minus_phase_broadcasts_minus_phase_event_markers(
        mocker) -> None:
    n = net.Net()
    n.new_layer("layer1", 1)
    mocker.spy(n, "handle")
    n.minus_phase_cycle(num_cycles=1)
    assert isinstance(n.handle.call_args_list[0][0][0], events.BeginMinusPhase)
    assert isinstance(n.handle.call_args_list[-1][0][0], events.EndMinusPhase)
Пример #9
0
def test_a_new_oscill_validates_its_spec() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)

    with pytest.raises(specs.ValidationError):
        n.new_oscill(
            "theta", ["layer1", "layer2"], spec=specs.OscillSpec(mid=-1))
Пример #10
0
def test_net_validates_rem_oscill() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    n.new_oscill("theta", ["layer1"])
    n.rem_oscill("theta")
    with pytest.raises(ValueError):
        n.rem_oscill("theta")
Пример #11
0
def test_you_can_hard_clamp_a_layer() -> None:
    n = net.Net()
    n.new_layer("layer1", 4)
    n.clamp_layer("layer1", [0, 1])
    n.cycle()
    expected = [0, 0.95, 0, 0.95]
    for i in range(4):
        assert math.isclose(
            n.objs["layer1"].units.act[i], expected[i], abs_tol=1e-6)
Пример #12
0
def test_getting_an_invalid_layer_name_raises_value_error() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    n.new_projn("proj1", "layer1", "layer2")
    with pytest.raises(ValueError):
        n._get_layer("whales")
    with pytest.raises(ValueError):
        n._get_layer("proj1")
Пример #13
0
def test_running_a_plus_phase_runs_the_correct_number_of_cycles(
        mocker) -> None:
    n = net.Net()
    n.new_layer("layer1", 1)
    mocker.spy(n, "handle")
    n.plus_phase_cycle(num_cycles=42)
    assert all(
        isinstance(i, events.Cycle)
        for i in n.handle.call_args_list[1:43][0][0])
Пример #14
0
def test_the_network_can_check_if_an_object_is_a_layer() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    n.new_projn("proj1", "layer1", "layer2")
    with pytest.raises(ValueError):
        n._validate_layer_name("whales")
    with pytest.raises(ValueError):
        n._validate_layer_name("proj1")
Пример #15
0
def test_net_oscill_phase_cycle() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    n.new_projn("projn1", "layer1", "layer2")
    n.new_oscill("theta1", ["layer1"])
    n.new_oscill("theta2", ["layer2"], spec=specs.OscillSpec(mid=0.7))
    n.phase_cycle(events.PlusPhase, 10)
    n.rem_oscill("theta1", "theta2")
    n.phase_cycle(events.PlusPhase, 10)
Пример #16
0
def test_projn_checks_if_the_receiving_layer_names_are_valid() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)

    with pytest.raises(ValueError):
        n.new_oscill("theta", ["layer2"])
    with pytest.raises(ValueError):
        n.new_oscill("theta", ["layer1", "layer2"])
    with pytest.raises(ValueError):
        n.new_oscill("theta", ["layer2", "layer1"])
Пример #17
0
def test_the_network_can_remove_oscill_by_name() -> None:
    n = net.Net()
    n.new_oscill("theta1", [])
    n.new_oscill("theta2", [])
    n._validate_oscill_name("theta1")
    n._validate_oscill_name("theta2")
    n.rem_oscill("theta1")
    with pytest.raises(ValueError):
        n._validate_oscill_name("theta1")
    n._validate_oscill_name("theta2")
Пример #18
0
def test_net_can_uninhibit_projns() -> None:
    n = net.Net()
    n.new_layer("lr1", size=1)
    n.new_layer("lr2", size=1)

    n.new_projn("pr1", "lr1", "lr2")
    n.new_projn("pr2", "lr1", "lr2")
    n.new_projn("pr3", "lr1", "lr2")

    n.uninhibit_projns("pr1", "pr2", "pr3")
Пример #19
0
def test_the_network_can_validate_projn_names() -> None:
    n = net.Net()
    n.new_layer("lr1", 1)
    n.new_layer("lr2", 1)
    n.new_projn("pr1", "lr1", "lr2")

    n._validate_projn_name("pr1")

    with pytest.raises(ValueError):
        n._validate_projn_name("whales")
Пример #20
0
def test_the_network_can_validate_oscill_names() -> None:
    n = net.Net()
    n.new_layer("lr1", 1)
    n.new_layer("lr2", 1)
    n.new_oscill("theta", ["lr1", "lr2"])

    n._validate_oscill_name("theta")

    with pytest.raises(ValueError):
        n._validate_oscill_name("whales")
Пример #21
0
def test_you_can_observe_unlogged_attributes() -> None:
    n = net.Net()
    n.new_layer("layer1", 1)
    n.new_layer("layer2", 1)
    n.new_projn("projn1", "layer1", "layer2")
    pd.util.testing.assert_frame_equal(
        n.observe("projn1", "cos_diff_avg"),
        pd.DataFrame({
            "cos_diff_avg": (0.0, )
        }),
        check_like=True)
Пример #22
0
def test_network_can_be_retrieved_and_continue_logging() -> None:
    n = net.Net()
    n.new_layer("layer1",
                2,
                spec=specs.LayerSpec(log_on_cycle=(
                    "unit_act",
                    "avg_act",
                )))
    for i in range(2):
        n.cycle()
    n.pause_logging()

    location = "tests/mynet.pkl"
    n.save(location)
    m = net.Net()
    m.load(filename=location)

    before_parts_n = n.logs("cycle", "layer1").parts
    before_parts_m = m.logs("cycle", "layer1").parts

    before_whole_n = n.logs("cycle", "layer1").whole
    before_whole_m = m.logs("cycle", "layer1").whole

    assert np.all((before_parts_n == before_parts_m).values)
    assert np.all((before_whole_n == before_whole_m).values)

    for i in range(2):
        m.cycle()
    m.resume_logging()
    for i in range(2):
        m.cycle()

    after_parts_time = torch.Tensor(m.logs("cycle", "layer1").parts["time"])
    after_whole_time = torch.Tensor(m.logs("cycle", "layer1").whole["time"])

    assert list(after_parts_time.size()) == [8]
    assert list(after_whole_time.size()) == [4]

    assert (after_parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all()
    assert (after_whole_time == torch.Tensor([0, 1, 4, 5])).all()
Пример #23
0
def test_network_passes_non_cycle_events_to_every_object(mocker) -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    n.new_projn("projn1", "layer1", "layer2")

    for _, obj in n.objs.items():
        mocker.spy(obj, "handle")

    n.handle(events.BeginPhase(events.PlusPhase))

    for _, obj in n.objs.items():
        assert obj.handle.call_count == 1
Пример #24
0
def test_you_can_retrieve_the_logs_for_a_layer() -> None:
    n = net.Net()
    n.new_layer(name="layer1",
                size=3,
                spec=specs.LayerSpec(log_on_cycle=("avg_act", ),
                                     log_on_trial=("avg_act", ),
                                     log_on_epoch=("avg_act", ),
                                     log_on_batch=("avg_act", )))
    n.plus_phase_cycle(1)
    n.end_epoch()
    n.end_batch()
    for freq in ("cycle", "trial", "epoch", "batch"):
        assert "avg_act" in n.logs(freq, "layer1").whole.columns
Пример #25
0
def test_net_catches_uninhibit_bad_projn_name() -> None:
    n = net.Net()
    n.new_layer("lr1", size=1)
    n.new_layer("lr2", size=1)

    n.new_projn("pr1", "lr1", "lr2")
    n.new_projn("pr2", "lr1", "lr2")
    n.new_projn("pr3", "lr1", "lr2")

    with pytest.raises(ValueError):
        n.uninhibit_projns("pr4")

    with pytest.raises(ValueError):
        n.uninhibit_projns("pr1", "pr5", "pr3")
Пример #26
0
def test_running_a_phase_runs_the_correct_number_of_cycles(mocker) -> None:
    for phase in events.Phase.phases():
        n = net.Net()
        mocker.spy(n, "handle")

        if phase == events.NonePhase:
            with pytest.raises(ValueError):
                n.phase_cycle(phase=phase, num_cycles=42)

        else:
            n.phase_cycle(phase=phase, num_cycles=42)

            assert all(
                isinstance(i, events.Cycle)
                for i in n.handle.call_args_list[1:43][0][0])
Пример #27
0
def test_unclamping_layers_validates_its_name() -> None:
    n = net.Net()

    with pytest.raises(ValueError):
        n.unclamp_layer("abcd")

    n.new_layer("abcd", size=2)

    with pytest.raises(ValueError):
        n.unclamp_layer("abcd", "g")

    with pytest.raises(ValueError):
        n.unclamp_layer("g", "abcd")

    with pytest.raises(ValueError):
        n.unclamp_layer("abcd", "g", "abcd")
Пример #28
0
def test_running_a_phase_broadcasts_phase_event_markers(mocker) -> None:
    for phase in events.Phase.phases():
        n = net.Net()
        mocker.spy(n, "handle")

        if phase == events.NonePhase:
            with pytest.raises(ValueError):
                n.phase_cycle(phase=phase, num_cycles=1)
            return

        else:
            n.phase_cycle(phase=phase, num_cycles=1)

            assert n.handle.call_args_list[0][0][0] == phase.begin_event
            assert isinstance(n.handle.call_args_list[1][0][0], events.Cycle)
            assert n.handle.call_args_list[2][0][0] == phase.end_event
Пример #29
0
def test_net_batch_log_pausing_and_resuming() -> None:
    n = net.Net()
    n.new_layer("layer1",
                2,
                spec=specs.LayerSpec(log_on_batch=(
                    "unit_act",
                    "avg_act",
                )))

    n.end_batch()
    n.pause_logging("batch")
    n.end_batch()
    n.resume_logging("batch")
    n.end_batch()

    parts_time = torch.Tensor(n.logs("batch", "layer1").parts["time"])
    whole_time = torch.Tensor(n.logs("batch", "layer1").whole["time"])

    assert list(parts_time.size()) == [4]
    assert list(whole_time.size()) == [2]

    assert (parts_time == torch.Tensor([0, 0, 2, 2])).all()
    assert (whole_time == torch.Tensor([0, 2])).all()
Пример #30
0
def test_you_can_signal_the_end_of_a_batch(mocker) -> None:
    n = net.Net()
    mocker.spy(n, "handle")
    n.end_batch()
    assert isinstance(n.handle.call_args_list[0][0][0], events.EndBatch)