def test_cell_list_emplace_2d(self, dtype): box_size = np.array([8.65, 8.0], f32) cell_size = f32(1.0) R = np.array([[0.25, 0.25], [8.5, 1.95], [8.1, 1.5], [3.7, 7.9]], dtype=dtype) cell_fn = partition.cell_list(box_size, cell_size, R) cell_list = cell_fn(R) self.assertAllClose(R[0], cell_list.R_buffer[0, 0, 0]) self.assertAllClose(R[1], cell_list.R_buffer[1, 8, 1]) self.assertAllClose(R[2], cell_list.R_buffer[1, 8, 0]) self.assertAllClose(R[3], cell_list.R_buffer[7, 3, 1]) self.assertEqual(0, cell_list.id_buffer[0, 0, 0]) self.assertEqual(1, cell_list.id_buffer[1, 8, 1]) self.assertEqual(2, cell_list.id_buffer[1, 8, 0]) self.assertEqual(3, cell_list.id_buffer[7, 3, 1]) id_flat = np.reshape(cell_list.id_buffer, (-1, )) R_flat = np.reshape(cell_list.R_buffer, (-1, 2)) R_out = np.zeros((5, 2), dtype) R_out = ops.index_update(R_out, id_flat, R_flat)[:-1] self.assertAllClose(R_out, R)
def test_cell_list_random_emplace_side_data(self, dtype, dim): key = random.PRNGKey(1) box_size = (np.array([9.0, 4.0, 7.25], f32) if dim is 3 else np.array([9.0, 4.25], f32)) cell_size = f32(1.23) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) side_data_dim = 2 side_data = random.normal(key, (PARTICLE_COUNT, side_data_dim), dtype=dtype) cell_fn = partition.cell_list(box_size, cell_size, R) cell_list = cell_fn(R, side_data=side_data) id_flat = np.reshape(cell_list.id_buffer, (-1, )) R_flat = np.reshape(cell_list.R_buffer, (-1, dim)) R_out = np.zeros((PARTICLE_COUNT + 1, dim), dtype) R_out = ops.index_update(R_out, id_flat, R_flat)[:-1] side_data_flat = np.reshape(cell_list.kwarg_buffers['side_data'], (-1, side_data_dim)) side_data_out = np.zeros((PARTICLE_COUNT + 1, side_data_dim), dtype) side_data_out = ops.index_update(side_data_out, id_flat, side_data_flat)[:-1] self.assertAllClose(R_out, R) self.assertAllClose(side_data_out, side_data)
def test_cell_list_random_emplace_rect(self, dtype, dim): key = random.PRNGKey(1) box_size = np.array([9.0, 3.0, 7.25]) if dim is 3 else np.array([9.0, 3.25]) cell_size = f32(1.0) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim)) cell_fn = partition.cell_list(box_size, cell_size, R) cell_list = cell_fn(R) id_flat = np.reshape(cell_list.id_buffer, (-1,)) R_flat = np.reshape(cell_list.R_buffer, (-1, dim)) R_out = np.zeros((PARTICLE_COUNT + 1, dim)) R_out = ops.index_update(R_out, id_flat, R_flat)[:-1] self.assertAllClose(R_out, R, True)
def test_cell_list_random_emplace(self, dtype, dim): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim)) cell_fn = partition.cell_list(box_size, cell_size) cell_list = cell_fn.allocate(R) id_flat = np.reshape(cell_list.id_buffer, (-1, )) R_flat = np.reshape(cell_list.position_buffer, (-1, dim)) R_out = np.zeros((PARTICLE_COUNT + 1, dim)) R_out = R_out.at[id_flat].set(R_flat)[:-1] self.assertAllClose(R_out, R)