from engine.mpm_solver import MPMSolver

write_to_disk = False

# Try to run on GPU
ti.init(arch=ti.cuda, device_memory_GB=3.0)

gui = ti.GUI("Taichi Elements", res=512, background_color=0x112F41)

mpm = MPMSolver(res=(64, 64, 64), size=1)

triangles = np.fromfile('suzanne.npy', dtype=np.float32)
triangles = np.reshape(triangles, (len(triangles) // 9, 9)) * 0.306 + 0.501

mpm.add_mesh(triangles=triangles,
             material=MPMSolver.material_elastic,
             color=0xFFFF00)

mpm.set_gravity((0, -20, 0))

for frame in range(1500):
    mpm.step(4e-3)
    particles = mpm.particle_info()
    np_x = particles['position'] / 1.0

    # simple camera transform
    screen_x = ((np_x[:, 0] + np_x[:, 2]) / 2**0.5) - 0.2
    screen_y = (np_x[:, 1])

    screen_pos = np.stack([screen_x, screen_y], axis=-1)
counter = 0

start_t = time.time()

for frame in range(15000):
    print(f'frame {frame}')
    t = time.time()
    if mpm.n_particles[None] < max_num_particles:
        i = frame % 4 - 2
        j = frame / 4 % 4 - 1

        r = 255 if frame % 3 == 0 else 128
        g = 255 if frame % 3 == 1 else 128
        b = 255 if frame % 3 == 2 else 128
        color = r * 65536 + g * 256 + b
        mpm.add_mesh(triangles=triangles,
                     material=MPMSolver.material_elastic,
                     color=color,
                     velocity=(0, -2, 0),
                     translation=((i + 0.5) * 0.25, 0, (2 - j) * 0.1))

    mpm.step(2e-3, print_stat=True)
    if with_gui and frame % 3 == 0:
        particles = mpm.particle_info()
        visualize(particles)

    if write_to_disk:
        mpm.write_particles(f'{output_dir}/{frame:05d}.npz')
    print(f'Frame total time {time.time() - t:.3f}')
    print(f'Total running time {time.time() - start_t:.3f}')
counter = 0

start_t = time.time()

for frame in range(15000):
    print(f'frame {frame}')
    t = time.time()

    if frame % 50 == 0 and mpm.n_particles[None] < max_num_particles:
        F = frame // 50
        r = 255 if F % 3 == 0 else 128
        g = 255 if F % 3 == 1 else 128
        b = 255 if F % 3 == 2 else 128
        mpm.add_mesh(triangles=triangles,
                     material=MPMSolver.material_elastic,
                     color=r * 65536 + g * 256 + b,
                     velocity=(0, -6, 0),
                     translation=(0.0, 0.16, (F % 2) * 0.4))

    if frame > 60 and mpm.n_particles[None] < max_num_particles:
        i = frame % 3 - 1.5
        j = 0  # frame / 4 % 4 - 1
        colors = [0xFF8888, 0xEEEEFF, 0xFFFF55]
        materials = [
            MPMSolver.material_elastic, MPMSolver.material_elastic,
            MPMSolver.material_elastic
        ]
        mpm.add_mesh(triangles=triangles_small,
                     material=materials[frame % 3],
                     color=colors[frame % 3],
                     velocity=(0, -6, 0),
    r = 255 if l % 3 == 0 else 128
    g = 255 if l % 3 == 1 else 128
    b = 255 if l % 3 == 2 else 128
    color = r * 65536 + g * 256 + b

    for k in range(layers):
        print(f"  Generating layer {k}, h_start {h_start}")
        for i in range(bb_count):
            for j in range(bb_count):
                x, y, z = -1.1 + (
                    i + 0.5) * bb_size, h_start + bb_size * k * 1.1, -1.1 + (
                        j + 0.5) * bb_size
                if mpm.n_particles[None] < max_num_particles:
                    mpm.add_mesh(triangles=bunnies[l],
                                 material=MPMSolver.material_elastic,
                                 color=color,
                                 velocity=(0, -5, 0),
                                 translation=(x, y, z))
                    print(
                        f'Total particles: {mpm.n_particles[None] / 1e6:.4f} M'
                    )
                    total_bunnies += 1
    h_start += bb_size * layers
    h_start -= 0.05 * max(0, 2 - l)  # adjustments

mpm.set_gravity((0, -25, 0))

print(f'Per particle space: {mpm.particle.cell_size_bytes} B')
print(f'Total bunnies: {total_bunnies}')
print(f'Total particles: {mpm.n_particles[None] / 1e6:.4f} M')