class TestQfDerivedPolicy(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GarageEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs, _, _, _ = self.env.step(1) action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions, _ = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action) def test_is_pickleable(self): with tf.compat.v1.variable_scope('SimpleQFunction/SimpleMLPModel', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) obs, _, _, _ = self.env.step(1) action1, _ = self.policy.get_action(obs) p = pickle.dumps(self.policy) with tf.compat.v1.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action2, _ = policy_pickled.get_action(obs) assert action1 == action2
def setUp(self): super().setUp() self.env = TfEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy( env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.global_variables_initializer()) self.env.reset()
def setup_method(self): super().setup_method() self.env = GarageEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset()
def setup_method(self): super().setup_method() self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()), is_image=True) self.qf = DiscreteCNNQFunction(env_spec=self.env.spec, filters=((1, (1, 1)), ), strides=(1, ), dueling=False) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset()
class TestQfDerivedPolicy(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GymEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs = self.env.step(1).observation action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions, _ = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action) def test_get_param(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') assert self.policy.get_param_values() == return_var.eval() def test_is_pickleable(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) obs = self.env.step(1).observation action1, _ = self.policy.get_action(obs) p = pickle.dumps(self.policy) with tf.compat.v1.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action2, _ = policy_pickled.get_action(obs) assert action1 == action2 def test_does_not_support_dict_obs_space(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='discrete')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) def test_invalid_action_spaces(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec) DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
def test_does_not_support_dict_obs_space(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='discrete')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
def run_task(snapshot_config, *_): """Run task.""" with LocalRunner(snapshot_config=snapshot_config) as runner: n_epochs = 100 n_epoch_cycles = 20 sampler_batch_size = 500 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env = gym.make('PongNoFrameskip-v4') env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4) env = TfEnv(env) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=int(5e4), time_horizon=1) qf = DiscreteCNNQFunction( env_spec=env.spec, filter_dims=(8, 4, 3), num_filters=(32, 64, 64), strides=(4, 2, 1), dueling=False) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN( env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=0.99, min_buffer_size=int(1e4), double_q=False, n_train_steps=500, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=2, buffer_batch_size=32) runner.setup(algo, env) runner.train( n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=sampler_batch_size)
class TestQfDerivedPolicy(TfGraphTestCase): def setUp(self): super().setUp() self.env = TfEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQfDerivedPolicy( env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.global_variables_initializer()) self.env.reset() def test_discrete_qf_derived_policy(self): obs, _, _, _ = self.env.step(1) action = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action)
def test_dqn_cartpole_pickle(self): """Test DQN with CartPole environment.""" with LocalRunner(self.sess) as runner: n_epochs = 10 n_epoch_cycles = 10 sampler_batch_size = 500 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env = TfEnv(gym.make('CartPole-v0')) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN( env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=False, n_train_steps=500, grad_norm_clipping=5.0, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) with tf.variable_scope( 'DiscreteMLPQFunction/MLPModel/mlp/hidden_0', reuse=True): bias = tf.get_variable('bias') # assign it to all one old_bias = tf.ones_like(bias).eval() bias.load(old_bias) h = pickle.dumps(algo) with tf.Session(graph=tf.Graph()): pickle.loads(h) with tf.variable_scope( 'DiscreteMLPQFunction/MLPModel/mlp/hidden_0', reuse=True): new_bias = tf.get_variable('bias') new_bias = new_bias.eval() assert np.array_equal(old_bias, new_bias) env.close()
def test_dqn_cartpole_pickle(self): """Test DQN with CartPole environment.""" deterministic.set_seed(100) with LocalTFRunner(snapshot_config, sess=self.sess) as runner: n_epochs = 10 steps_per_epoch = 10 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GarageEnv(gym.make('CartPole-v0')) replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=epilson_greedy_policy, replay_buffer=replay_buffer, max_path_length=100, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=False, n_train_steps=500, grad_norm_clipping=5.0, steps_per_epoch=steps_per_epoch, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) with tf.compat.v1.variable_scope( 'DiscreteMLPQFunction/mlp/hidden_0', reuse=True): bias = tf.compat.v1.get_variable('bias') # assign it to all one old_bias = tf.ones_like(bias).eval() bias.load(old_bias) h = pickle.dumps(algo) with tf.compat.v1.Session(graph=tf.Graph()): pickle.loads(h) with tf.compat.v1.variable_scope( 'DiscreteMLPQFunction/mlp/hidden_0', reuse=True): new_bias = tf.compat.v1.get_variable('bias') new_bias = new_bias.eval() assert np.array_equal(old_bias, new_bias) env.close()
def test_dqn_cartpole_grad_clip(self): """Test DQN with CartPole environment.""" with LocalRunner(self.sess) as runner: n_epochs = 10 n_epoch_cycles = 10 sampler_batch_size = 500 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env = TfEnv(gym.make('CartPole-v0')) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN( env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=False, n_train_steps=500, grad_norm_clipping=5.0, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) last_avg_ret = runner.train( n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=sampler_batch_size) assert last_avg_ret > 20 env.close()
def dqn_cartpole(ctxt=None, seed=1): """Train TRPO with CubeCrash-v0 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. """ set_seed(seed) with LocalTFRunner(ctxt) as runner: n_epochs = 10 steps_per_epoch = 10 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = TfEnv(gym.make('CartPole-v0')) replay_buffer = SimpleReplayBuffer(env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=True, n_train_steps=500, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
def run_task(snapshot_config, *_): """Run task. Args: snapshot_config (garage.experiment.SnapshotConfig): The snapshot configuration used by LocalRunner to create the snapshotter. *_ (object): Ignored by this function. """ with LocalTFRunner(snapshot_config=snapshot_config) as runner: n_epochs = 10 steps_per_epoch = 10 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = TfEnv(gym.make('CartPole-v0')) replay_buffer = SimpleReplayBuffer(env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=True, n_train_steps=500, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
def test_dqn_cartpole_grad_clip(self): """Test DQN with CartPole environment.""" deterministic.set_seed(100) with LocalTFRunner(snapshot_config, sess=self.sess) as runner: n_epochs = 10 steps_per_epoch = 10 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GarageEnv(gym.make('CartPole-v0')) replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=epilson_greedy_policy, replay_buffer=replay_buffer, max_path_length=100, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=False, n_train_steps=500, grad_norm_clipping=5.0, steps_per_epoch=steps_per_epoch, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) last_avg_ret = runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size) assert last_avg_ret > 9 env.close()
class TestQfDerivedPolicyImageObs(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()), is_image=True) self.qf = DiscreteCNNQFunction(env_spec=self.env.spec, filters=((1, (1, 1)), ), strides=(1, ), dueling=False) self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_obs_unflattened(self): """Test if a flattened image obs is passed to get_action then it is unflattened. """ obs = self.env.observation_space.sample() action, _ = self.policy.get_action( self.env.observation_space.flatten(obs)) self.env.step(action)
def run_task(*_): """Run task.""" with LocalRunner() as runner: n_epochs = 10 n_epoch_cycles = 10 sampler_batch_size = 500 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env = TfEnv(gym.make('CartPole-v0')) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN( env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=True, n_train_steps=500, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) runner.train( n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=sampler_batch_size)
def run_task(snapshot_config, *_): """Run task.""" with LocalTFRunner(snapshot_config=snapshot_config) as runner: n_epochs = 500 n_epoch_cycles = 20 sampler_batch_size = 100 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env_name = 'MountainCar-v0' env = TfEnv(gym.make(env_name)) replay_buffer = SimpleReplayBuffer(env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(20,)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=0.5, min_epsilon=0.01, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-3, discount=0.99, min_buffer_size=int(1e3), double_q=False, n_train_steps=50, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=5, buffer_batch_size=64) runner.setup(algo, env) runner.train(n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=sampler_batch_size)
def run_task(snapshot_config, variant_data, *_): """Run task. Args: snapshot_config (garage.experiment.SnapshotConfig): The snapshot configuration used by LocalRunner to create the snapshotter. variant_data (dict): Custom arguments for the task. *_ (object): Ignored by this function. """ with LocalTFRunner(snapshot_config=snapshot_config) as runner: n_epochs = 100 steps_per_epoch = 20 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = gym.make('PongNoFrameskip-v4') env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4) env = TfEnv(env) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=variant_data['buffer_size'], time_horizon=1) qf = DiscreteCNNQFunction(env_spec=env.spec, filter_dims=(8, 4, 3), num_filters=(32, 64, 64), strides=(4, 2, 1), dueling=False) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=0.99, min_buffer_size=int(1e4), double_q=False, n_train_steps=500, steps_per_epoch=steps_per_epoch, target_network_update_freq=2, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
def test_invalid_action_spaces(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec) DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_path_length=None): """Train DQN on PongNoFrameskip-v4 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. buffer_size (int): Number of timesteps to store in replay buffer. max_path_length (int): Maximum length of a path after which a path is considered complete. This is used during testing to minimize the memory required to store a single path. """ set_seed(seed) with LocalTFRunner(ctxt) as runner: n_epochs = 100 steps_per_epoch = 20 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = gym.make('PongNoFrameskip-v4') env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4) env = GarageEnv(env, is_image=True) replay_buffer = PathBuffer(capacity_in_transitions=buffer_size) qf = DiscreteCNNQFunction(env_spec=env.spec, filters=( (32, (8, 8)), (64, (4, 4)), (64, (3, 3)), ), strides=(4, 2, 1), dueling=False) # yapf: disable policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=0.99, min_buffer_size=int(1e4), max_path_length=max_path_length, double_q=False, n_train_steps=500, steps_per_epoch=steps_per_epoch, target_network_update_freq=2, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
def run_task(snapshot_config, *_): """Run task.""" with LocalTFRunner(snapshot_config=snapshot_config) as runner: n_epochs = 500 n_epoch_cycles = 20 sampler_batch_size = 100 num_timesteps = n_epochs * n_epoch_cycles * sampler_batch_size env_name = 'MountainCar-v0' env = TfEnv(gym.make(env_name)) replay_buffer = SimpleReplayBuffer(env_spec=env.spec, size_in_transitions=int(1e4), time_horizon=1) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(20, ), hidden_nonlinearity=tf.nn.relu) obs_model = DiscreteMLPObsFunction(env_spec=env.spec, hidden_sizes=(20, ), hidden_nonlinearity=tf.nn.relu) reward_model = DiscreteMLPRewardFunction( env_spec=env.spec, hidden_sizes=(20, ), hidden_nonlinearity=tf.nn.relu) #terminal model for predicting the end of an episode terminal_model = MLPTerminalFunction(env_spec=env.spec, hidden_sizes=(20, ), hidden_nonlinearity=tf.nn.relu) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=0.5, min_epsilon=0.01, decay_ratio=0.1) algo = JoleDQN(env_spec=env.spec, policy=policy, qf=qf, obs_model=obs_model, reward_model=reward_model, terminal_model=terminal_model, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-3, discount=0.99, min_buffer_size=int(1e3), double_q=False, n_train_steps=50, n_epoch_cycles=n_epoch_cycles, target_network_update_freq=100, buffer_batch_size=64, env_name=env_name) runner.setup(algo, env) runner.train(n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=sampler_batch_size)