コード例 #1
0
 def test_dm_control_thread(self):
     args = tools.AttrDict(logdir=self.get_temp_dir(),
                           num_runs=1,
                           params=tools.AttrDict(
                               defaults=['dreamer', 'debug'],
                               tasks=['cup_catch'],
                               isolate_envs='thread',
                               max_steps=30),
                           ping_every=0,
                           resume_runs=False)
     train.main(args)
コード例 #2
0
 def test_planet(self):
     args = tools.AttrDict(logdir=self.get_temp_dir(),
                           num_runs=1,
                           params=tools.AttrDict(
                               defaults=['planet', 'debug'],
                               tasks=['dummy'],
                               isolate_envs='none',
                               max_steps=30,
                               planner_horizon=3),
                           ping_every=0,
                           resume_runs=False)
     train.main(args)
コード例 #3
0
 def test_no_value(self):
     args = tools.AttrDict(logdir=self.get_temp_dir(),
                           num_runs=1,
                           params=tools.AttrDict(
                               defaults=['actor', 'debug'],
                               tasks=['dummy'],
                               isolate_envs='none',
                               max_steps=30,
                               imagination_horizon=3),
                           ping_every=0,
                           resume_runs=False)
     train.main(args)
コード例 #4
0
 def test_atari_thread(self):
     args = tools.AttrDict(logdir=self.get_temp_dir(),
                           num_runs=1,
                           params=tools.AttrDict(
                               defaults=['dreamer', 'debug'],
                               tasks=['atari_pong'],
                               isolate_envs='thread',
                               action_head_dist='onehot_score',
                               action_noise_type='epsilon_greedy',
                               max_steps=30),
                           ping_every=0,
                           resume_runs=False)
     train.main(args)
コード例 #5
0
 def test_dreamer(self):
     args = tools.AttrDict(logdir=self.get_temp_dir(),
                           num_runs=1,
                           params=tools.AttrDict(
                               defaults=['dreamer', 'debug'],
                               tasks=['dummy'],
                               isolate_envs='none',
                               max_steps=30,
                               train_planner='policy_sample',
                               test_planner='policy_mode',
                               planner_objective='reward_value',
                               action_head=True,
                               value_head=True,
                               imagination_horizon=3),
                           ping_every=0,
                           resume_runs=False)
     train.main(args)