def run_alexnet_ann_recall_test_simulation_trial3():
    # instantiate alexnet from mnist trained
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc1) # for now capture final output
    return run_alexnet_ann_recall_simulation(alex_cnn=alex_cnn, alex_capture=alex_capture, output_name="alexnet_recall_task_trial3.txt", num_nodes=1024)
                 download=True,
                 transform=transforms.Compose(
                     [transforms.Resize((32, 32)),
                      transforms.ToTensor()]))

unique_train_images = []
unique_train_labels = []
for i in range(len(data_raw)):
    image, label = data_raw[i]
    if label not in unique_train_labels:
        unique_train_images.append(image.reshape((1, 1, 32, 32)))
        unique_train_labels.append(label)
    if len(unique_train_labels) > 2:
        break

lenet_model = lenet.LeNet5()
lenet_model.load_state_dict(torch.load("trained_models/lenet5_1.pt"))
lenet_model.eval()
lenet_capture = Intermediate_Capture(lenet_model.f5)
hopfield_net = hopnet(10)
test_b = CNN_ANN(
    lenet_model,
    hopfield_net,
    lenet_capture,
    capture_process_fn=lambda x: np.sign(np.exp(x) - np.exp(x).mean()))

test_b.learn(unique_train_images, unique_train_labels, verbose=True)

for train_image in unique_train_images:
    print(test_b.predict(train_image))
def run_lenet_ann_recall_test_simulation_trial1():
    output_name = "lenet_recall_task_trial1.txt"

    lenet_cnn2 = LeNet5()
    lenet_cnn2.load_state_dict(torch.load("trained_models/lenet5_1.pt", map_location=torch.device("cpu")))
    lenet_cnn2.eval()
    lenet_capture2 = Intermediate_Capture(lenet_cnn2.c2_2) 

    lenet_cnn3 = LeNet5()
    lenet_cnn3.load_state_dict(torch.load("trained_models/lenet5_1.pt", map_location=torch.device("cpu")))
    lenet_cnn3.eval()
    lenet_capture3 = Intermediate_Capture(lenet_cnn3.c3) 

    lenet_cnn4 = LeNet5()
    lenet_cnn4.load_state_dict(torch.load("trained_models/lenet5_1.pt", map_location=torch.device("cpu")))
    lenet_cnn4.eval()
    lenet_capture4 = Intermediate_Capture(lenet_cnn4.f4) 

    transform = transforms.ToTensor()
    data_raw = MNIST('./data/mnist',
                    download=True,
                    transform=transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor()]))

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 32,32)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):
        desired_labels = list(range(desired_label_size+1))
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(400) 
        model = CNN_ANN(lenet_cnn2, ann_model, lenet_capture2, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("lenet Layer4:", num_succ, ":", num_fail)
        layer3_perf = int(num_succ)

        ann_model = hopnet(120) 
        model = CNN_ANN(lenet_cnn3, ann_model, lenet_capture3, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("lenet Layer5:", num_succ, ":", num_fail)
        layer4_perf = int(num_succ)

        ann_model = hopnet(84) 
        model = CNN_ANN(lenet_cnn4, ann_model, lenet_capture4, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("lenet Layer6:", num_succ, ":", num_fail)
        layer5_perf = int(num_succ)

        stored_size_vs_performance.append((layer3_perf, layer4_perf, layer5_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance
def run_alexnet_ann_recall_test_simulation_trial7():
    output_name="alexnet_recall_task_trial7.txt"
    num_nodes=10
    full_connection_mat = np.ones(shape=(num_nodes,num_nodes)) - np.eye(num_nodes)
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc3) # for now capture final output

    transform = transforms.ToTensor()
    data_raw = MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform)

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 28,28)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):

        # need to generate probe set each time
        # when desired label size is k:
        # probe set is 10 instances each of labels 0 to k-1
        desired_labels = list(range(desired_label_size+1))
        sub_probe_set = []
        sub_probe_labels = []
        for des in desired_labels:
            # add 10 instances of des
            for inst in mnist_subset[des]:
                sub_probe_set.append(inst)
                sub_probe_labels.append(des)
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(num_nodes) 
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("Hopfield:", num_succ, ":", num_fail)
        hopfield_perf = int(num_succ)

        # evaluate popularity ANN performance
        # hyperparams: set c = N-1, with randomly generated connectivity matrix
        ann_model = PopularityANN(N=num_nodes, c=num_nodes-1, connectivity_matrix=full_connection_mat)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("PopularityANN:", num_succ, ":", num_fail)
        popularity_perf = int(num_succ)

        # evaluate orthogonal hebbs ANN performance
        ann_model = OrthogonalHebbsANN(N=num_nodes)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("OrthogonalHebbsANN:", num_succ, ":", num_fail)
        ortho_perf = int(num_succ)

        stored_size_vs_performance.append((hopfield_perf, popularity_perf, ortho_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance