def test_obs_not_image(self):
        env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=False)
        with mock.patch(('metarl.tf.baselines.'
                         'gaussian_cnn_baseline.'
                         'GaussianCNNRegressor'),
                        new=SimpleGaussianCNNRegressor):
            with mock.patch(
                    'metarl.tf.baselines.'
                    'gaussian_cnn_baseline.'
                    'normalize_pixel_batch',
                    side_effect=normalize_pixel_batch) as npb:

                gcb = GaussianCNNBaseline(env_spec=env.spec)

                obs_dim = env.spec.observation_space.shape
                paths = [{
                    'observations': [np.full(obs_dim, 1)],
                    'returns': [1]
                }, {
                    'observations': [np.full(obs_dim, 2)],
                    'returns': [2]
                }]

                gcb.fit(paths)
                obs = {
                    'observations': [np.full(obs_dim, 1),
                                     np.full(obs_dim, 2)]
                }
                gcb.predict(obs)
                assert not npb.called
    def test_obs_is_image(self):
        env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=True)
        with mock.patch(('metarl.tf.baselines.'
                         'gaussian_cnn_baseline.'
                         'GaussianCNNRegressor'),
                        new=SimpleGaussianCNNRegressor):
            with mock.patch(
                    'metarl.tf.baselines.'
                    'gaussian_cnn_baseline.'
                    'normalize_pixel_batch',
                    side_effect=normalize_pixel_batch) as npb:

                gcb = GaussianCNNBaseline(env_spec=env.spec)

                obs_dim = env.spec.observation_space.shape
                paths = [{
                    'observations': [np.full(obs_dim, 1)],
                    'returns': [1]
                }, {
                    'observations': [np.full(obs_dim, 2)],
                    'returns': [2]
                }]

                gcb.fit(paths)
                observations = np.concatenate(
                    [p['observations'] for p in paths])
                assert npb.call_count == 1, (
                    "Expected '%s' to have been called once. Called %s times."
                    % (npb._mock_name or 'mock', npb.call_count))
                assert (npb.call_args_list[0][0][0] == observations).all()

                obs = {
                    'observations': [np.full(obs_dim, 1),
                                     np.full(obs_dim, 2)]
                }
                observations = obs['observations']
                gcb.predict(obs)
                assert npb.call_args_list[1][0][0] == observations
    def test_fit(self, obs_dim):
        box_env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim))
        with mock.patch(('metarl.tf.baselines.'
                         'gaussian_cnn_baseline.'
                         'GaussianCNNRegressor'),
                        new=SimpleGaussianCNNRegressor):
            gcb = GaussianCNNBaseline(env_spec=box_env.spec)
        paths = [{
            'observations': [np.full(obs_dim, 1)],
            'returns': [1]
        }, {
            'observations': [np.full(obs_dim, 2)],
            'returns': [2]
        }]
        gcb.fit(paths)

        obs = {'observations': [np.full(obs_dim, 1), np.full(obs_dim, 2)]}
        prediction = gcb.predict(obs)
        assert np.array_equal(prediction, [1, 2])
    def test_is_pickleable(self):
        box_env = MetaRLEnv(DummyBoxEnv(obs_dim=(1, 1)))
        with mock.patch(('metarl.tf.baselines.'
                         'gaussian_cnn_baseline.'
                         'GaussianCNNRegressor'),
                        new=SimpleGaussianCNNRegressor):
            gcb = GaussianCNNBaseline(env_spec=box_env.spec)
        obs = {'observations': [np.full((1, 1), 1), np.full((1, 1), 1)]}

        with tf.compat.v1.variable_scope('GaussianCNNBaseline', reuse=True):
            return_var = tf.compat.v1.get_variable(
                'SimpleGaussianCNNModel/return_var')
        return_var.load(1.0)

        prediction = gcb.predict(obs)

        h = pickle.dumps(gcb)

        with tf.compat.v1.Session(graph=tf.Graph()):
            gcb_pickled = pickle.loads(h)
            prediction2 = gcb_pickled.predict(obs)

            assert np.array_equal(prediction, prediction2)