Example #1
0
    def test_dist_info(self, obs_dim, action_dim, filter_dims, filter_sizes,
                       strides, padding, hidden_sizes):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.policies.'
                         'categorical_conv_policy_with_model.MLPModel'),
                        new=SimpleMLPModel):
            with mock.patch(('garage.tf.policies.'
                             'categorical_conv_policy_with_model.CNNModel'),
                            new=SimpleCNNModel):
                policy = CategoricalConvPolicyWithModel(
                    env_spec=env.spec,
                    conv_filters=filter_dims,
                    conv_filter_sizes=filter_sizes,
                    conv_strides=strides,
                    conv_pad=padding,
                    hidden_sizes=hidden_sizes)

        env.reset()
        obs, _, _, _ = env.step(1)

        expected_prob = np.full(action_dim, 0.5)

        policy_probs = policy.dist_info([obs])
        assert np.array_equal(policy_probs['prob'][0], expected_prob)
Example #2
0
    def test_is_pickleable(self):
        env = MetaRLEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalGRUPolicy(env_spec=env.spec,
                                      state_include_action=False)
        policy.build(obs_var)

        obs = env.reset()
        policy.model._gru_cell.weights[0].load(
            tf.ones_like(policy.model._gru_cell.weights[0]).eval())

        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            # yapf: disable
            output2 = sess.run(
                [policy_pickled.distribution.probs],
                feed_dict={
                    policy_pickled.model.input: [[obs.flatten()],
                                                 [obs.flatten()]]
                })
            # yapf: enable
            assert np.array_equal(output1, output2)
Example #3
0
    def test_get_action_state_include_action(self, obs_dim, action_dim,
                                             hidden_dim, obs_type):
        assert obs_type in ['discrete', 'dict']
        if obs_type == 'discrete':
            env = GymEnv(
                DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='discrete'))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=True)

        policy.reset()
        obs = env.reset()[0]
        if obs_type == 'discrete':
            obs = obs.flatten()

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs])
        for action in actions:
            assert env.action_space.contains(action)
Example #4
0
    def test_dist_info_sym_wrong_input(self):
        env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))

        obs_ph = tf.compat.v1.placeholder(
            tf.float32, shape=(None, None, env.observation_space.flat_dim))

        with mock.patch(('metarl.tf.policies.'
                         'categorical_lstm_policy.LSTMModel'),
                        new=SimpleLSTMModel):
            policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                           state_include_action=True)

        policy.reset()
        obs = env.reset()

        policy.dist_info_sym(
            obs_var=obs_ph,
            state_info_vars={'prev_action': np.zeros((3, 1, 1))},
            name='p2_sym')
        # observation batch size = 2 but prev_action batch size = 3
        with pytest.raises(tf.errors.InvalidArgumentError):
            self.sess.run(
                policy.model.networks['p2_sym'].input,
                feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})
Example #5
0
    def test_is_pickleable(self, obs_dim, action_dim):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.q_functions.'
                         'discrete_mlp_q_function.MLPModel'),
                        new=SimpleMLPModel):
            qf = DiscreteMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)

        with tf.variable_scope('DiscreteMLPQFunction/SimpleMLPModel',
                               reuse=True):
            return_var = tf.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [obs]})

        h_data = pickle.dumps(qf)
        with tf.Session(graph=tf.Graph()) as sess:
            qf_pickled = pickle.loads(h_data)
            output2 = sess.run(qf_pickled.q_vals,
                               feed_dict={qf_pickled.input: [obs]})

        assert np.array_equal(output1, output2)
Example #6
0
 def setUp(self):
     super().setUp()
     self.env = TfEnv(DummyDiscreteEnv())
Example #7
0
    def setUp(self):
        super().setUp()
        env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        self.default_initializer = tf.constant_initializer(1)
        self.default_hidden_nonlinearity = tf.nn.tanh
        self.default_recurrent_nonlinearity = tf.nn.sigmoid
        self.default_output_nonlinearity = None
        self.time_step = 1

        self.policy1 = CategoricalGRUPolicy(
            env_spec=env.spec,
            hidden_dim=4,
            hidden_nonlinearity=self.default_hidden_nonlinearity,
            recurrent_nonlinearity=self.default_recurrent_nonlinearity,
            recurrent_w_x_init=self.default_initializer,
            recurrent_w_h_init=self.default_initializer,
            output_nonlinearity=self.default_output_nonlinearity,
            output_w_init=self.default_initializer,
            state_include_action=True,
            name='P1')
        self.policy2 = CategoricalGRUPolicy(
            env_spec=env.spec,
            hidden_dim=4,
            hidden_nonlinearity=self.default_hidden_nonlinearity,
            recurrent_nonlinearity=self.default_recurrent_nonlinearity,
            recurrent_w_x_init=self.default_initializer,
            recurrent_w_h_init=self.default_initializer,
            output_nonlinearity=self.default_output_nonlinearity,
            output_w_init=tf.constant_initializer(2),
            state_include_action=True,
            name='P2')

        self.sess.run(tf.global_variables_initializer())

        self.policy3 = CategoricalGRUPolicyWithModel(
            env_spec=env.spec,
            hidden_dim=4,
            hidden_nonlinearity=self.default_hidden_nonlinearity,
            hidden_w_init=self.default_initializer,
            recurrent_nonlinearity=self.default_recurrent_nonlinearity,
            recurrent_w_init=self.default_initializer,
            output_nonlinearity=self.default_output_nonlinearity,
            output_w_init=self.default_initializer,
            state_include_action=True,
            name='P3')
        self.policy4 = CategoricalGRUPolicyWithModel(
            env_spec=env.spec,
            hidden_dim=4,
            hidden_nonlinearity=self.default_hidden_nonlinearity,
            hidden_w_init=self.default_initializer,
            recurrent_nonlinearity=self.default_recurrent_nonlinearity,
            recurrent_w_init=self.default_initializer,
            output_nonlinearity=self.default_output_nonlinearity,
            output_w_init=tf.constant_initializer(2),
            state_include_action=True,
            name='P4')

        self.policy1.reset()
        self.policy2.reset()
        self.policy3.reset()
        self.policy4.reset()
        self.obs = [env.reset()]
        self.obs = np.concatenate([self.obs for _ in range(self.time_step)],
                                  axis=0)

        self.obs_ph = tf.placeholder(tf.float32,
                                     shape=(None, None,
                                            env.observation_space.flat_dim))
        self.action_ph = tf.placeholder(tf.float32,
                                        shape=(None, None,
                                               env.action_space.flat_dim))

        self.dist1_sym = self.policy1.dist_info_sym(
            obs_var=self.obs_ph,
            state_info_vars={'prev_action': np.zeros((2, self.time_step, 1))},
            name='p1_sym')
        self.dist2_sym = self.policy2.dist_info_sym(
            obs_var=self.obs_ph,
            state_info_vars={'prev_action': np.zeros((2, self.time_step, 1))},
            name='p2_sym')
        self.dist3_sym = self.policy3.dist_info_sym(
            obs_var=self.obs_ph,
            state_info_vars={'prev_action': np.zeros((2, self.time_step, 1))},
            name='p3_sym')
        self.dist4_sym = self.policy4.dist_info_sym(
            obs_var=self.obs_ph,
            state_info_vars={'prev_action': np.zeros((2, self.time_step, 1))},
            name='p4_sym')
 def test_clone(self, obs_dim, action_dim, hidden_sizes):
     env = GarageEnv(
         DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
     qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=hidden_sizes)
     qf_clone = qf.clone('another_qf')
     assert qf_clone._hidden_sizes == qf._hidden_sizes
Example #9
0
 def test_invalid_obs_dim(self, obs_dim):
     with pytest.raises(ValueError):
         env = GarageEnv(DummyDiscreteEnv(obs_dim=obs_dim))
         ContinuousCNNQFunction(env_spec=env.spec,
                                filters=((5, (3, 3)), ),
                                strides=(1, ))
 def test_invalid_env(self):
     env = TfEnv(DummyDiscreteEnv())
     with self.assertRaises(ValueError):
         GaussianGRUPolicyWithModel(env_spec=env.spec)
Example #11
0
 def test_state_info_specs_with_state_include_action(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                    state_include_action=True)
     assert policy.state_info_specs == [('prev_action', (4, ))]
Example #12
0
 def setup_method(self):
     super().setup_method()
     self.env = GarageEnv(DummyDiscreteEnv())
 def setUp(self):
     self.env = TfEnv(DummyDiscreteEnv(random=False))
     self.env_r = TfEnv(
         RepeatAction(DummyDiscreteEnv(random=False), n_frame_to_repeat=4))
Example #14
0
 def test_invalid_env(self):
     env = TfEnv(DummyDiscreteEnv())
     with pytest.raises(ValueError):
         GaussianGRUPolicy2(env_spec=env.spec)
Example #15
0
 def test_invalid_env(self):
     env = GarageEnv(DummyDiscreteEnv())
     with pytest.raises(ValueError):
         GaussianLSTMPolicy(env_spec=env.spec)
Example #16
0
 def test_state_info_specs(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                    state_include_action=False)
     assert policy.state_info_specs == []
Example #17
0
 def setUp(self):
     super().setUp()
     self.data = np.ones((2, 1))
     self.env = TfEnv(DummyDiscreteEnv())
     self.qf = DiscreteMLPQFunction(self.env.spec)
Example #18
0
 def test_clone(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec)
     policy_clone = policy.clone('CategoricalLSTMPolicyClone')
     assert policy.env_spec == policy_clone.env_spec
Example #19
0
 def test_invalid_env(self):
     env = MetaRLEnv(DummyDiscreteEnv())
     with pytest.raises(ValueError):
         GaussianMLPPolicy(env_spec=env.spec)