def test_curvilinear_grids(mode): x = np.linspace(0, 1e3, 7, dtype=np.float32) y = np.linspace(0, 1e3, 5, dtype=np.float32) (xx, yy) = np.meshgrid(x, y) r = np.sqrt(xx * xx + yy * yy) theta = np.arctan2(yy, xx) theta = theta + np.pi / 6. lon = r * np.cos(theta) lat = r * np.sin(theta) time = np.array([0, 86400], dtype=np.float64) grid = CurvilinearZGrid(lon, lat, time=time) u_data = np.ones((2, y.size, x.size), dtype=np.float32) v_data = np.zeros((2, y.size, x.size), dtype=np.float32) u_data[0, :, :] = lon[:, :] + lat[:, :] u_field = Field('U', u_data, grid=grid, transpose=False) v_field = Field('V', v_data, grid=grid, transpose=False) field_set = FieldSet(u_field, v_field) def sampleSpeed(particle, fieldset, time, dt): u = fieldset.U[time, particle.lon, particle.lat, particle.depth] v = fieldset.V[time, particle.lon, particle.lat, particle.depth] particle.speed = math.sqrt(u * u + v * v) class MyParticle(ptype[mode]): speed = Variable('speed', dtype=np.float32, initial=0.) pset = ParticleSet.from_list(field_set, MyParticle, lon=[400], lat=[600]) pset.execute(pset.Kernel(sampleSpeed), runtime=0, dt=0) assert (np.allclose(pset[0].speed, 1000))
def test_pset_create_field_curvi(npart=100): np.random.seed(123456) r_v = np.linspace(.25, 2, 20) theta_v = np.linspace(0, np.pi / 2, 200) dtheta = theta_v[1] - theta_v[0] dr = r_v[1] - r_v[0] (r, theta) = np.meshgrid(r_v, theta_v) x = -1 + r * np.cos(theta) y = -1 + r * np.sin(theta) grid = CurvilinearZGrid(x, y) u = np.ones(x.shape) v = np.where(np.logical_and(theta > np.pi / 4, theta < np.pi / 3), 1, 0) ufield = Field('U', u, grid=grid) vfield = Field('V', v, grid=grid) fieldset = FieldSet(ufield, vfield) pset = ParticleSet.from_field(fieldset, size=npart, pclass=ptype['scipy'], start_field=fieldset.V) lons = np.array([p.lon + 1 for p in pset]) lats = np.array([p.lat + 1 for p in pset]) thetas = np.arctan2(lats, lons) rs = np.sqrt(lons * lons + lats * lats) test = np.pi / 4 - dtheta < thetas test *= thetas < np.pi / 3 + dtheta test *= rs > .25 - dr test *= rs < 2 + dr assert np.all(test)