def test_the_network_can_get_a_projn_by_name() -> None: n = net.Net() n.new_layer("lr1", 1) n.new_layer("lr2", 1) n.new_projn("pr1", "lr1", "lr2") assert n._get_projn("pr1") is n.projns["pr1"]
def test_a_new_projn_validates_its_spec() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) with pytest.raises(specs.ValidationError): n.new_projn( "projn1", "layer1", "layer2", spec=specs.ProjnSpec(integ=-1))
def test_the_network_can_get_a_oscill_by_name() -> None: n = net.Net() n.new_layer("lr1", 1) n.new_layer("lr2", 1) n.new_oscill("theta", ["lr1", "lr2"]) assert n._get_oscill("theta") is n.oscills["theta"]
def test_net_can_pause_and_resume_logging() -> None: n = net.Net() n.new_layer( "layer1", 2, spec=specs.LayerSpec(log_on_cycle=( "unit_act", "avg_act", ))) for i in range(2): n.cycle() n.pause_logging() for i in range(2): n.cycle() n.resume_logging() for i in range(2): n.cycle() parts_time = torch.Tensor(n.logs("cycle", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("cycle", "layer1").whole["time"]) assert list(parts_time.size()) == [8] assert list(whole_time.size()) == [4] assert (parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all() assert (whole_time == torch.Tensor([0, 1, 4, 5])).all()
def test_net_trial_log_pausing_and_resuming() -> None: n = net.Net() n.new_layer( "layer1", 2, spec=specs.LayerSpec(log_on_trial=( "unit_act", "avg_act", ))) n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() n.pause_logging("trial") n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() n.resume_logging("trial") n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() parts_time = torch.Tensor(n.logs("trial", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("trial", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all()
def test_kwta_suppresses_all_but_k_units() -> None: n = net.Net() n.new_layer(name="lr1", size=1) lr2_spec = sp.LayerSpec( inhibition_type="kwta", k=2, log_on_cycle=("unit_act", ), unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0)) n.new_layer(name="lr2", size=3, spec=lr2_spec) pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3)) n.new_projn("proj1", "lr1", "lr2", pr1_spec) pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1]) n.new_projn("proj2", "lr1", "lr2", pr2_spec) n.force_layer("lr1", [1]) for i in range(100): n.cycle() logs = n.logs("cycle", "lr2") acts = logs[logs.time == 99]["act"] assert (acts > 0.8).sum() == 2
def test_you_can_unclamp_layers() -> None: n = net.Net() n.new_layer("layer1", 1) n.new_layer("layer2", 2) n.new_projn("projn1", pre="layer1", post="layer2") n.new_projn("projn2", pre="layer2", post="layer1") n.clamp_layer("layer2", [0]) n.clamp_layer("layer1", [0.7]) n.unclamp_layer("layer2") # Drive layer 2 so it should spike for _ in range(50): n.cycle() assert (n.layers["layer2"].units.act > 0).all() n.clamp_layer("layer1", [0]) n.unclamp_layer("layer1", "layer2") for _ in range(50): n.cycle() assert (n.layers["layer1"].units.act > 0).all() assert (n.layers["layer2"].units.act > 0).all()
def test_running_a_minus_phase_broadcasts_minus_phase_event_markers( mocker) -> None: n = net.Net() n.new_layer("layer1", 1) mocker.spy(n, "handle") n.minus_phase_cycle(num_cycles=1) assert isinstance(n.handle.call_args_list[0][0][0], events.BeginMinusPhase) assert isinstance(n.handle.call_args_list[-1][0][0], events.EndMinusPhase)
def test_a_new_oscill_validates_its_spec() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) with pytest.raises(specs.ValidationError): n.new_oscill( "theta", ["layer1", "layer2"], spec=specs.OscillSpec(mid=-1))
def test_net_validates_rem_oscill() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) n.new_oscill("theta", ["layer1"]) n.rem_oscill("theta") with pytest.raises(ValueError): n.rem_oscill("theta")
def test_you_can_hard_clamp_a_layer() -> None: n = net.Net() n.new_layer("layer1", 4) n.clamp_layer("layer1", [0, 1]) n.cycle() expected = [0, 0.95, 0, 0.95] for i in range(4): assert math.isclose( n.objs["layer1"].units.act[i], expected[i], abs_tol=1e-6)
def test_getting_an_invalid_layer_name_raises_value_error() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) n.new_projn("proj1", "layer1", "layer2") with pytest.raises(ValueError): n._get_layer("whales") with pytest.raises(ValueError): n._get_layer("proj1")
def test_running_a_plus_phase_runs_the_correct_number_of_cycles( mocker) -> None: n = net.Net() n.new_layer("layer1", 1) mocker.spy(n, "handle") n.plus_phase_cycle(num_cycles=42) assert all( isinstance(i, events.Cycle) for i in n.handle.call_args_list[1:43][0][0])
def test_the_network_can_check_if_an_object_is_a_layer() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) n.new_projn("proj1", "layer1", "layer2") with pytest.raises(ValueError): n._validate_layer_name("whales") with pytest.raises(ValueError): n._validate_layer_name("proj1")
def test_net_oscill_phase_cycle() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) n.new_projn("projn1", "layer1", "layer2") n.new_oscill("theta1", ["layer1"]) n.new_oscill("theta2", ["layer2"], spec=specs.OscillSpec(mid=0.7)) n.phase_cycle(events.PlusPhase, 10) n.rem_oscill("theta1", "theta2") n.phase_cycle(events.PlusPhase, 10)
def test_projn_checks_if_the_receiving_layer_names_are_valid() -> None: n = net.Net() n.new_layer("layer1", 3) with pytest.raises(ValueError): n.new_oscill("theta", ["layer2"]) with pytest.raises(ValueError): n.new_oscill("theta", ["layer1", "layer2"]) with pytest.raises(ValueError): n.new_oscill("theta", ["layer2", "layer1"])
def test_the_network_can_remove_oscill_by_name() -> None: n = net.Net() n.new_oscill("theta1", []) n.new_oscill("theta2", []) n._validate_oscill_name("theta1") n._validate_oscill_name("theta2") n.rem_oscill("theta1") with pytest.raises(ValueError): n._validate_oscill_name("theta1") n._validate_oscill_name("theta2")
def test_net_can_uninhibit_projns() -> None: n = net.Net() n.new_layer("lr1", size=1) n.new_layer("lr2", size=1) n.new_projn("pr1", "lr1", "lr2") n.new_projn("pr2", "lr1", "lr2") n.new_projn("pr3", "lr1", "lr2") n.uninhibit_projns("pr1", "pr2", "pr3")
def test_the_network_can_validate_projn_names() -> None: n = net.Net() n.new_layer("lr1", 1) n.new_layer("lr2", 1) n.new_projn("pr1", "lr1", "lr2") n._validate_projn_name("pr1") with pytest.raises(ValueError): n._validate_projn_name("whales")
def test_the_network_can_validate_oscill_names() -> None: n = net.Net() n.new_layer("lr1", 1) n.new_layer("lr2", 1) n.new_oscill("theta", ["lr1", "lr2"]) n._validate_oscill_name("theta") with pytest.raises(ValueError): n._validate_oscill_name("whales")
def test_you_can_observe_unlogged_attributes() -> None: n = net.Net() n.new_layer("layer1", 1) n.new_layer("layer2", 1) n.new_projn("projn1", "layer1", "layer2") pd.util.testing.assert_frame_equal( n.observe("projn1", "cos_diff_avg"), pd.DataFrame({ "cos_diff_avg": (0.0, ) }), check_like=True)
def test_network_can_be_retrieved_and_continue_logging() -> None: n = net.Net() n.new_layer("layer1", 2, spec=specs.LayerSpec(log_on_cycle=( "unit_act", "avg_act", ))) for i in range(2): n.cycle() n.pause_logging() location = "tests/mynet.pkl" n.save(location) m = net.Net() m.load(filename=location) before_parts_n = n.logs("cycle", "layer1").parts before_parts_m = m.logs("cycle", "layer1").parts before_whole_n = n.logs("cycle", "layer1").whole before_whole_m = m.logs("cycle", "layer1").whole assert np.all((before_parts_n == before_parts_m).values) assert np.all((before_whole_n == before_whole_m).values) for i in range(2): m.cycle() m.resume_logging() for i in range(2): m.cycle() after_parts_time = torch.Tensor(m.logs("cycle", "layer1").parts["time"]) after_whole_time = torch.Tensor(m.logs("cycle", "layer1").whole["time"]) assert list(after_parts_time.size()) == [8] assert list(after_whole_time.size()) == [4] assert (after_parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all() assert (after_whole_time == torch.Tensor([0, 1, 4, 5])).all()
def test_network_passes_non_cycle_events_to_every_object(mocker) -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) n.new_projn("projn1", "layer1", "layer2") for _, obj in n.objs.items(): mocker.spy(obj, "handle") n.handle(events.BeginPhase(events.PlusPhase)) for _, obj in n.objs.items(): assert obj.handle.call_count == 1
def test_you_can_retrieve_the_logs_for_a_layer() -> None: n = net.Net() n.new_layer(name="layer1", size=3, spec=specs.LayerSpec(log_on_cycle=("avg_act", ), log_on_trial=("avg_act", ), log_on_epoch=("avg_act", ), log_on_batch=("avg_act", ))) n.plus_phase_cycle(1) n.end_epoch() n.end_batch() for freq in ("cycle", "trial", "epoch", "batch"): assert "avg_act" in n.logs(freq, "layer1").whole.columns
def test_net_catches_uninhibit_bad_projn_name() -> None: n = net.Net() n.new_layer("lr1", size=1) n.new_layer("lr2", size=1) n.new_projn("pr1", "lr1", "lr2") n.new_projn("pr2", "lr1", "lr2") n.new_projn("pr3", "lr1", "lr2") with pytest.raises(ValueError): n.uninhibit_projns("pr4") with pytest.raises(ValueError): n.uninhibit_projns("pr1", "pr5", "pr3")
def test_running_a_phase_runs_the_correct_number_of_cycles(mocker) -> None: for phase in events.Phase.phases(): n = net.Net() mocker.spy(n, "handle") if phase == events.NonePhase: with pytest.raises(ValueError): n.phase_cycle(phase=phase, num_cycles=42) else: n.phase_cycle(phase=phase, num_cycles=42) assert all( isinstance(i, events.Cycle) for i in n.handle.call_args_list[1:43][0][0])
def test_unclamping_layers_validates_its_name() -> None: n = net.Net() with pytest.raises(ValueError): n.unclamp_layer("abcd") n.new_layer("abcd", size=2) with pytest.raises(ValueError): n.unclamp_layer("abcd", "g") with pytest.raises(ValueError): n.unclamp_layer("g", "abcd") with pytest.raises(ValueError): n.unclamp_layer("abcd", "g", "abcd")
def test_running_a_phase_broadcasts_phase_event_markers(mocker) -> None: for phase in events.Phase.phases(): n = net.Net() mocker.spy(n, "handle") if phase == events.NonePhase: with pytest.raises(ValueError): n.phase_cycle(phase=phase, num_cycles=1) return else: n.phase_cycle(phase=phase, num_cycles=1) assert n.handle.call_args_list[0][0][0] == phase.begin_event assert isinstance(n.handle.call_args_list[1][0][0], events.Cycle) assert n.handle.call_args_list[2][0][0] == phase.end_event
def test_net_batch_log_pausing_and_resuming() -> None: n = net.Net() n.new_layer("layer1", 2, spec=specs.LayerSpec(log_on_batch=( "unit_act", "avg_act", ))) n.end_batch() n.pause_logging("batch") n.end_batch() n.resume_logging("batch") n.end_batch() parts_time = torch.Tensor(n.logs("batch", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("batch", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all()
def test_you_can_signal_the_end_of_a_batch(mocker) -> None: n = net.Net() mocker.spy(n, "handle") n.end_batch() assert isinstance(n.handle.call_args_list[0][0][0], events.EndBatch)