def test_obs_not_image(self): env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=False) with mock.patch(('metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'GaussianCNNRegressor'), new=SimpleGaussianCNNRegressor): with mock.patch( 'metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'normalize_pixel_batch', side_effect=normalize_pixel_batch) as npb: gcb = GaussianCNNBaseline(env_spec=env.spec) obs_dim = env.spec.observation_space.shape paths = [{ 'observations': [np.full(obs_dim, 1)], 'returns': [1] }, { 'observations': [np.full(obs_dim, 2)], 'returns': [2] }] gcb.fit(paths) obs = { 'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)] } gcb.predict(obs) assert not npb.called
def test_obs_is_image(self): env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=True) with mock.patch(('metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'GaussianCNNRegressor'), new=SimpleGaussianCNNRegressor): with mock.patch( 'metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'normalize_pixel_batch', side_effect=normalize_pixel_batch) as npb: gcb = GaussianCNNBaseline(env_spec=env.spec) obs_dim = env.spec.observation_space.shape paths = [{ 'observations': [np.full(obs_dim, 1)], 'returns': [1] }, { 'observations': [np.full(obs_dim, 2)], 'returns': [2] }] gcb.fit(paths) observations = np.concatenate( [p['observations'] for p in paths]) assert npb.call_count == 1, ( "Expected '%s' to have been called once. Called %s times." % (npb._mock_name or 'mock', npb.call_count)) assert (npb.call_args_list[0][0][0] == observations).all() obs = { 'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)] } observations = obs['observations'] gcb.predict(obs) assert npb.call_args_list[1][0][0] == observations
def test_fit(self, obs_dim): box_env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim)) with mock.patch(('metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'GaussianCNNRegressor'), new=SimpleGaussianCNNRegressor): gcb = GaussianCNNBaseline(env_spec=box_env.spec) paths = [{ 'observations': [np.full(obs_dim, 1)], 'returns': [1] }, { 'observations': [np.full(obs_dim, 2)], 'returns': [2] }] gcb.fit(paths) obs = {'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)]} prediction = gcb.predict(obs) assert np.array_equal(prediction, [1, 2])
def test_is_pickleable(self): box_env = MetaRLEnv(DummyBoxEnv(obs_dim=(1, 1))) with mock.patch(('metarl.tf.baselines.' 'gaussian_cnn_baseline.' 'GaussianCNNRegressor'), new=SimpleGaussianCNNRegressor): gcb = GaussianCNNBaseline(env_spec=box_env.spec) obs = {'observations': [np.full((1, 1), 1), np.full((1, 1), 1)]} with tf.compat.v1.variable_scope('GaussianCNNBaseline', reuse=True): return_var = tf.compat.v1.get_variable( 'SimpleGaussianCNNModel/return_var') return_var.load(1.0) prediction = gcb.predict(obs) h = pickle.dumps(gcb) with tf.compat.v1.Session(graph=tf.Graph()): gcb_pickled = pickle.loads(h) prediction2 = gcb_pickled.predict(obs) assert np.array_equal(prediction, prediction2)