def main():
    print("start running basic neural network")
    np.random.seed(2)  # restart random number generator
    s1 = Session(parent_dir="/rscratch/xuanyu/data/")
    n_events = 5000

    count = 0
    nn_list_basic = myModel.complex_cnn(9)

    for hits_train, truth_train in s1.get_train_events(n=n_events, content=[s1.HITS, s1.TRUTH], randomness=True)[1]:
        count += 1
        print(f"{count}/{n_events}")
        hits_train = join_hits_truth(hits_train, truth_train)
        fy = get_target(hits_train)

        loss_global = 5000
        # fx = get_feature(hits, 0.0, flip=False, quadratic=True)
        for i in range(100):
            print("Step: " + str(i))
            loss, model = train_nn(nn_list_basic, get_feature(hits_train, theta=np.random.rand() * 2 * np.pi, flip=np.random.rand() < 0.5, quadratic=True), permute_target(fy),
            basic_trainable=True, epochs=5, batch_size=2048, verbose=1)

            if(loss<loss_global):
                print("Epoch result better than the best, saving model")              
                model.save("./checkpoint/aaronmao/mymodel.h5", overwrite=True)
def main():
    print("start running basic neural network")
    np.random.seed(1)  # restart random number generator
    s1 = Session(parent_dir="E:/TrackMLData/")
    n_events = 50
    count = 0
    nn_list_basic = get_basic_nn(9)

    for hits, truth in s1.get_train_events(n=n_events,
                                           content=[s1.HITS, s1.TRUTH],
                                           randomness=True)[1]:
        count += 1
        print(f"{count}/{n_events}")
        hits = join_hits_truth(hits, truth)
        fy = get_target(hits)
        # fx = get_feature(hits, 0.0, flip=False, quadratic=True)
        if count > 0:
            print("validation check")
            try:
                train_nn(nn_list_basic,
                         get_feature(hits, theta=0, flip=False,
                                     quadratic=True),
                         fy,
                         basic_trainable=False,
                         epochs=5,
                         batch_size=128,
                         verbose=1)
            except KeyboardInterrupt:
                pass
        print("start actual training")
        train_nn(nn_list_basic,
                 get_feature(hits, theta=0, flip=False, quadratic=True),
                 fy,
                 basic_trainable=True,
                 epochs=5,
                 batch_size=128,
                 verbose=1)
        for i in range(5):
            try:
                train_nn(nn_list_basic,
                         get_feature(hits,
                                     theta=np.random.rand() * 2 * np.pi,
                                     flip=np.random.rand() < 0.5,
                                     quadratic=True),
                         permute_target(fy),
                         basic_trainable=True,
                         epochs=5,
                         batch_size=128,
                         verbose=1)
            except KeyboardInterrupt:
                continue