def test_incremental_rotation(dev_str, call): if call in [helpers.np_call, helpers.jnp_call, helpers.mx_call]: # convolutions not yet implemented in numpy or jax # mxnet is unable to stack or expand zero-dimensional tensors pytest.skip() batch_size = 1 num_timesteps = 1 num_cams = 1 num_feature_channels = 3 image_dims = [3, 3] esm = ESM(omni_image_dims=[10, 20], smooth_mean=False) empty_memory = esm.empty_memory(batch_size, num_timesteps) empty_obs = _get_dummy_obs(batch_size, num_timesteps, num_cams, image_dims, num_feature_channels, empty=True) rel_rot_vec_pose = ivy.array([[[0., 0., 0., 0., 0.1, 0.]]]) empty_obs['control_mean'] = rel_rot_vec_pose empty_obs['agent_rel_mat'] = ivy_mech.rot_vec_pose_to_mat_pose(rel_rot_vec_pose) first_obs = _get_dummy_obs(batch_size, num_timesteps, num_cams, image_dims, num_feature_channels, ones=True) memory_1 = esm(first_obs, empty_memory, batch_size=batch_size, num_timesteps=num_timesteps, num_cams=num_cams, image_dims=image_dims) memory_2 = esm(empty_obs, memory_1, batch_size=batch_size, num_timesteps=num_timesteps, num_cams=num_cams, image_dims=image_dims) memory_3 = esm(empty_obs, memory_2, batch_size=batch_size, num_timesteps=num_timesteps, num_cams=num_cams, image_dims=image_dims) assert not np.allclose(memory_1.mean, memory_3.mean)
def test_realtime_speed(dev_str, call): if call in [helpers.np_call, helpers.jnp_call, helpers.mx_call]: # convolutions not yet implemented in numpy or jax # mxnet is unable to stack or expand zero-dimensional tensors pytest.skip() ivy.seed(0) device = 'cpu' batch_size = 1 num_timesteps = 1 num_cams = 1 num_feature_channels = 3 image_dims = [64, 64] omni_img_dims = [90, 180] esm = ESM(omni_image_dims=omni_img_dims, device=device) memory = esm.empty_memory(batch_size, num_timesteps) start_time = time.perf_counter() for i in range(50): obs = _get_dummy_obs(batch_size, num_timesteps, num_cams, image_dims, num_feature_channels, device) memory = esm(obs, memory, batch_size=batch_size, num_timesteps=num_timesteps, num_cams=num_cams, image_dims=image_dims) memory_mean = memory.mean.numpy() assert memory_mean.shape == tuple([batch_size, num_timesteps] + omni_img_dims + [3 + num_feature_channels]) assert memory_mean[0, 0, 0, 0, 0] == 0. np.max(memory_mean) end_time = time.perf_counter() time_taken = end_time - start_time assert time_taken < 20.
def test_inference(with_args, dev_str, call): if call in [helpers.np_call, helpers.jnp_call, helpers.mx_call]: # convolutions not yet implemented in numpy or jax # mxnet is unable to stack or expand zero-dimensional tensors pytest.skip() batch_size = 5 num_timesteps = 6 num_cams = 7 num_feature_channels = 3 image_dims = [3, 3] esm = ESM() esm(_get_dummy_obs(batch_size, num_timesteps, num_cams, image_dims, num_feature_channels), esm.empty_memory(batch_size, num_timesteps) if with_args else None, batch_size=batch_size if with_args else None, num_timesteps=num_timesteps if with_args else None, num_cams=num_cams if with_args else None, image_dims=image_dims if with_args else None)
def test_values(dev_str, call): if call in [helpers.np_call, helpers.jnp_call, helpers.mx_call]: # convolutions not yet implemented in numpy or jax # mxnet is unable to stack or expand zero-dimensional tensors pytest.skip() device = 'cpu' batch_size = 1 num_timesteps = 1 num_cams = 1 num_feature_channels = 3 image_dims = [128, 128] omni_img_dims = [180, 360] esm = ESM(omni_image_dims=omni_img_dims, device=device) memory = esm.empty_memory(batch_size, num_timesteps) this_dir = os.path.dirname(os.path.realpath(__file__)) for i in range(2): obs = ivy.Container.from_disk(os.path.join(this_dir, 'test_data/obs_{}.hdf5'.format(i))) memory = esm(obs, memory, batch_size=batch_size, num_timesteps=num_timesteps, num_cams=num_cams, image_dims=image_dims) expected_mem = ivy.Container.from_disk(os.path.join(this_dir, 'test_data/mem_{}.hdf5'.format(i))) assert np.allclose(memory.mean, expected_mem.mean, atol=1e-3) assert np.allclose(memory.var, expected_mem.var)
def test_construction(dev_str, call): if call in [helpers.mx_call]: # mxnet is unable to stack or expand zero-dimensional tensors pytest.skip() ESM()