コード例 #1
0
    def actor_evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
            evaluator_network)
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client = variable_utils.VariableClient(variable_source,
                                                        'policy',
                                                        device='cpu')
        actor = actors.GenericActor(actor_core,
                                    random_key,
                                    variable_client,
                                    backend='cpu')

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
コード例 #2
0
ファイル: run_offline_td3_jax.py プロジェクト: deepmind/acme
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
コード例 #3
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
コード例 #4
0
ファイル: run_cql_jax.py プロジェクト: vishalbelsare/acme
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
コード例 #5
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: networks_lib.FeedForwardNetwork,
      config: dqn_config.DQNConfig,
  ):
    """Initialize the agent."""
    # Data is communicated via reverb replay.
    reverb_replay = replay.make_reverb_prioritized_nstep_replay(
        environment_spec=environment_spec,
        n_step=config.n_step,
        batch_size=config.batch_size,
        max_replay_size=config.max_replay_size,
        min_replay_size=config.min_replay_size,
        priority_exponent=config.priority_exponent,
        discount=config.discount,
    )
    self._server = reverb_replay.server

    optimizer = optax.chain(
        optax.clip_by_global_norm(config.max_gradient_norm),
        optax.adam(config.learning_rate),
    )
    key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed))
    # The learner updates the parameters (and initializes them).
    loss_fn = losses.PrioritizedDoubleQLearning(
        discount=config.discount,
        importance_sampling_exponent=config.importance_sampling_exponent,
    )
    learner = learning_lib.SGDLearner(
        network=network,
        loss_fn=loss_fn,
        data_iterator=reverb_replay.data_iterator,
        optimizer=optimizer,
        target_update_period=config.target_update_period,
        random_key=key_learner,
        replay_client=reverb_replay.client,
    )

    # The actor selects actions according to the policy.
    assert config.epsilon is not Sequence
    def policy(params: networks_lib.Params, key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
      action_values = network.apply(params, observation)
      return rlax.epsilon_greedy(config.epsilon).sample(key, action_values)
    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
    variable_client = variable_utils.VariableClient(learner, '')
    actor = actors.GenericActor(
        actor_core, key_actor, variable_client, reverb_replay.adder)

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(config.batch_size, config.min_replay_size),
        observations_per_step=config.batch_size / config.samples_per_insert,
    )
コード例 #6
0
def apply_policy_and_sample(
    networks, eval_mode = False):
  """Returns a function that computes actions."""
  sample_fn = networks.sample if not eval_mode else networks.sample_eval
  if not sample_fn:
    raise ValueError('sample function is not provided')

  def apply_and_sample(params, key, obs):
    return sample_fn(networks.policy_network.apply(params, obs), key)
  return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
コード例 #7
0
ファイル: networks.py プロジェクト: kokizzu/google-research
def apply_policy_and_sample_with_img_encoder(networks, eval_mode=False):
    """Returns a function that computes actions."""
    sample_fn = networks.sample if not eval_mode else networks.sample_eval
    if not sample_fn:
        raise ValueError('sample function is not provided')

    def apply_and_sample(params, key, obs):
        img = obs['state_image']
        img_embedding = networks.img_encoder.apply(params[1], img)
        x = dict(state_image=img_embedding, state_dense=obs['state_dense'])
        return sample_fn(networks.policy_network.apply(params[0], x), key)

    return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
コード例 #8
0
ファイル: builder.py プロジェクト: vishalbelsare/acme
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None) -> acme.Actor:
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
       policy_network)
   variable_client = variable_utils.VariableClient(variable_source, 'policy',
                                                   device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, adder, backend='cpu')
コード例 #9
0
ファイル: builder.py プロジェクト: deepmind/acme
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy: actor_core_lib.FeedForwardPolicy,
     environment_spec: specs.EnvironmentSpec,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   del environment_spec
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
   variable_client = variable_utils.VariableClient(
       variable_source, 'policy', device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, backend='cpu')
コード例 #10
0
ファイル: networks.py プロジェクト: kokizzu/google-research
def build_q_filtered_actor(
    networks,
    num_samples,
    with_uniform=True,
):
    def select_action(
        params,
        key,
        obs,
    ):
        key, sub_key = jax.random.split(key)
        policy_params = params[0]
        q_params = params[1]

        dist = networks.policy_network.apply(policy_params, obs)
        acts = dist._sample_n(num_samples, sub_key)
        acts = acts[:, 0, :]  # N x act_dim

        if with_uniform:
            key, sub_key = jax.random.split(sub_key)
            unif_acts = jax.random.uniform(sub_key,
                                           acts.shape,
                                           dtype=acts.dtype,
                                           minval=-1.,
                                           maxval=1.)
            acts = jnp.concatenate([acts, unif_acts], axis=0)

        def obs_tile_fn(t):
            # t = jnp.expand_dims(t, axis=0)
            tile_shape = [1] * t.ndim
            # tile_shape[0] = num_samples
            tile_shape[0] = acts.shape[0]
            return jnp.tile(t, tile_shape)

        tiled_obs = jax.tree_map(obs_tile_fn, obs)

        # batch_size x num_critics
        all_q = networks.q_network.apply(q_params, tiled_obs, acts)
        # num_devices x num_per_device x batch_size
        q_score = jnp.min(all_q, axis=-1)
        best_idx = jnp.argmax(q_score)
        # return acts[best_idx], key
        return acts[best_idx][None, :]

    # return actor_core.ActorCore(
    #     init=lambda key: key,
    #     select_action=select_action,
    #     get_extras=lambda x: ())
    return actor_core.batched_feed_forward_to_actor_core(select_action)
コード例 #11
0
def apply_policy_and_sample(networks, eval_mode=False, use_img_encoder=False):
    """Returns a function that computes actions."""
    sample_fn = networks.sample if not eval_mode else networks.sample_eval
    if not sample_fn:
        raise ValueError('sample function is not provided')

    def apply_and_sample(params, key, obs):
        if use_img_encoder:
            params, encoder_params = params[0], params[1]
            obs = {
                'state_image':
                networks.img_encoder.apply(encoder_params, obs['state_image']),
                'state_dense':
                obs['state_dense']
            }
        return sample_fn(networks.policy_network.apply(params, obs), key)

    return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
コード例 #12
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
     assert variable_source is not None
     actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
         policy_network)
     # Inference happens on CPU, so it's better to move variables there too.
     variable_client = variable_utils.VariableClient(variable_source,
                                                     'policy',
                                                     device='cpu')
     return actors.GenericActor(actor_core,
                                random_key,
                                variable_client,
                                adder,
                                backend='cpu')
コード例 #13
0
ファイル: builder.py プロジェクト: kokizzu/google-research
 def make_actor(self,
                random_key,
                policy_network,
                adder=None,
                variable_source=None):
     assert variable_source is not None
     actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
         policy_network)
     variable_client = variable_utils.VariableClient(variable_source,
                                                     'policy',
                                                     device='cpu')
     if self._config.use_random_actor:
         ACTOR = contrastive_utils.InitiallyRandomActor  # pylint: disable=invalid-name
     else:
         ACTOR = actors.GenericActor  # pylint: disable=invalid-name
     return ACTOR(actor_core,
                  random_key,
                  variable_client,
                  adder,
                  backend='cpu')
コード例 #14
0
ファイル: actors_test.py プロジェクト: vishalbelsare/acme
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
コード例 #15
0
def build_q_filtered_actor(
    networks,
    beta,
    num_samples,
    use_img_encoder=False,
    with_uniform=True,
    ensemble_method='deep_ensembles',
    ensemble_size=None,  # not used for deep ensembles
    mimo_using_obs_tile=False,
    mimo_using_act_tile=False,
):
    if ensemble_method not in [
            'deep_ensembles',
            'mimo',
    ]:
        raise NotImplementedError()

    def select_action(
        params,
        key,
        obs,
    ):
        key, sub_key = jax.random.split(key)
        policy_params = params[0]
        all_q_params = params[1]
        if use_img_encoder:
            img_encoder_params = params[2]
            obs = {
                'state_image':
                networks.img_encoder.apply(img_encoder_params,
                                           obs['state_image']),
                'state_dense':
                obs['state_dense']
            }

        dist = networks.policy_network.apply(policy_params, obs)
        acts = dist._sample_n(num_samples, sub_key)
        acts = acts[:, 0, :]  # N x act_dim

        if with_uniform:
            key, sub_key = jax.random.split(sub_key)
            unif_acts = jax.random.uniform(sub_key,
                                           acts.shape,
                                           dtype=acts.dtype,
                                           minval=-1.,
                                           maxval=1.)
            acts = jnp.concatenate([acts, unif_acts], axis=0)

        if ensemble_method == 'deep_ensembles':
            get_all_q_values = jax.pmap(jax.vmap(networks.q_network.apply,
                                                 in_axes=(0, None, None),
                                                 out_axes=0),
                                        in_axes=(0, None, None),
                                        out_axes=0)
        elif ensemble_method == 'mimo':
            get_all_q_values = jax.pmap(jax.vmap(networks.q_network.apply,
                                                 in_axes=(0, None, None),
                                                 out_axes=0),
                                        in_axes=(0, None, None),
                                        out_axes=0)
        else:
            raise NotImplementedError()

        def obs_tile_fn(t):
            # t = jnp.expand_dims(t, axis=0)
            tile_shape = [1] * t.ndim
            # tile_shape[0] = num_samples
            tile_shape[0] = acts.shape[0]
            return jnp.tile(t, tile_shape)

        tiled_obs = jax.tree_map(obs_tile_fn, obs)

        if ensemble_method == 'deep_ensembles':
            # num_devices x num_per_device x batch_size x 2(because of double-Q)
            all_q = get_all_q_values(all_q_params, tiled_obs, acts)
            # num_devices x num_per_device x batch_size
            all_q = jnp.min(all_q, axis=-1)

            q_mean = jnp.mean(all_q, axis=(0, 1))
            q_std = jnp.std(all_q, axis=(0, 1))
            q_score = q_mean + beta * q_std  # batch_size
            best_idx = jnp.argmax(q_score)
        elif ensemble_method == 'mimo':
            if mimo_using_obs_tile:
                # if using the version where we also tile the obs
                tile_shape = [1] * tiled_obs.ndim
                tile_shape[-1] = ensemble_size
                tiled_obs = jnp.tile(tiled_obs, tile_shape)

            if mimo_using_act_tile:
                # if using the version where we are tiling the acts
                tile_shape = [1] * acts.ndim
                tile_shape[-1] = ensemble_size
                tiled_acts = jnp.tile(acts, tile_shape)
            else:
                # otherwise
                tiled_acts = acts

            all_q = get_all_q_values(
                all_q_params, tiled_obs, tiled_acts
            )  # 1 x 1 x batch_size x ensemble_size x (num_qs_per_member)
            all_q = jnp.min(all_q,
                            axis=-1)  # 1 x 1 x batch_size x ensemble_size

            q_mean = jnp.mean(all_q, axis=(0, 1, 3))
            q_std = jnp.std(all_q, axis=(0, 1, 3))
            q_score = q_mean + beta * q_std  # batch_size
            best_idx = jnp.argmax(q_score)
        else:
            raise NotImplementedError()

        # return acts[best_idx], key
        return acts[best_idx][None, :]

    # return actor_core.ActorCore(
    #     init=lambda key: key,
    #     select_action=select_action,
    #     get_extras=lambda x: ())
    # return select_action
    return actor_core.batched_feed_forward_to_actor_core(select_action)