def main():

    train_X = np.asarray([[0.2, -0.3], [0.1, -0.9], [0.3, 0.5]])
    train_Y = np.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

    # For comparison
    net = ParticleDipoleNetwork(cost="mse", particle_input=ParticleDipoleInput(2))
    net.append(ParticleDipole(2, 5, activation="sigmoid"))
    net.append(ParticleDipole(5, 3, activation="sigmoid"))

    print(net.predict(train_X))
    print(net.cost(train_X, train_Y))

    net2 = ParticleDipoleTreeNetwork(cost="mse", particle_input=ParticleDipoleTreeInput(2))
    net2.append(ParticleDipoleTree(2, 5, activation="sigmoid"))
    net2.append(ParticleDipoleTree(5, 3, activation="sigmoid"))
    # Make sure we have the same coordinates and charges
    net2.particle_input.copy_pos_neg_positions(net.particle_input.rx_pos, net.particle_input.ry_pos, net.particle_input.rz_pos,
                                               net.particle_input.rx_neg, net.particle_input.ry_neg, net.particle_input.rz_neg)
    for l in range(len(net.layers)):
        net2.layers[l].copy_pos_neg_positions(net.layers[l].q, net.layers[l].b,
                                              net.layers[l].rx_pos, net.layers[l].ry_pos, net.layers[l].rz_pos,
                                              net.layers[l].rx_neg, net.layers[l].ry_neg, net.layers[l].rz_neg)

    print(net2.predict(train_X))
    print(net2.cost(train_X, train_Y))
def main():

    train_X = np.asarray([[0.2, -0.3], [0.1, -0.9], [0.3, 0.5]])
    train_Y = np.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

    # For comparison
    net = ParticleDipoleNetwork(cost="mse",
                                particle_input=ParticleDipoleInput(2))
    net.append(ParticleDipole(2, 5, activation="sigmoid"))
    net.append(ParticleDipole(5, 3, activation="sigmoid"))

    print(net.predict(train_X))
    print(net.cost(train_X, train_Y))

    net2 = ParticleDipoleTreeNetwork(cost="mse",
                                     particle_input=ParticleDipoleTreeInput(2))
    net2.append(ParticleDipoleTree(2, 5, activation="sigmoid"))
    net2.append(ParticleDipoleTree(5, 3, activation="sigmoid"))
    # Make sure we have the same coordinates and charges
    net2.particle_input.copy_pos_neg_positions(net.particle_input.rx_pos,
                                               net.particle_input.ry_pos,
                                               net.particle_input.rz_pos,
                                               net.particle_input.rx_neg,
                                               net.particle_input.ry_neg,
                                               net.particle_input.rz_neg)
    for l in range(len(net.layers)):
        net2.layers[l].copy_pos_neg_positions(
            net.layers[l].q, net.layers[l].b, net.layers[l].rx_pos,
            net.layers[l].ry_pos, net.layers[l].rz_pos, net.layers[l].rx_neg,
            net.layers[l].ry_neg, net.layers[l].rz_neg)

    print(net2.predict(train_X))
    print(net2.cost(train_X, train_Y))
times = []
nt = 3
for _ in range(nt):
    ts = time.time()
    c = net.cost(X_sub, Y_sub)
    # c = 1.0
    # net.cost_gradient(X_sub, Y_sub)
    t = time.time() - ts
    print("Cost: {} time: {}".format(c, t))
    times.append(t)
print("Mean: " + str(sum(times) / nt))

net2 = ParticleDipoleTreeNetwork(cost="categorical_cross_entropy",
                                 particle_input=ParticleDipoleTreeInput(
                                     784,
                                     s=s,
                                     cut=cut,
                                     max_level=max_level,
                                     mac=mac))
net2.append(
    ParticleDipoleTree(784,
                       n,
                       activation="sigmoid",
                       s=s,
                       cut=cut,
                       max_level=max_level,
                       mac=mac,
                       n_particle_min=n_min))
net2.append(
    ParticleDipoleTree(n,
                       10,
print("starting predict")
times = []
nt = 3
for _ in range(nt):
    ts = time.time()
    c = net.cost(X_sub, Y_sub)
    # c = 1.0
    # net.cost_gradient(X_sub, Y_sub)
    t = time.time() - ts
    print("Cost: {} time: {}".format(c, t))
    times.append(t)
print("Mean: " + str(sum(times) / nt))

net2 = ParticleDipoleTreeNetwork(
    cost="categorical_cross_entropy",
    particle_input=ParticleDipoleTreeInput(784, s=s, cut=cut, max_level=max_level, mac=mac),
)
net2.append(
    ParticleDipoleTree(784, n, activation="sigmoid", s=s, cut=cut, max_level=max_level, mac=mac, n_particle_min=n_min)
)
net2.append(
    ParticleDipoleTree(n, 10, activation="softmax", s=s, cut=cut, max_level=max_level, mac=mac, n_particle_min=n_min)
)

# Make sure we have the same coordinates and charges
net2.particle_input.copy_pos_neg_positions(
    net.particle_input.rx_pos,
    net.particle_input.ry_pos,
    net.particle_input.rz_pos,
    net.particle_input.rx_neg,
    net.particle_input.ry_neg,