예제 #1
0
파일: agents_test.py 프로젝트: srsohn/mtsgi
def test_rlrl(env_id, graph_param, seed, num_envs):
    adapt_envs, test_envs = make_envs(env_id, graph_param, seed, num_envs)
    environment_spec = specs.make_environment_spec(adapt_envs)

    if env_id in {'playground', 'mining'}:
        # spatial observation.
        network = snt_utils.CombinedNN(environment_spec.actions)
    else:
        network = snt_utils.RecurrentNN(environment_spec.actions)

    # Create meta agent.
    meta_agent = agents.RLRL(environment_spec=environment_spec,
                             network=network,
                             n_step_horizon=10,
                             minibatch_size=10)

    # Run meta loop.
    meta_loop = environment_loop.EnvironmentMetaLoop(
        adapt_environment=adapt_envs,
        test_environment=test_envs,
        meta_agent=meta_agent,
        label='meta_train'  # XXX meta train
    )

    meta_loop.run(
        num_trials=3,
        num_adapt_steps=100,  # should be greater than TimeLimit
        num_test_episodes=1,
        num_trial_splits=1)
예제 #2
0
  def test_td3_fd(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    td3_network = td3.make_networks(spec)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=1)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)
    counter = counting.Counter()
    agent = lfd.TD3fD(
        spec=spec,
        td3_network=td3_network,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=20)
예제 #3
0
  def actor(self, random_key, replay,
            variable_source, counter,
            actor_id):
    """The actor process."""
    adder = self._builder.make_adder(replay)

    environment_key, actor_key = jax.random.split(random_key)
    # Create environment and policy core.

    # Environments normally require uint32 as a seed.
    environment = self._environment_factory(
        utils.sample_uint32(environment_key))

    networks = self._network_factory(specs.make_environment_spec(environment))
    policy_network = self._policy_network(networks)
    actor = self._builder.make_actor(actor_key, policy_network, adder,
                                     variable_source)

    # Create logger and counter.
    counter = counting.Counter(counter, 'actor')
    # Only actor #0 will write to bigtable in order not to spam it too much.
    logger = self._actor_logger_fn(actor_id)
    # Create the loop to connect environment and agent.
    return environment_loop.EnvironmentLoop(environment, actor, counter,
                                            logger, observers=self._observers)
예제 #4
0
    def test_ppo_nest_safety(self):
        # Create a fake environment with nested observations.
        environment = fakes.NestedDiscreteEnvironment(num_observations={
            'lat': 2,
            'long': 3
        },
                                                      num_actions=5,
                                                      obs_shape=(10, 5),
                                                      obs_dtype=np.float32,
                                                      episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_haiku_networks(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
예제 #5
0
    def test_r2d2(self):
        # Create a fake environment to test with.
        # TODO(b/152596848): Allow R2D2 to deal with integer observations.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 4),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        agent = r2d2.R2D2(
            environment_spec=spec,
            network=SimpleNetwork(spec.actions),
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            burn_in_length=2,
            trace_length=6,
            replay_period=4,
            checkpoint=False,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=5)
예제 #6
0
    def test_r2d2(self):
        # Create a fake environment to test with.
        environment = fakes.fake_atari_wrapped(oar_wrapper=True)
        spec = specs.make_environment_spec(environment)

        config = r2d2.R2D2Config(batch_size=1,
                                 trace_length=5,
                                 sequence_period=1,
                                 samples_per_insert=0.,
                                 min_replay_size=1,
                                 burn_in_length=1)

        counter = counting.Counter()
        agent = r2d2.R2D2(
            spec=spec,
            networks=r2d2.make_atari_networks(config.batch_size, spec),
            config=config,
            seed=0,
            counter=counter,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
예제 #7
0
    def test_mcts(self):
        # Create a fake environment to test with.
        num_actions = 5
        environment = fakes.DiscreteEnvironment(num_actions=num_actions,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50]),
            networks.PolicyValueHead(spec.actions.num_values),
        ])
        model = simulator.Simulator(environment)
        optimizer = snt.optimizers.Adam(1e-3)

        # Construct the agent.
        agent = mcts.MCTS(environment_spec=spec,
                          network=network,
                          model=model,
                          optimizer=optimizer,
                          n_step=1,
                          discount=1.,
                          replay_capacity=100,
                          num_simulations=10,
                          batch_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
예제 #8
0
    def test_impala(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x, s):
            model = MyNetwork(spec.actions.num_values)
            return model(x, s)

        def initial_state_fn(batch_size: Optional[int] = None):
            model = MyNetwork(spec.actions.num_values)
            return model.initial_state(batch_size)

        # Construct the agent.
        agent = impala.IMPALA(
            environment_spec=spec,
            network=network,
            initial_state_fn=initial_state_fn,
            sequence_length=3,
            sequence_period=3,
            batch_size=6,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
예제 #9
0
파일: run_sac.py 프로젝트: novatig/acme
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = make_environment(FLAGS.task_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = make_networks(environment_spec)

    # Construct the agent.
    agent = sac.SAC(
        environment_spec=environment_spec,
        policy_network=agent_networks['policy'],
        critic_network=agent_networks['critic'],
        encoder_network=agent_networks['observation'],
        #sigma=0.3,  # pytype: disable=wrong-arg-types
    )

    # Create the environment loop used for training.
    train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop')

    # Create the evaluation policy.
    eval_policy = agent.behavior_network

    # Create the evaluation actor and loop.
    eval_actor = actors.FeedForwardActor(policy_network=eval_policy)
    eval_env = make_environment(FLAGS.task_name)
    eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop')

    for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval):
        train_loop.run(num_episodes=FLAGS.num_episodes_per_eval)
예제 #10
0
def generate_train_data(
    task_name,
    behavior_policy_param,
    dataset_path,
    environment,
    dataset_size,
    batch_size,
    shuffle,
    include_terminal=False,  # Include terminal absorbing state.
    ignore_d_tm1=False  # Set d_tm1 as constant 1.0 if True.
):
    environment_spec = specs.make_environment_spec(environment)
    with tf.device('CPU'):
        behavior_policy_net = load_policy_net(
            task_name=task_name,
            params=behavior_policy_param,
            environment_spec=environment_spec,
            dataset_path=dataset_path)

        logging.info('start generating transitions')
        dataset = _generate_data(behavior_policy_net,
                                 environment,
                                 dataset_size,
                                 batch_size,
                                 shuffle,
                                 include_terminal=include_terminal,
                                 ignore_d_tm1=ignore_d_tm1)
        logging.info('end generating transitions')
    return dataset
예제 #11
0
    def test_step(self):
        for param in TEST_PARAMS:
            env = self._make_parallel_environments(
                env_id=param.env_id,
                num_envs=param.num_envs,
                graph_param=param.graph_param)

            env_spec = specs.make_environment_spec(env)

            # Create random actor.
            actor = agents.RandomActor(env_spec)

            env.reset_task(task_index=0)
            ts = env.reset()

            # Step.
            action = actor.select_action(ts.observation)
            ts = env.step(action)

            self.assertIsInstance(ts, dm_env.TimeStep)
            self.assertFalse(any(ts.first()))

            ob = ts.observation
            self.assertEqual(ob['mask'].shape,
                             (param.num_envs, param.action_dim))
            self.assertEqual(ob['completion'].shape,
                             (param.num_envs, param.action_dim))
            self.assertEqual(ob['eligibility'].shape,
                             (param.num_envs, param.action_dim))
예제 #12
0
    def evaluator(
        random_key: types.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)
        networks = network_factory(environment_spec)
        policy = policy_factory(networks, environment_spec, True)
        actor = make_actor(actor_key, policy, environment_spec,
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = logger_factory('evaluator', 'actor_steps', 0)

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
예제 #13
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
예제 #14
0
    def test_d4pg(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        networks = make_networks(spec)

        config = d4pg.D4PGConfig(
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            samples_per_insert_tolerance_rate=float('inf'))
        counter = counting.Counter()
        agent = d4pg.D4PG(spec,
                          networks,
                          config=config,
                          random_seed=0,
                          counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
예제 #15
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
예제 #16
0
파일: run_sac.py 프로젝트: bimec/acme
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = sac.make_networks(environment_spec)

    # Construct the agent.
    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        seed=FLAGS.seed)
    agent = sac.SAC(environment_spec, agent_networks, config=config)

    # Create the environment loop used for training.
    train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop')
    # Create the evaluation actor and loop.
    eval_actor = agent.builder.make_actor(
        policy_network=sac.apply_policy_and_sample(agent_networks,
                                                   eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop')

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
예제 #17
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
예제 #18
0
파일: agent_test.py 프로젝트: deepmind/acme
    def test_mompo(self, distributional_critic):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Create objectives.
        reward_objectives, qvalue_objectives = make_objectives()
        num_critic_heads = len(reward_objectives)

        # Create networks.
        agent_networks = make_networks(
            spec.actions,
            num_critic_heads=num_critic_heads,
            distributional_critic=distributional_critic)

        # Construct the agent.
        agent = mompo.MultiObjectiveMPO(
            reward_objectives,
            qvalue_objectives,
            spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
예제 #19
0
    def test_dqfd(self):
        # Create a fake environment to test with.
        # TODO(b/152596848): Allow DQN to deal with integer observations.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Build demonstrations.
        dummy_action = np.zeros((), dtype=np.int32)
        recorder = bsuite_demonstrations.DemonstrationRecorder()
        timestep = environment.reset()
        while timestep.step_type is not dm_env.StepType.LAST:
            recorder.step(timestep, dummy_action)
            timestep = environment.step(dummy_action)
        recorder.step(timestep, dummy_action)
        recorder.record_episode()

        # Construct the agent.
        agent = dqfd.DQfD(environment_spec=spec,
                          network=_make_network(spec.actions),
                          demonstration_dataset=recorder.make_tf_dataset(),
                          demonstration_ratio=0.5,
                          batch_size=10,
                          samples_per_insert=2,
                          min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=10)
예제 #20
0
    def evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        networks = network_factory(specs.make_environment_spec(environment))

        actor = make_actor(actor_key, policy_factory(networks),
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        if logger_fn is not None:
            logger = logger_fn('evaluator', 'actor_steps')
        else:
            logger = loggers.make_default_logger('evaluator',
                                                 log_to_bigtable,
                                                 steps_key='actor_steps')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
예제 #21
0
    def run_ppo_agent(self, make_networks_fn):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_networks_fn(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        counter = counting.Counter()
        logger = loggers.make_default_logger('learner')
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
            counter=counter,
            logger=logger,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
예제 #22
0
파일: agent_test.py 프로젝트: deepmind/acme
    def test_train(self, policy_loss_coeff_fn):
        seed = 0
        num_iterations = 5
        batch_size = 64
        grad_updates_per_batch = 1

        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True,
                                                  action_dim=6)
        spec = specs.make_environment_spec(environment)

        # Construct the learner.
        networks = crr.make_networks(spec,
                                     policy_layer_sizes=(8, 8),
                                     critic_layer_sizes=(8, 8))
        key = jax.random.PRNGKey(seed)
        dataset = fakes.transition_iterator(environment)
        learner = crr.CRRLearner(networks,
                                 key,
                                 discount=0.95,
                                 target_update_period=2,
                                 policy_loss_coeff_fn=policy_loss_coeff_fn,
                                 iterator=dataset(batch_size *
                                                  grad_updates_per_batch),
                                 policy_optimizer=optax.adam(1e-4),
                                 critic_optimizer=optax.adam(1e-4),
                                 grad_updates_per_batch=grad_updates_per_batch)

        # Train the learner.
        for _ in range(num_iterations):
            learner.step()
예제 #23
0
파일: run_d4pg.py 프로젝트: pchtsp/acme
def main(_):
  # Create an environment, grab the spec, and use it to create networks.
  environment = make_environment()
  environment_spec = specs.make_environment_spec(environment)
  agent_networks = make_networks(environment_spec.actions)

  # Construct the agent.
  agent = d4pg.D4PG(
      environment_spec=environment_spec,
      policy_network=agent_networks['policy'],
      critic_network=agent_networks['critic'],
      observation_network=agent_networks['observation'],  # pytype: disable=wrong-arg-types
  )

  # Create the environment loop used for training.
  train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop')

  # Create the evaluation policy.
  eval_policy = snt.Sequential([
      agent_networks['observation'],
      agent_networks['policy'],
  ])

  # Create the evaluation actor and loop.
  eval_actor = actors.FeedForwardActor(policy_network=eval_policy)
  eval_env = make_environment()
  eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop')

  for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval):
    train_loop.run(num_episodes=FLAGS.num_episodes_per_eval)
    eval_loop.run(num_episodes=1)
예제 #24
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
예제 #25
0
def main(_):
    # Create an environment and environment model.
    environment, model = make_env_and_model(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment_spec = specs.make_environment_spec(environment)

    # Create the network and optimizer.
    network = make_network(environment_spec.actions)
    optimizer = snt.optimizers.Adam(learning_rate=1e-3)

    # Construct the agent.
    agent = mcts.MCTS(
        environment_spec=environment_spec,
        model=model,
        network=network,
        optimizer=optimizer,
        discount=0.99,
        replay_capacity=10000,
        n_step=1,
        batch_size=16,
        num_simulations=50,
    )

    # Run the environment loop.
    loop = acme.EnvironmentLoop(environment, agent)
    loop.run(num_episodes=environment.bsuite_num_episodes)  # pytype: disable=attribute-error
예제 #26
0
    def test_train(self):
        seed = 0
        num_iterations = 2
        batch_size = 64

        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True,
                                                  action_dim=6)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        networks = cql.make_networks(spec, hidden_layer_sizes=(8, 8))
        dataset = fakes.transition_iterator(environment)
        key = jax.random.PRNGKey(seed)
        learner = cql.CQLLearner(batch_size,
                                 networks,
                                 key,
                                 demonstrations=dataset(batch_size),
                                 policy_optimizer=optax.adam(3e-5),
                                 critic_optimizer=optax.adam(3e-4),
                                 cql_lagrange_threshold=1.,
                                 target_entropy=0.1,
                                 num_sgd_steps_per_step=1)

        # Train the agent
        for _ in range(num_iterations):
            learner.step()
예제 #27
0
 def replay(self):
   """The replay storage."""
   dummy_seed = 1
   environment_spec = (
       self._environment_spec or
       specs.make_environment_spec(self._environment_factory(dummy_seed)))
   return self._builder.make_replay_tables(environment_spec)
예제 #28
0
    def test_make_dataset_with_sequence_length_and_batch_size(self):
        sequence_length = 6
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            sequence_length=sequence_length)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(
                batch_size,
                sequence_length,
            ) + spec.shape,
                                 dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(batch_size, sequence_length),
                                        dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
예제 #29
0
def save_ilp(environment, filename):
    environment_spec = specs.make_environment_spec(environment)
    # Create actor
    meta_agent = msgi.MSGIRandom(environment_spec=environment_spec,
                                 num_adapt_steps=NUM_ADAPT_STEPS,
                                 visualize=False,
                                 directory=None,
                                 environment_id=ENV_NAME)
    agent = meta_agent.instantiate_adapt_agent()  # MSGIActor() (Fast agent)

    # Set task
    environment.reset_task(task_index=TASK_IDX)
    meta_agent.reset_agent(environment=environment)
    # Run RL loop
    ### Reset
    timestep = environment.reset()
    agent.observe_first(timestep)

    ### Loop
    cumulative_reward = 0.
    for step_count in range(NUM_ADAPT_STEPS):
        action = agent.select_action(timestep.observation)
        print(action)
        timestep = environment.step(action)

        agent.observe(action, next_timestep=timestep)

        # Book-keeping.
        cumulative_reward += sum(timestep.reward)  # TODO: needs check.

    meta_agent._ilp.save(filename)
예제 #30
0
파일: agents_test.py 프로젝트: srsohn/mtsgi
def test_msgi(env_id, graph_param, seed, num_envs):
    adapt_envs, test_envs = make_envs(env_id, graph_param, seed, num_envs)
    environment_spec = specs.make_environment_spec(adapt_envs)

    # Create meta agent.
    meta_agent = agents.MSGI(environment_spec=environment_spec,
                             num_adapt_steps=20,
                             num_trial_splits=5,
                             environment_id=env_id,
                             branch_neccessary_first=True,
                             exploration='random',
                             temp=200,
                             w_a=3.0,
                             beta_a=8.0,
                             ep_or=0.8,
                             temp_or=2.0)

    # Run meta loop.
    meta_loop = environment_loop.EnvironmentMetaLoop(
        adapt_environment=adapt_envs,
        test_environment=test_envs,
        meta_agent=meta_agent,
        label='meta_eval')

    meta_loop.run(
        num_trials=3,
        num_adapt_steps=20,  # should be greater than TimeLimit
        num_test_episodes=1,
        num_trial_splits=5)