def main(_): environment_factory = lp_utils.partial_kwargs(helpers.make_environment, domain_name=FLAGS.domain, task_name=FLAGS.task) batch_size = 32 sequence_length = 20 gradient_steps_per_actor_step = 1.0 samples_per_insert = (gradient_steps_per_actor_step * batch_size * sequence_length) num_actors = 1 program = svg0_prior.DistributedSVG0( environment_factory=environment_factory, network_factory=lp_utils.partial_kwargs( svg0_prior.make_default_networks), batch_size=batch_size, sequence_length=sequence_length, samples_per_insert=samples_per_insert, entropy_regularizer_cost=1e-4, max_replay_size=int(2e6), target_update_period=250, num_actors=num_actors).build() lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
def main(_): environment_factory = lp_utils.partial_kwargs(helpers.make_environment, task=FLAGS.task) program = d4pg.DistributedD4PG(environment_factory=environment_factory, network_factory=lp_utils.partial_kwargs( helpers.make_networks), num_actors=2).build() lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
def main(_): environment_factory = lp_utils.partial_kwargs(helpers.make_environment, task=FLAGS.task) program = d4pg.DistributedD4PG(environment_factory=environment_factory, network_factory=lp_utils.partial_kwargs( helpers.make_networks), num_actors=2).build() lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING)
def test_agent(self, distributional_critic): # Create objectives. reward_objectives, qvalue_objectives = make_objectives() network_factory = lp_utils.partial_kwargs( make_networks, distributional_critic=distributional_critic) agent = mompo.DistributedMultiObjectiveMPO( reward_objectives, qvalue_objectives, environment_factory=make_environment, network_factory=network_factory, num_actors=2, batch_size=32, min_replay_size=32, max_replay_size=1000, ) program = agent.build() (learner_node, ) = program.groups['learner'] learner_node.disable_run() lp.launch(program, launch_type='test_mt') learner: acme.Learner = learner_node.create_handle().dereference() for _ in range(5): learner.step()
def test_partial_kwargs(self): def foo(a, b, c=2): return a, b, c def bar(a, b): return a, b # Override the default values. The last two should be no-ops. foo1 = lp_utils.partial_kwargs(foo, c=1) foo2 = lp_utils.partial_kwargs(foo) bar1 = lp_utils.partial_kwargs(bar) # Check that we raise errors on overriding kwargs with no default values with self.assertRaises(ValueError): lp_utils.partial_kwargs(foo, a=2) # CHeck the we raise if we try to override a kwarg that doesn't exist. with self.assertRaises(ValueError): lp_utils.partial_kwargs(foo, d=2) # Make sure we get back the correct values. self.assertEqual(foo1(1, 2), (1, 2, 1)) self.assertEqual(foo2(1, 2), (1, 2, 2)) self.assertEqual(bar1(1, 2), (1, 2))