Beispiel #1
0
def main(_):
    environment_factory = lp_utils.partial_kwargs(helpers.make_environment,
                                                  domain_name=FLAGS.domain,
                                                  task_name=FLAGS.task)

    batch_size = 32
    sequence_length = 20
    gradient_steps_per_actor_step = 1.0
    samples_per_insert = (gradient_steps_per_actor_step * batch_size *
                          sequence_length)
    num_actors = 1

    program = svg0_prior.DistributedSVG0(
        environment_factory=environment_factory,
        network_factory=lp_utils.partial_kwargs(
            svg0_prior.make_default_networks),
        batch_size=batch_size,
        sequence_length=sequence_length,
        samples_per_insert=samples_per_insert,
        entropy_regularizer_cost=1e-4,
        max_replay_size=int(2e6),
        target_update_period=250,
        num_actors=num_actors).build()

    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Beispiel #2
0
def main(_):
    environment_factory = lp_utils.partial_kwargs(helpers.make_environment,
                                                  task=FLAGS.task)

    program = d4pg.DistributedD4PG(environment_factory=environment_factory,
                                   network_factory=lp_utils.partial_kwargs(
                                       helpers.make_networks),
                                   num_actors=2).build()

    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Beispiel #3
0
def main(_):
    environment_factory = lp_utils.partial_kwargs(helpers.make_environment,
                                                  task=FLAGS.task)

    program = d4pg.DistributedD4PG(environment_factory=environment_factory,
                                   network_factory=lp_utils.partial_kwargs(
                                       helpers.make_networks),
                                   num_actors=2).build()

    lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING)
    def test_agent(self, distributional_critic):
        # Create objectives.
        reward_objectives, qvalue_objectives = make_objectives()

        network_factory = lp_utils.partial_kwargs(
            make_networks, distributional_critic=distributional_critic)

        agent = mompo.DistributedMultiObjectiveMPO(
            reward_objectives,
            qvalue_objectives,
            environment_factory=make_environment,
            network_factory=network_factory,
            num_actors=2,
            batch_size=32,
            min_replay_size=32,
            max_replay_size=1000,
        )
        program = agent.build()

        (learner_node, ) = program.groups['learner']
        learner_node.disable_run()

        lp.launch(program, launch_type='test_mt')

        learner: acme.Learner = learner_node.create_handle().dereference()

        for _ in range(5):
            learner.step()
Beispiel #5
0
  def test_partial_kwargs(self):

    def foo(a, b, c=2):
      return a, b, c

    def bar(a, b):
      return a, b

    # Override the default values. The last two should be no-ops.
    foo1 = lp_utils.partial_kwargs(foo, c=1)
    foo2 = lp_utils.partial_kwargs(foo)
    bar1 = lp_utils.partial_kwargs(bar)

    # Check that we raise errors on overriding kwargs with no default values
    with self.assertRaises(ValueError):
      lp_utils.partial_kwargs(foo, a=2)

    # CHeck the we raise if we try to override a kwarg that doesn't exist.
    with self.assertRaises(ValueError):
      lp_utils.partial_kwargs(foo, d=2)

    # Make sure we get back the correct values.
    self.assertEqual(foo1(1, 2), (1, 2, 1))
    self.assertEqual(foo2(1, 2), (1, 2, 2))
    self.assertEqual(bar1(1, 2), (1, 2))