def test_trial(
    network: lb.Net,
    input_pattern=torch.Tensor,
    output_pattern=torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], bool, bool]:
    # Initial Response
    network.clamp_layer("Input", input_pattern)
    network.phase_cycle(TestPhase, num_cycles=20)

    initial_response = net_snapshot(network)
    print(type(initial_response["EC_out"]))
    print(type(output_pattern.byte()))
    sys.stdout.flush()
    initial_guess = ((initial_response["EC_out"] >
                      0.5) == output_pattern.byte()).all()

    # Final Response
    network.phase_cycle(TestPhase, num_cycles=60)

    final_response = net_snapshot(network)
    final_guess = ((final_response["EC_out"] >
                    0.5) == output_pattern.byte()).all()

    # Reset
    network.end_trial()

    return initial_response, final_response, initial_guess, final_guess
def trial(network: lb.Net, input_pattern: Iterable[float],
          output_pattern: Iterable[float]) -> None:
    """Runs a trial."""
    network.clamp_layer("input", input_pattern)
    network.phase_cycle(phase=lb.MinusPhase, num_cycles=50)
    network.clamp_layer("output", output_pattern)
    network.phase_cycle(phase=lb.PlusPhase, num_cycles=25)
    network.unclamp_layer("input")
    network.unclamp_layer("output")
    network.end_trial()
Beispiel #3
0
def learn_trial(network: lb.Net, pattern=torch.Tensor) -> None:
    # Theta Trough
    network.clamp_layer("Input", pattern)
    network.inhibit_projns("TSP: CA3 -> CA1", "Loop: EC_out -> EC_in")
    network.phase_cycle(ThetaTrough, num_cycles=20)
    network.uninhibit_projns("TSP: CA3 -> CA1", "Loop: EC_out -> EC_in")

    # Theta Peak
    network.inhibit_projns("MSP: EC_in -> CA1", "Loop: EC_out -> EC_in")
    network.phase_cycle(ThetaPeak, num_cycles=20)
    network.uninhibit_projns("MSP: EC_in -> CA1", "Loop: EC_out -> EC_in")
    # Theta Plus
    network.clamp_layer("EC_out", pattern)
    network.inhibit_projns("TSP: CA3 -> CA1")
    network.phase_cycle(ThetaPlus, num_cycles=60)
    network.inhibit_projns("TSP: CA3 -> CA1")
    # Reset
    network.unclamp_layer("Input", "EC_out")
    network.end_trial()
    network.learn()