def test_return_bullet_env(): env = GymEnv(env=gym.make('CartPoleBulletEnv-v1')) assert isinstance(env, BulletEnv) env = GymEnv(env='CartPoleBulletEnv-v1') assert isinstance(env, BulletEnv) env = GymEnv(gym.make('CartPoleBulletEnv-v1')) assert isinstance(env, BulletEnv) env = GymEnv('CartPoleBulletEnv-v1') assert isinstance(env, BulletEnv)
def rl2_ppo_metaworld_ml1_push(ctxt, seed, max_episode_length, meta_batch_size, n_epochs, episode_per_task): """Train PPO with ML1 environment. Args: ctxt (ExperimentContext): The experiment configuration used by :class:`~LocalRunner` to create the :class:`~Snapshotter`. seed (int): Used to seed the random number generator to produce determinism. max_episode_length (int): Maximum length of a single episode. meta_batch_size (int): Meta batch size. n_epochs (int): Total number of epochs for training. episode_per_task (int): Number of training episode per task. """ set_seed(seed) with LocalTFRunner(snapshot_config=ctxt) as runner: tasks = task_sampler.SetTaskSampler( lambda: RL2Env(GymEnv(mwb.ML1.get_train_tasks('push-v1')))) env_spec = RL2Env(GymEnv(mwb.ML1.get_train_tasks('push-v1'))).spec policy = GaussianGRUPolicy(name='policy', hidden_dim=64, env_spec=env_spec, state_include_action=False) baseline = LinearFeatureBaseline(env_spec=env_spec) algo = RL2PPO(rl2_max_episode_length=max_episode_length, meta_batch_size=meta_batch_size, task_sampler=tasks, env_spec=env_spec, policy=policy, baseline=baseline, discount=0.99, gae_lambda=0.95, lr_clip_range=0.2, optimizer_args=dict( batch_size=32, max_episode_length=10, ), stop_entropy_gradient=True, entropy_method='max', policy_ent_coeff=0.02, center_adv=False, max_episode_length=max_episode_length * episode_per_task) runner.setup(algo, tasks.sample(meta_batch_size), sampler_cls=LocalSampler, n_workers=meta_batch_size, worker_class=RL2Worker, worker_args=dict(n_episodes_per_trial=episode_per_task)) runner.train(n_epochs=n_epochs, batch_size=episode_per_task * max_episode_length * meta_batch_size)
def test_visualization(self): envs = ['CartPole-v0', 'CartPole-v1'] mt_env = self._init_multi_env_wrapper(envs) mt_env.visualize() gym_env = GymEnv('CartPole-v0') assert gym_env.render_modes == mt_env.render_modes mode = gym_env.render_modes[0] assert gym_env.render(mode) == mt_env.render(mode)
def maml_trpo_metaworld_ml45(ctxt, seed, epochs, episodes_per_task, meta_batch_size): """Set up environment and algorithm and run the task. Args: ctxt (ExperimentContext): The experiment configuration used by :class:`~LocalRunner` to create the :class:`~Snapshotter`. seed (int): Used to seed the random number generator to produce determinism. epochs (int): Number of training epochs. episodes_per_task (int): Number of episodes per epoch per task for training. meta_batch_size (int): Number of tasks sampled per batch. """ set_seed(seed) env = normalize(GymEnv(mwb.ML45.get_train_tasks()), expected_action_scale=10.) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(100, 100), hidden_nonlinearity=torch.tanh, output_nonlinearity=None, ) value_function = LinearFeatureBaseline(env_spec=env.spec) max_episode_length = 100 test_task_names = mwb.ML45.get_test_tasks().all_task_names test_tasks = [ normalize(GymEnv(mwb.ML45.from_task(task)), expected_action_scale=10.) for task in test_task_names ] test_sampler = EnvPoolSampler(test_tasks) meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler, max_episode_length=max_episode_length, n_test_tasks=len(test_task_names)) runner = LocalRunner(ctxt) algo = MAMLTRPO(env=env, policy=policy, value_function=value_function, max_episode_length=max_episode_length, meta_batch_size=meta_batch_size, discount=0.99, gae_lambda=1., inner_lr=0.1, num_grad_updates=1, meta_evaluator=meta_evaluator) runner.setup(algo, env) runner.train(n_epochs=epochs, batch_size=episodes_per_task * max_episode_length)
def test_done_resets_step_cnt(): env = GymEnv('MountainCar-v0') max_episode_length = env.spec.max_episode_length env.reset() for _ in range(max_episode_length): es = env.step(env.action_space.sample()) if es.last: break assert env._step_cnt is None
def test_output_shape(self, obs_dim, action_dim): env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) qf = ContinuousMLPQFunction(env_spec=env.spec) env.reset() obs = env.step(1).observation obs = obs.flatten() act = np.full(action_dim, 0.5).flatten() outputs = qf.get_qval([obs], [act]) assert outputs.shape == (1, 1)
def rl2_trpo_halfcheetah(ctxt, seed, max_episode_length, meta_batch_size, n_epochs, episode_per_task): """Train TRPO with HalfCheetah environment. Args: ctxt (ExperimentContext): The experiment configuration used by :class:`~Trainer` to create the :class:`~Snapshotter`. seed (int): Used to seed the random number generator to produce determinism. max_episode_length (int): Maximum length of a single episode. meta_batch_size (int): Meta batch size. n_epochs (int): Total number of epochs for training. episode_per_task (int): Number of training episode per task. """ set_seed(seed) with TFTrainer(snapshot_config=ctxt) as trainer: tasks = task_sampler.SetTaskSampler( HalfCheetahVelEnv, wrapper=lambda env, _: RL2Env( GymEnv(env, max_episode_length=max_episode_length))) env_spec = RL2Env( GymEnv(HalfCheetahVelEnv(), max_episode_length=max_episode_length)).spec policy = GaussianGRUPolicy(name='policy', hidden_dim=64, env_spec=env_spec, state_include_action=False) baseline = LinearFeatureBaseline(env_spec=env_spec) algo = RL2TRPO(meta_batch_size=meta_batch_size, task_sampler=tasks, env_spec=env_spec, policy=policy, baseline=baseline, episodes_per_trial=episode_per_task, discount=0.99, max_kl_step=0.01, optimizer=ConjugateGradientOptimizer, optimizer_args=dict(hvp_approach=FiniteDifferenceHVP( base_eps=1e-5))) trainer.setup(algo, tasks.sample(meta_batch_size), sampler_cls=LocalSampler, n_workers=meta_batch_size, worker_class=RL2Worker, worker_args=dict(n_episodes_per_trial=episode_per_task)) trainer.train(n_epochs=n_epochs, batch_size=episode_per_task * max_episode_length * meta_batch_size)
def maml_ppo_half_cheetah_dir(ctxt, seed, epochs, episodes_per_task, meta_batch_size): """Set up environment and algorithm and run the task. Args: ctxt (ExperimentContext): The experiment configuration used by :class:`~LocalRunner` to create the :class:`~Snapshotter`. seed (int): Used to seed the random number generator to produce determinism. epochs (int): Number of training epochs. episodes_per_task (int): Number of episodes per epoch per task for training. meta_batch_size (int): Number of tasks sampled per batch. """ set_seed(seed) env = normalize(GymEnv(HalfCheetahDirEnv()), expected_action_scale=10.) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(64, 64), hidden_nonlinearity=torch.tanh, output_nonlinearity=None, ) value_function = GaussianMLPValueFunction(env_spec=env.spec, hidden_sizes=(32, 32), hidden_nonlinearity=torch.tanh, output_nonlinearity=None) max_episode_length = 100 task_sampler = SetTaskSampler(lambda: normalize( GymEnv(HalfCheetahDirEnv()), expected_action_scale=10.)) meta_evaluator = MetaEvaluator(test_task_sampler=task_sampler, max_episode_length=max_episode_length, n_test_tasks=1, n_test_episodes=10) runner = LocalRunner(ctxt) algo = MAMLPPO(env=env, policy=policy, value_function=value_function, max_episode_length=max_episode_length, meta_batch_size=meta_batch_size, discount=0.99, gae_lambda=1., inner_lr=0.1, num_grad_updates=1, meta_evaluator=meta_evaluator) runner.setup(algo, env) runner.train(n_epochs=epochs, batch_size=episodes_per_task * max_episode_length)
def test_get_action_dict_space(self): env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box')) policy = GaussianMLPPolicy(env_spec=env.spec) obs = env.reset()[0] action, _ = policy.get_action(obs) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs, obs]) for action in actions: assert env.action_space.contains(action)
def test_baseline(self): """Test the baseline initialization.""" box_env = GymEnv(DummyBoxEnv()) deterministic_mlp_baseline = ContinuousMLPBaseline(env_spec=box_env) gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env) self.sess.run(tf.compat.v1.global_variables_initializer()) deterministic_mlp_baseline.get_param_values() gaussian_mlp_baseline.get_param_values() box_env.close()
def test_unflattened_input(self): env = GymEnv(DummyBoxEnv(obs_dim=(2, 2))) cmb = ContinuousMLPBaseline(env_spec=env.spec) env.reset() es = env.step(1) obs, rewards = es.observation, es.reward train_paths = [{'observations': [obs], 'returns': [rewards]}] cmb.fit(train_paths) paths = {'observations': [obs]} prediction = cmb.predict(paths) assert np.allclose(0., prediction)
def maml_trpo_metaworld_ml1_push(ctxt, seed, epochs, episodes_per_task, meta_batch_size): """Set up environment and algorithm and run the task. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. epochs (int): Number of training epochs. episodes_per_task (int): Number of episodes per epoch per task for training. meta_batch_size (int): Number of tasks sampled per batch. """ set_seed(seed) env = normalize(GymEnv(mwb.ML1.get_train_tasks('push-v1'), max_episode_length=150), expected_action_scale=10.) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(100, 100), hidden_nonlinearity=torch.tanh, output_nonlinearity=None, ) value_function = GaussianMLPValueFunction(env_spec=env.spec, hidden_sizes=[32, 32], hidden_nonlinearity=torch.tanh, output_nonlinearity=None) max_episode_length = env.spec.max_episode_length test_sampler = SetTaskSampler( lambda: normalize(GymEnv(mwb.ML1.get_test_tasks('push-v1')))) meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler, max_episode_length=max_episode_length) runner = LocalRunner(ctxt) algo = MAMLTRPO(env=env, policy=policy, value_function=value_function, meta_batch_size=meta_batch_size, discount=0.99, gae_lambda=1., inner_lr=0.1, num_grad_updates=1, meta_evaluator=meta_evaluator) runner.setup(algo, env) runner.train(n_epochs=epochs, batch_size=episodes_per_task * max_episode_length)
def setup_method(self): super().setup_method() self.max_episode_length = 100 self.meta_batch_size = 10 self.episode_per_task = 4 self.tasks = task_sampler.SetTaskSampler( lambda: RL2Env(normalize(GymEnv(HalfCheetahDirEnv())))) self.env_spec = RL2Env(normalize(GymEnv(HalfCheetahDirEnv()))).spec self.policy = GaussianGRUPolicy(env_spec=self.env_spec, hidden_dim=64, state_include_action=False) self.baseline = LinearFeatureBaseline(env_spec=self.env_spec)
def setup_method(self): super().setup_method() self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()), is_image=True) self.qf = DiscreteCNNQFunction(env_spec=self.env.spec, filters=((1, (1, 1)), ), strides=(1, ), dueling=False) self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset()
class TestQfDerivedPolicy(TfGraphTestCase): def setup_method(self): super().setup_method() self.env = GymEnv(DummyDiscreteEnv()) self.qf = SimpleQFunction(self.env.spec) self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec, qf=self.qf) self.sess.run(tf.compat.v1.global_variables_initializer()) self.env.reset() def test_discrete_qf_argmax_policy(self): obs = self.env.step(1).observation action, _ = self.policy.get_action(obs) assert self.env.action_space.contains(action) actions, _ = self.policy.get_actions([obs]) for action in actions: assert self.env.action_space.contains(action) def test_get_param(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') assert self.policy.get_param_values() == return_var.eval() def test_is_pickleable(self): with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) obs = self.env.step(1).observation action1, _ = self.policy.get_action(obs) p = pickle.dumps(self.policy) with tf.compat.v1.Session(graph=tf.Graph()): policy_pickled = pickle.loads(p) action2, _ = policy_pickled.get_action(obs) assert action1 == action2 def test_does_not_support_dict_obs_space(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='discrete')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) def test_invalid_action_spaces(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec) DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
def test_dqn_cartpole_pickle(self): """Test DQN with CartPole environment.""" deterministic.set_seed(100) with LocalTFRunner(snapshot_config, sess=self.sess) as runner: n_epochs = 10 steps_per_epoch = 10 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e4)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64)) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=epilson_greedy_policy, replay_buffer=replay_buffer, max_episode_length=100, qf_lr=1e-4, discount=1.0, min_buffer_size=int(1e3), double_q=False, n_train_steps=500, grad_norm_clipping=5.0, steps_per_epoch=steps_per_epoch, target_network_update_freq=1, buffer_batch_size=32) runner.setup(algo, env) with tf.compat.v1.variable_scope( 'DiscreteMLPQFunction/mlp/hidden_0', reuse=True): bias = tf.compat.v1.get_variable('bias') # assign it to all one old_bias = tf.ones_like(bias).eval() bias.load(old_bias) h = pickle.dumps(algo) with tf.compat.v1.Session(graph=tf.Graph()): pickle.loads(h) with tf.compat.v1.variable_scope( 'DiscreteMLPQFunction/mlp/hidden_0', reuse=True): new_bias = tf.compat.v1.get_variable('bias') new_bias = new_bias.eval() assert np.array_equal(old_bias, new_bias) env.close()
def test_get_action_dict_space(self): env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box')) policy = GaussianGRUPolicy(env_spec=env.spec, hidden_dim=4, state_include_action=False) policy.reset(do_resets=None) obs = env.reset()[0] action, _ = policy.get_action(obs) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs, obs]) for action in actions: assert env.action_space.contains(action)
def test_get_action(self, obs_dim, action_dim, hidden_dim): env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) policy = GaussianGRUPolicy(env_spec=env.spec, hidden_dim=hidden_dim, state_include_action=False) policy.reset(do_resets=None) obs = env.reset()[0] action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs.flatten()]) for action in actions: assert env.action_space.contains(action)
def test_get_action(self, obs_dim, action_dim): env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) policy = GaussianMLPPolicy(env_spec=env.spec) env.reset() obs = env.step(1).observation action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions( [obs.flatten(), obs.flatten(), obs.flatten()]) for action in actions: assert env.action_space.contains(action)
def test_build(self, obs_dim, action_dim): env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) qf = DiscreteMLPDuelingQFunction(env_spec=env.spec) env.reset() obs = env.step(1).observation output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [obs]}) input_var = tf.compat.v1.placeholder(tf.float32, shape=(None, ) + obs_dim) q_vals = qf.build(input_var, 'another') output2 = self.sess.run(q_vals, feed_dict={input_var: [obs]}) assert np.array_equal(output1, output2)
def test_get_action(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test get_action function.""" env = GymEnv(DummyDiscretePixelEnv(), is_image=True) policy = CategoricalCNNPolicy(env_spec=env.spec, image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, hidden_sizes=hidden_sizes) env.reset() obs = env.step(1).observation action, _ = policy.get_action(obs) assert env.action_space.contains(action)
def test_get_action(self, obs_dim, action_dim, hidden_dim): env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) policy = CategoricalLSTMPolicy(env_spec=env.spec, hidden_dim=hidden_dim, state_include_action=False) policy.reset() obs = env.reset()[0] action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs.flatten()]) for action in actions: assert env.action_space.contains(action)
def test_get_action_img_obs(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test get_action function with akro.Image observation space.""" env = GymEnv(self._initialize_obs_env(DummyDiscretePixelEnv()), is_image=True) policy = CategoricalCNNPolicy(env=env, kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, hidden_sizes=hidden_sizes) env.reset() obs = env.step(1).observation action, _ = policy.get_action(obs) assert env.action_space.contains(action)
def test_flattened_image_input(self): env = GymEnv(DummyDiscretePixelEnv(), is_image=True) gcb = GaussianCNNBaseline(env_spec=env.spec, filters=((3, (3, 3)), (6, (3, 3))), strides=(1, 1), padding='SAME', hidden_sizes=(32, )) env.reset() es = env.step(1) obs, rewards = es.observation, es.reward train_paths = [{'observations': [obs.flatten()], 'returns': [rewards]}] gcb.fit(train_paths) paths = {'observations': [obs.flatten()]} prediction = gcb.predict(paths) assert np.allclose(0., prediction)
def test_visualization(): inner_env = gym.make('MountainCar-v0') env = GymEnv(inner_env) env.reset() env.visualize() assert inner_env.metadata['render.modes'] == env.render_modes env.step(env.action_space.sample())
def test_build(self, obs_dim, action_dim): env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) policy = GaussianMLPPolicy(env_spec=env.spec) obs = env.reset()[0] state_input = tf.compat.v1.placeholder(tf.float32, shape=(None, None, policy.input_dim)) dist_sym = policy.build(state_input, name='dist_sym').dist dist_sym2 = policy.build(state_input, name='dist_sym2').dist output1 = self.sess.run([dist_sym.loc], feed_dict={state_input: [[obs.flatten()]]}) output2 = self.sess.run([dist_sym2.loc], feed_dict={state_input: [[obs.flatten()]]}) assert np.array_equal(output1, output2)
def dqn_cartpole(ctxt=None, seed=24): """Train DQN with CartPole-v0 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. """ set_seed(seed) runner = Trainer(ctxt) n_epochs = 100 steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4) sampler = LocalSampler(agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.9, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size) env.close()
def test_does_not_support_dict_obs_space(self): """Test that policy raises error if passed a dict obs space.""" env = GymEnv(DummyDictEnv(act_space_type='discrete')) with pytest.raises(ValueError): qf = SimpleQFunction(env.spec, name='does_not_support_dict_obs_space') DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
def trpo_swimmer(ctxt=None, seed=1, batch_size=4000): """Train TRPO with Swimmer-v2 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. batch_size (int): Number of timesteps to use in each training step. """ set_seed(seed) with LocalTFRunner(ctxt) as runner: env = GymEnv('Swimmer-v2') policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32)) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO(env_spec=env.spec, policy=policy, baseline=baseline, max_episode_length=500, discount=0.99, max_kl_step=0.01) runner.setup(algo, env) runner.train(n_epochs=40, batch_size=batch_size)
def test_categorical_policies(self, policy_cls): with TFTrainer(snapshot_config, sess=self.sess) as trainer: env = normalize(GymEnv('CartPole-v0', max_episode_length=100)) policy = policy_cls(name='policy', env_spec=env.spec) baseline = LinearFeatureBaseline(env_spec=env.spec) sampler = LocalSampler( agents=policy, envs=env, max_episode_length=env.spec.max_episode_length, is_tf_worker=True) algo = TRPO( env_spec=env.spec, policy=policy, baseline=baseline, sampler=sampler, discount=0.99, max_kl_step=0.01, optimizer=ConjugateGradientOptimizer, optimizer_args=dict(hvp_approach=FiniteDifferenceHVP( base_eps=1e-5)), ) trainer.setup(algo, env) trainer.train(n_epochs=1, batch_size=4000) env.close()