예제 #1
0
    def test_categorical_lstm_policy(self):
        categorical_lstm_policy = CategoricalLSTMPolicy(
            env_spec=self.env, hidden_dim=1, state_include_action=False)
        categorical_lstm_policy.reset()

        obs = self.env.observation_space.high
        assert categorical_lstm_policy.get_action(obs)
예제 #2
0
    def test_is_pickleable(self):
        env = GarageEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       state_include_action=False)

        policy.reset()
        obs = env.reset()

        policy.model._lstm_cell.weights[0].load(
            tf.ones_like(policy.model._lstm_cell.weights[0]).eval())

        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            output2 = sess.run([policy_pickled.distribution.probs],
                               feed_dict={
                                   policy_pickled.model.input:
                                   [[obs.flatten()], [obs.flatten()]]
                               })  # noqa: E126
            assert np.array_equal(output1, output2)
예제 #3
0
    def test_is_pickleable(self):
        env = GarageEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       state_include_action=False)

        policy.reset()
        obs = env.reset()

        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, None,
                                                      policy.input_dim))
        dist_sym = policy.build(state_input, name='dist_sym').dist
        policy._lstm_cell.weights[0].load(
            tf.ones_like(policy._lstm_cell.weights[0]).eval())

        output1 = self.sess.run(
            [dist_sym.probs],
            feed_dict={state_input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            state_input = tf.compat.v1.placeholder(
                tf.float32, shape=(None, None, policy_pickled.input_dim))
            dist_sym = policy_pickled.build(state_input, name='dist_sym').dist
            output2 = sess.run(
                [dist_sym.probs],
                feed_dict={state_input: [[obs.flatten()],
                                         [obs.flatten()]]})  # noqa: E126
            assert np.array_equal(output1, output2)
예제 #4
0
 def test_clone(self):
     env = GymEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec)
     policy_clone = policy.clone('CategoricalLSTMPolicyClone')
     assert policy.env_spec == policy_clone.env_spec
     for cloned_param, param in zip(policy_clone.parameters.values(),
                                    policy.parameters.values()):
         assert np.array_equal(cloned_param, param)
예제 #5
0
    def test_categorical_lstm_policy(self):
        categorical_lstm_policy = CategoricalLSTMPolicy(env_spec=self.env,
                                                        hidden_dim=1)
        self.sess.run(tf.global_variables_initializer())

        categorical_lstm_policy.reset()

        obs = self.env.observation_space.high
        assert categorical_lstm_policy.get_action(obs)
 def test_clone(self):
     env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
     with mock.patch(('garage.tf.policies.'
                      'categorical_lstm_policy.LSTMModel'),
                     new=SimpleLSTMModel):
         policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                        state_include_action=False)
         policy_cloned = policy.clone('cloned_policy')
         assert policy_cloned.name == 'cloned_policy'
         assert np.array_equal(policy.get_param_values(),
                               policy_cloned.get_param_values())
예제 #7
0
    def test_trpo_lstm_cartpole(self):
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            env = normalize(GymEnv('CartPole-v1', max_episode_length=100))

            policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            sampler = LocalSampler(
                agents=policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                is_tf_worker=True)

            algo = TRPO(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        sampler=sampler,
                        discount=0.99,
                        max_kl_step=0.01,
                        optimizer_args=dict(hvp_approach=FiniteDifferenceHVP(
                            base_eps=1e-5)))

            snapshotter.snapshot_dir = './'
            trainer.setup(algo, env)
            last_avg_ret = trainer.train(n_epochs=10, batch_size=2048)
            assert last_avg_ret > 60

            env.close()
def run_task(*_):
    with LocalRunner() as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalLSTMPolicy(
            name='policy',
            env_spec=env.spec,
            lstm_layer_cls=L.TfBasicLSTMLayer,
            # gru_layer_cls=L.GRULayer,
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(
            env_spec=env.spec,
            policy=policy,
            baseline=baseline,
            max_path_length=100,
            discount=0.99,
            max_kl_step=0.01,
            optimizer=ConjugateGradientOptimizer,
            optimizer_args=dict(
                hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)))

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)
 def test_invalid_env(self):
     env = TfEnv(DummyBoxEnv())
     with mock.patch(('garage.tf.policies.'
                      'categorical_lstm_policy.LSTMModel'),
                     new=SimpleLSTMModel):
         with pytest.raises(ValueError):
             CategoricalLSTMPolicy(env_spec=env.spec)
예제 #10
0
 def test_process_samples_discrete_recurrent(self):
     env = TfEnv(DummyDiscreteEnv())
     policy = CategoricalLSTMPolicy(env_spec=env.spec)
     baseline = LinearFeatureBaseline(env_spec=env.spec)
     max_path_length = 100
     with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
         algo = BatchPolopt2(env_spec=env.spec,
                             policy=policy,
                             baseline=baseline,
                             max_path_length=max_path_length,
                             flatten_input=True)
         runner.setup(algo, env, sampler_args=dict(n_envs=1))
         runner.train(n_epochs=1, batch_size=max_path_length)
         paths = runner.obtain_samples(0)
         samples = algo.process_samples(0, paths)
         # Since there is only 1 vec_env in the sampler and DummyDiscreteEnv
         # always terminate, number of paths must be max_path_length, and
         # batch size must be max_path_length as well, i.e. 100
         assert samples['observations'].shape == (
             max_path_length, env.observation_space.flat_dim)
         assert samples['actions'].shape == (max_path_length,
                                             env.action_space.n)
         assert samples['rewards'].shape == (max_path_length, )
         assert samples['baselines'].shape == (max_path_length, )
         assert samples['returns'].shape == (max_path_length, )
         # there is 100 path
         assert samples['lengths'].shape == (max_path_length, )
         # non-recurrent policy has empty agent info
         for key, shape in policy.state_info_specs:
             assert samples['agent_infos'][key].shape == (max_path_length,
                                                          np.prod(shape))
         assert isinstance(samples['average_return'], float)
예제 #11
0
    def test_is_pickleable(self):
        env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        with mock.patch(('garage.tf.policies.'
                         'categorical_lstm_policy.LSTMModel'),
                        new=SimpleLSTMModel):
            policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                           state_include_action=False)

        env.reset()
        obs = env.reset()

        with tf.compat.v1.variable_scope('CategoricalLSTMPolicy/prob_network',
                                         reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = self.sess.run(
            policy.model.outputs[0],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            output2 = sess.run(policy_pickled.model.outputs[0],
                               feed_dict={
                                   policy_pickled.model.input:
                                   [[obs.flatten()], [obs.flatten()]]
                               })  # noqa: E126
            assert np.array_equal(output1, output2)
예제 #12
0
def trpo_cartpole_recurrent(ctxt, seed, n_epochs, batch_size, plot):
    """Train TRPO with a recurrent policy on CartPole.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        n_epochs (int): Number of epochs for training.
        seed (int): Used to seed the random number generator to produce
            determinism.
        batch_size (int): Batch size used for training.
        plot (bool): Whether to plot or not.

    """
    set_seed(seed)
    with LocalTFRunner(snapshot_config=ctxt) as runner:
        env = GymEnv('CartPole-v1')

        policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_episode_length=100,
                    discount=0.99,
                    max_kl_step=0.01,
                    optimizer=ConjugateGradientOptimizer,
                    optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)))

        runner.setup(algo, env)
        runner.train(n_epochs=n_epochs, batch_size=batch_size, plot=plot)
예제 #13
0
def run_task(snapshot_config, *_):
    """Defines the main experiment routine.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): Configuration
            values for snapshotting.
        *_ (object): Hyperparameters (unused).

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        env = TfEnv(env_name='CartPole-v1')

        policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_path_length=100,
                    discount=0.99,
                    max_kl_step=0.01,
                    optimizer=ConjugateGradientOptimizer,
                    optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                        base_eps=1e-5)))

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)
예제 #14
0
    def test_categorical_lstm_policy(self):
        categorical_lstm_policy = CategoricalLSTMPolicy(
            env_spec=self.env, hidden_dim=1, state_include_action=False)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        categorical_lstm_policy.build(self.obs_var)
        categorical_lstm_policy.reset()

        obs = self.env.observation_space.high
        assert categorical_lstm_policy.get_action(obs)
예제 #15
0
    def test_build_state_not_include_action(self, obs_dim, action_dim,
                                            hidden_dim):
        env = GarageEnv(
            DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=False)
        policy.reset(do_resets=None)
        obs = env.reset()

        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, None,
                                                      policy.input_dim))
        dist_sym = policy.build(state_input, name='dist_sym').dist
        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})
        output2 = self.sess.run(
            [dist_sym.probs],
            feed_dict={state_input: [[obs.flatten()], [obs.flatten()]]})
        assert np.array_equal(output1, output2)
예제 #16
0
    def test_get_action(self, mock_rand, obs_dim, action_dim):
        mock_rand.return_value = 0

        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))

        with mock.patch(('garage.tf.policies.'
                         'categorical_lstm_policy.LSTMModel'),
                        new=SimpleLSTMModel):
            policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                           state_include_action=False)

        policy.reset()
        obs = env.reset()

        expected_prob = np.full(action_dim, 0.5)

        action, agent_info = policy.get_action(obs)
        assert env.action_space.contains(action)
        assert action == 0
        assert np.array_equal(agent_info['prob'], expected_prob)

        actions, agent_infos = policy.get_actions([obs])
        for action, prob in zip(actions, agent_infos['prob']):
            assert env.action_space.contains(action)
            assert action == 0
            assert np.array_equal(prob, expected_prob)
예제 #17
0
def run_garage(env, seed, log_dir):
    '''
    Create garage model and training.
    Replace the ppo with the algorithm you want to run.
    :param env: Environment of the task.
    :param seed: Random seed for the trial.
    :param log_dir: Log dir path.
    :return:
    '''
    deterministic.set_seed(seed)
    config = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                      intra_op_parallelism_threads=12,
                                      inter_op_parallelism_threads=12)
    sess = tf.compat.v1.Session(config=config)
    with LocalTFRunner(snapshot_config, sess=sess, max_cpus=12) as runner:
        env = TfEnv(normalize(env))

        policy = CategoricalLSTMPolicy(
            env_spec=env.spec,
            hidden_dim=32,
            hidden_nonlinearity=tf.nn.tanh,
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = PPO(
            env_spec=env.spec,
            policy=policy,
            baseline=baseline,
            max_path_length=100,
            discount=0.99,
            gae_lambda=0.95,
            lr_clip_range=0.2,
            policy_ent_coeff=0.0,
            optimizer_args=dict(
                batch_size=32,
                max_epochs=10,
                tf_optimizer_args=dict(learning_rate=1e-3),
            ),
        )

        # Set up logger since we are not using run_experiment
        tabular_log_file = osp.join(log_dir, 'progress.csv')
        dowel_logger.add_output(dowel.StdOutput())
        dowel_logger.add_output(dowel.CsvOutput(tabular_log_file))
        dowel_logger.add_output(dowel.TensorBoardOutput(log_dir))

        runner.setup(algo, env, sampler_args=dict(n_envs=12))
        runner.train(n_epochs=488, batch_size=2048)
        dowel_logger.remove_all()

        return tabular_log_file
예제 #18
0
    def test_dist_info_sym(self, obs_dim, action_dim):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))

        obs_ph = tf.compat.v1.placeholder(
            tf.float32, shape=(None, None, env.observation_space.flat_dim))

        with mock.patch(('garage.tf.policies.'
                         'categorical_lstm_policy.LSTMModel'),
                        new=SimpleLSTMModel):
            policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                           state_include_action=False)

        policy.reset()
        obs = env.reset()

        dist_sym = policy.dist_info_sym(obs_var=obs_ph,
                                        state_info_vars=None,
                                        name='p2_sym')
        dist = self.sess.run(
            dist_sym, feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})

        assert np.array_equal(dist['prob'], np.full((2, 1, action_dim), 0.5))
예제 #19
0
def categorical_lstm_policy(ctxt, env_id, seed):
    """Create Categorical LSTM Policy on TF-PPO.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the
            snapshotter.
        env_id (str): Environment id of the task.
        seed (int): Random positive integer for the trial.

    """
    deterministic.set_seed(seed)

    with TFTrainer(ctxt) as trainer:
        env = normalize(GymEnv(env_id))

        policy = CategoricalLSTMPolicy(
            env_spec=env.spec,
            hidden_dim=32,
            hidden_nonlinearity=tf.nn.tanh,
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        sampler = RaySampler(agents=policy,
                             envs=env,
                             max_episode_length=env.spec.max_episode_length,
                             is_tf_worker=True)

        algo = PPO(
            env_spec=env.spec,
            policy=policy,
            baseline=baseline,
            sampler=sampler,
            discount=0.99,
            gae_lambda=0.95,
            lr_clip_range=0.2,
            policy_ent_coeff=0.0,
            optimizer_args=dict(
                batch_size=32,
                max_optimization_epochs=10,
                learning_rate=1e-3,
            ),
        )

        trainer.setup(algo, env)
        trainer.train(n_epochs=488, batch_size=2048)
예제 #20
0
    def test_get_action(self, obs_dim, action_dim, hidden_dim):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=False)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
예제 #21
0
 def test_policies(self):
     """Test the policies initialization."""
     box_env = TfEnv(DummyBoxEnv())
     discrete_env = TfEnv(DummyDiscreteEnv())
     categorical_gru_policy = CategoricalGRUPolicy(env_spec=discrete_env,
                                                   hidden_dim=1)
     categorical_lstm_policy = CategoricalLSTMPolicy(env_spec=discrete_env,
                                                     hidden_dim=1)
     categorical_mlp_policy = CategoricalMLPPolicy(env_spec=discrete_env,
                                                   hidden_sizes=(1, ))
     continuous_mlp_policy = ContinuousMLPPolicy(env_spec=box_env,
                                                 hidden_sizes=(1, ))
     deterministic_mlp_policy = DeterministicMLPPolicy(env_spec=box_env,
                                                       hidden_sizes=(1, ))
     gaussian_gru_policy = GaussianGRUPolicy(env_spec=box_env, hidden_dim=1)
     gaussian_lstm_policy = GaussianLSTMPolicy(env_spec=box_env,
                                               hidden_dim=1)
     gaussian_mlp_policy = GaussianMLPPolicy(env_spec=box_env,
                                             hidden_sizes=(1, ))
예제 #22
0
def categorical_lstm_policy(ctxt, env_id, seed):
    """Create Categorical LSTM Policy on TF-PPO.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the
            snapshotter.
        env_id (str): Environment id of the task.
        seed (int): Random positive integer for the trial.

    """
    deterministic.set_seed(seed)

    with LocalTFRunner(ctxt, max_cpus=12) as runner:
        env = TfEnv(normalize(gym.make(env_id)))

        policy = CategoricalLSTMPolicy(
            env_spec=env.spec,
            hidden_dim=32,
            hidden_nonlinearity=tf.nn.tanh,
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = PPO(
            env_spec=env.spec,
            policy=policy,
            baseline=baseline,
            max_path_length=100,
            discount=0.99,
            gae_lambda=0.95,
            lr_clip_range=0.2,
            policy_ent_coeff=0.0,
            optimizer_args=dict(
                batch_size=32,
                max_epochs=10,
                tf_optimizer_args=dict(learning_rate=1e-3),
            ),
        )

        runner.setup(algo, env, sampler_args=dict(n_envs=12))
        runner.train(n_epochs=488, batch_size=2048)
예제 #23
0
    def test_get_action(self, obs_dim, action_dim, hidden_dim):
        env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=False)

        policy.reset()
        obs = env.reset()[0]

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
예제 #24
0
파일: test_trpo.py 프로젝트: gagkhan/garage
    def test_trpo_lstm_cartpole(self):
        with LocalTFRunner(snapshot_config, sess=self.sess) as runner:
            env = TfEnv(normalize(gym.make('CartPole-v1')))

            policy = CategoricalLSTMPolicy(name='policy', env_spec=env.spec)

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = TRPO(env_spec=env.spec,
                        policy=policy,
                        baseline=baseline,
                        max_path_length=100,
                        discount=0.99,
                        max_kl_step=0.01,
                        optimizer_args=dict(hvp_approach=FiniteDifferenceHvp(
                            base_eps=1e-5)))

            snapshotter.snapshot_dir = './'
            runner.setup(algo, env)
            last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
            assert last_avg_ret > 80

            env.close()
예제 #25
0
    def test_build_state_include_action(self, obs_dim, action_dim, hidden_dim):
        env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=True)
        policy.reset(do_resets=None)
        obs = env.reset()[0]

        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, None,
                                                      policy.input_dim))
        dist_sym = policy.build(state_input, name='dist_sym').dist
        dist_sym2 = policy.build(state_input, name='dist_sym2').dist

        concat_obs = np.concatenate([obs.flatten(), np.zeros(action_dim)])
        output1 = self.sess.run(
            [dist_sym.probs],
            feed_dict={state_input: [[concat_obs], [concat_obs]]})
        output2 = self.sess.run(
            [dist_sym2.probs],
            feed_dict={state_input: [[concat_obs], [concat_obs]]})
        assert np.array_equal(output1, output2)
예제 #26
0
    def test_get_action_state_include_action(self, obs_dim, action_dim,
                                             hidden_dim, obs_type):
        assert obs_type in ['discrete', 'dict']
        if obs_type == 'discrete':
            env = GymEnv(
                DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='discrete'))
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       hidden_dim=hidden_dim,
                                       state_include_action=True)

        policy.reset()
        obs = env.reset()[0]
        if obs_type == 'discrete':
            obs = obs.flatten()

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs])
        for action in actions:
            assert env.action_space.contains(action)
예제 #27
0
    def test_dist_info_sym_wrong_input(self):
        env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))

        obs_ph = tf.compat.v1.placeholder(
            tf.float32, shape=(None, None, env.observation_space.flat_dim))

        with mock.patch(('garage.tf.policies.'
                         'categorical_lstm_policy.LSTMModel'),
                        new=SimpleLSTMModel):
            policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                           state_include_action=True)

        policy.reset()
        obs = env.reset()

        policy.dist_info_sym(
            obs_var=obs_ph,
            state_info_vars={'prev_action': np.zeros((3, 1, 1))},
            name='p2_sym')
        # observation batch size = 2 but prev_action batch size = 3
        with pytest.raises(tf.errors.InvalidArgumentError):
            self.sess.run(
                policy.model.networks['p2_sym'].input,
                feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})
예제 #28
0
 def test_clone(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec)
     policy_clone = policy.clone('CategoricalLSTMPolicyClone')
     assert policy.env_spec == policy_clone.env_spec
예제 #29
0
 def test_state_info_specs_with_state_include_action(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                    state_include_action=True)
     assert policy.state_info_specs == [('prev_action', (4, ))]
예제 #30
0
 def test_state_info_specs(self):
     env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4))
     policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                    state_include_action=False)
     assert policy.state_info_specs == []