def __init__(self,
                 network: discrete_networks.DiscreteFilteredQNetwork,
                 dataset: tf.data.Dataset,
                 learning_rate: float,
                 counter: counting.Counter = None,
                 bc_logger: loggers.Logger = None,
                 bcq_logger: loggers.Logger = None,
                 **bcq_learner_kwargs):
        counter = counter or counting.Counter()
        self._bc_logger = bc_logger or loggers.TerminalLogger('bc_learner',
                                                              time_delta=1.)
        self._bcq_logger = bcq_logger or loggers.TerminalLogger('bcq_learner',
                                                                time_delta=1.)

        self._bc_learner = bc.BCLearner(network=network.g_network,
                                        learning_rate=learning_rate,
                                        dataset=dataset,
                                        counter=counting.Counter(
                                            counter, 'bc'),
                                        logger=self._bc_logger,
                                        checkpoint=False)
        self._bcq_learner = _InternalBCQLearner(network=network,
                                                learning_rate=learning_rate,
                                                dataset=dataset,
                                                counter=counting.Counter(
                                                    counter, 'bcq'),
                                                logger=self._bcq_logger,
                                                **bcq_learner_kwargs)
def run_dqn(experiment_name):
    current_dir = pathlib.Path().absolute()
    directories = Save_paths(data_dir=f'{current_dir}/data', experiment_name=experiment_name)

    game = Winter_is_coming(setup=PARAMS['setup'])
    environment = wrappers.SinglePrecisionWrapper(game)
    spec = specs.make_environment_spec(environment)

    # Build the network.
    def _make_network(spec) -> snt.Module:
        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50, spec.actions.num_values]),
        ])
        tf2_utils.create_variables(network, [spec.observations])
        return network

    network = _make_network(spec)

    # Setup the logger
    if neptune_enabled:
        agent_logger = NeptuneLogger(label='DQN agent', time_delta=0.1)
        loop_logger = NeptuneLogger(label='Environment loop', time_delta=0.1)
        PARAMS['network'] = f'{network}'
        neptune.init('cvasquez/sandbox')
        neptune.create_experiment(name=experiment_name, params=PARAMS)
    else:
        agent_logger = loggers.TerminalLogger('DQN agent', time_delta=1.)
        loop_logger = loggers.TerminalLogger('Environment loop', time_delta=1.)

    # Build the agent
    agent = DQN(
        environment_spec=spec,
        network=network,
        params=PARAMS,
        checkpoint=True,
        paths=directories,
        logger=agent_logger
    )
    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, logger=loop_logger)
    loop.run(num_episodes=PARAMS['num_episodes'])

    last_checkpoint_path = agent.save()

    # Upload last checkpoint
    if neptune_upload_checkpoint and last_checkpoint_path:
        files = os.listdir(last_checkpoint_path)
        for f in files:
            neptune.log_artifact(os.path.join(last_checkpoint_path, f))

    if neptune_enabled:
        neptune.stop()

    do_example_run(game,agent)
Exemple #3
0
def _build_custom_loggers(wb_client):
  terminal_learner = loggers.TerminalLogger(label='Learner', time_delta=10)
  terminal_eval = loggers.TerminalLogger(label='EvalLoop', time_delta=10)

  if wb_client is not None:
    wb_learner = WBLogger(wb_client, label='Learner')
    wb_loop = WBLogger(wb_client, label='EvalLoop')
    disp = loggers.Dispatcher([terminal_learner, wb_learner])
    disp_loop = loggers.Dispatcher([terminal_eval, wb_loop])
    return disp, disp_loop
  else:
    return terminal_learner, terminal_eval
Exemple #4
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #5
0
    def make_logger(
        self,
        to_terminal: bool,
        to_csv: bool,
        to_tensorboard: bool,
        time_delta: float,
        print_fn: Callable[[str], None],
        external_logger: Optional[base.Logger],
        **external_logger_kwargs: Any,
    ) -> loggers.Logger:
        """Build a Mava logger.

        Args:
            label: Name to give to the logger.
            directory: base directory for the  logging of the experiment.
            to_terminal: to print the logs in the terminal.
            to_csv: to save the logs in a csv file.
            to_tensorboard: to write the logs tf-events.
            time_delta: minimum elapsed time (in seconds) between logging events.
            print_fn: function to call which acts like print.
            external_logger: optional external logger.
            external_logger_kwargs: optional external logger params.
        Returns:
            A logger (pipe) object that responds to logger.write(some_dict).
        """
        logger = []

        if to_terminal:
            logger += [
                loggers.TerminalLogger(label=self._label, print_fn=print_fn)
            ]

        if to_csv:
            logger += [
                loggers.CSVLogger(directory_or_file=self._path("csv"),
                                  label=self._label)
            ]

        if to_tensorboard:
            logger += [
                TFSummaryLogger(logdir=self._path("tensorboard"),
                                label=self._label)
            ]

        if external_logger:
            logger += [
                external_logger(
                    label=self._label,
                    **external_logger_kwargs,
                )
            ]

        if logger:
            logger = loggers.Dispatcher(logger)
            logger = loggers.NoneFilter(logger)
            logger = loggers.TimeFilter(logger, time_delta)
        else:
            logger = loggers.NoOpLogger()

        return logger
Exemple #6
0
    def __init__(self,
                 network: snt.Module,
                 learning_rate: float,
                 dataset: tf.data.Dataset,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      learning_rate: learning rate for the q-network update.
      dataset: dataset to learn from.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
    """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # TODO(b/155086959): Fix type stubs and remove.

        self._network = network
        self._optimizer = snt.optimizers.Adam(learning_rate)

        self._variables: List[List[tf.Tensor]] = [network.trainable_variables]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)
Exemple #7
0
def main(_):
    # Create an environment and grab the spec.
    environment = atari.environment(FLAGS.game)
    environment_spec = specs.make_environment_spec(environment)

    # Create dataset.
    dataset = atari.dataset(path=FLAGS.dataset_path,
                            game=FLAGS.game,
                            run=FLAGS.run,
                            num_shards=FLAGS.num_shards)
    # Discard extra inputs
    dataset = dataset.map(lambda x: x._replace(data=x.data[:5]))

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Build network.
    g_network = make_network(environment_spec.actions)
    q_network = make_network(environment_spec.actions)
    network = networks.DiscreteFilteredQNetwork(g_network=g_network,
                                                q_network=q_network,
                                                threshold=FLAGS.bcq_threshold)
    tf2_utils.create_variables(network, [environment_spec.observations])

    evaluator_network = snt.Sequential([
        q_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Counters.
    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = bcq.DiscreteBCQLearner(
        network=network,
        dataset=dataset,
        learning_rate=FLAGS.learning_rate,
        discount=FLAGS.discount,
        importance_sampling_exponent=FLAGS.importance_sampling_exponent,
        target_update_period=FLAGS.target_update_period,
        counter=counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #8
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_from_id(FLAGS.bsuite_id)
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors_tf2.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #9
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #10
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #11
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        discount: float,
        dataset: tf.data.Dataset,
        critic_lr: float = 1e-4,
        checkpoint_interval_minutes: int = 10.0,
        clipping: bool = True,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
        init_observations: Any = None,
    ):

        self._policy_network = policy_network
        self._critic_network = critic_network
        self._discount = discount
        self._clipping = clipping
        self._init_observations = init_observations

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self._num_steps = tf.Variable(0, dtype=tf.int32)

        # Batch dataset and create iterator.
        self._iterator = iter(dataset)

        self._critic_optimizer = snt.optimizers.Adam(critic_lr)

        # Expose the variables.
        self._variables = {
            'critic': self._critic_network.variables,
        }
        # We remove trailing dimensions to keep same output dimmension
        # as existing FQE based on D4PG. i.e.: (batch_size,).
        critic_mean = snt.Sequential(
            [self._critic_network, lambda t: tf.squeeze(t, -1)])
        self._critic_mean = critic_mean

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'critic': critic_mean,
            },
                                                       time_delta_minutes=60.)
Exemple #12
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.RNNCore,
      dataset: tf.data.Dataset,
      learning_rate: float,
      discount: float = 0.99,
      decay: float = 0.99,
      epsilon: float = 1e-5,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: Optional[float] = None,
      max_gradient_norm: Optional[float] = None,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
  ):

    # Internalise, optimizer, and dataset.
    self._env_spec = environment_spec
    self._optimizer = snt.optimizers.RMSProp(
        learning_rate=learning_rate,
        decay=decay,
        epsilon=epsilon
    )

    self._network = network
    self._variables = network.variables
    # TODO(b/155086959): Fix type stubs and remove.
    #self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    self._dataset = dataset

    # Hyperparameters.
    self._discount = discount
    self._entropy_cost = entropy_cost
    self._baseline_cost = baseline_cost

    # Set up reward/gradient clipping.
    if max_abs_reward is None:
      max_abs_reward = np.inf
    if max_gradient_norm is None:
      max_gradient_norm = 1e10  # A very large number. Infinity results in NaNs.
    self._max_abs_reward = tf.convert_to_tensor(max_abs_reward)
    self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Exemple #13
0
  def __init__(self,
               network: hk.Transformed,
               obs_spec: specs.Array,
               optimizer: optax.GradientTransformation,
               rng: hk.PRNGSequence,
               dataset: tf.data.Dataset,
               loss_fn: LossFn = _sparse_categorical_cross_entropy,
               counter: counting.Counter = None,
               logger: loggers.Logger = None):
    """Initializes the learner."""

    def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray:
      # Pull out the data needed for updates.
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data
      del r_t, d_t, o_t
      logits = network.apply(params, o_tm1)
      return jnp.mean(loss_fn(a_tm1, logits))

    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
      """Do a step of SGD."""
      grad_fn = jax.value_and_grad(loss)
      loss_value, gradients = grad_fn(state.params, sample)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      new_state = TrainingState(
          params=new_params, opt_state=new_opt_state, steps=steps)

      # Compute the global norm of the gradients for logging.
      global_gradient_norm = optax.global_norm(gradients)
      fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm}

      return new_state, fetches

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    # Initialise parameters and optimiser state.
    initial_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_opt_state = optimizer.init(initial_params)

    self._state = TrainingState(
        params=initial_params, opt_state=initial_opt_state, steps=0)

    self._sgd_step = jax.jit(sgd_step)
Exemple #14
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        queue: adder.Adder,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        n_step_horizon: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
        verbose_level: Optional[int] = 0,
    ):
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        tf2_utils.create_variables(network, [environment_spec.observations])

        actor = acting.A2CActor(environment_spec=environment_spec,
                                verbose_level=verbose_level,
                                network=network,
                                queue=queue)
        learner = learning.A2CLearner(
            environment_spec=environment_spec,
            network=network,
            dataset=queue,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=0,
                         observations_per_step=n_step_horizon)
Exemple #15
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
    ):
        # Internalise, optimizer, and dataset.
        self._env_spec = environment_spec

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemple #16
0
  def __init__(self,
               network: snt.Module,
               learning_rate: float,
               dataset: tf.data.Dataset,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               checkpoint_subpath: str = '~/acme/'
               ):
    """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      learning_rate: learning rate for the q-network update.
      dataset: dataset to learn from.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    self._network = network
    self._optimizer = snt.optimizers.Adam(learning_rate)

    self._variables: List[List[tf.Tensor]] = [network.trainable_variables]

    # Create a checkpointer and snapshoter object.
    self._checkpointer = tf2_savers.Checkpointer(
      objects_to_save=self.state,
      time_delta_minutes=10.,
      directory=checkpoint_subpath,
      subdirectory='bc_learner'
    )

    self._snapshotter = tf2_savers.Snapshotter(
      objects_to_save={'network': network}, time_delta_minutes=60.)
Exemple #17
0
  def __init__(
      self,
      network: snt.Module,
      optimizer: snt.Optimizer,
      dataset: tf.data.Dataset,
      discount: float,
      logger: Optional[loggers.Logger] = None,
      counter: Optional[counting.Counter] = None,
  ):

    # Logger and counter for tracking statistics / writing out to terminal.
    self._counter = counting.Counter(counter, 'learner')
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=30.)

    # Internalize components.
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    self._optimizer = optimizer
    self._network = network
    self._variables = network.trainable_variables
    self._discount = np.float32(discount)
Exemple #18
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
from acme.utils import loggers
from acme.wrappers import gym_wrapper

from agents.dqn_agent import DQNAgent
from networks.models import Models

from tensorflow.python.client import device_lib

print(device_lib.list_local_devices())


def render(env):
    return env.environment.render(mode='rgb_array')


environment = gym_wrapper.GymWrapper(gym.make('LunarLander-v2'))
environment = wrappers.SinglePrecisionWrapper(environment)
environment_spec = specs.make_environment_spec(environment)

model = Models.sequential_model(
    input_shape=environment_spec.observations.shape,
    num_outputs=environment_spec.actions.num_values,
    hidden_layers=3,
    layer_size=300)

agent = DQNAgent(environment_spec=environment_spec, network=model)

logger = loggers.TerminalLogger(time_delta=10.)
loop = acme.EnvironmentLoop(environment=environment, actor=agent)
loop.run()
Exemple #20
0
policy_network = snt.Sequential([
    networks.LayerNormMLP((256, 256, 256), activate_final=True),
    networks.NearZeroInitializedLinear(num_dimensions),
    networks.TanhToSpec(environment_spec.actions),
])

# Create the distributional critic network.
critic_network = snt.Sequential([
    # The multiplexer concatenates the observations/actions.
    networks.CriticMultiplexer(),
    networks.LayerNormMLP((512, 512, 256), activate_final=True),
    networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51),
])

# Create a logger for the agent and environment loop.
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)

# Create the D4PG agent.
agent = d4pg.D4PG(environment_spec=environment_spec,
                  policy_network=policy_network,
                  critic_network=critic_network,
                  observation_network=observation_network,
                  sigma=1.0,
                  logger=agent_logger,
                  checkpoint=False)

# Create an loop connecting this agent to the environment created above.
env_loop = environment_loop.EnvironmentLoop(environment,
                                            agent,
                                            logger=env_loop_logger)
Exemple #21
0
    def __init__(self,
                 value_func: snt.Module,
                 instrumental_feature: snt.Module,
                 policy_net: snt.Module,
                 discount: float,
                 value_learning_rate: float,
                 instrumental_learning_rate: float,
                 value_reg: float,
                 instrumental_reg: float,
                 stage1_reg: float,
                 stage2_reg: float,
                 instrumental_iter: int,
                 value_iter: int,
                 dataset: tf.data.Dataset,
                 d_tm1_weight: float = 1.0,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 checkpoint_interval_minutes: int = 10.0):
        """Initializes the learner.

        Args:
          value_func: value function network
          instrumental_feature: dual function network.
          policy_net: policy network.
          discount: global discount.
          value_learning_rate: learning rate for the treatment_net update.
          instrumental_learning_rate: learning rate for the instrumental_net update.
          value_reg: L2 regularizer for value net.
          instrumental_reg: L2 regularizer for instrumental net.
          stage1_reg: ridge regularizer for stage 1 regression
          stage2_reg: ridge regularizer for stage 2 regression
          instrumental_iter: number of iteration for instrumental net
          value_iter: number of iteration for value function,
          dataset: dataset to learn from.
          d_tm1_weight: weights for terminal state transitions. Ignored in this variant.
          counter: Counter object for (potentially distributed) counting.
          logger: Logger object for writing logs to.
          checkpoint: boolean indicating whether to checkpoint the learner.
          checkpoint_interval_minutes: checkpoint interval in minutes.
        """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self.stage1_reg = stage1_reg
        self.stage2_reg = stage2_reg
        self.instrumental_iter = instrumental_iter
        self.value_iter = value_iter
        self.discount = discount
        self.value_reg = value_reg
        self.instrumental_reg = instrumental_reg
        del d_tm1_weight

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self.value_func = value_func
        self.value_feature = value_func._feature
        self.instrumental_feature = instrumental_feature
        self.policy = policy_net
        self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate,
                                                         beta1=0.5,
                                                         beta2=0.9)
        self._instrumental_func_optimizer = snt.optimizers.Adam(
            instrumental_learning_rate, beta1=0.5, beta2=0.9)

        # Define additional variables.
        self.stage1_weight = tf.Variable(
            tf.zeros(
                (instrumental_feature.feature_dim(), value_func.feature_dim()),
                dtype=tf.float32))
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._variables = [
            self.value_func.trainable_variables,
            self.instrumental_feature.trainable_variables,
            self.stage1_weight,
        ]

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'value_func':
                self.value_func,
                'instrumental_feature':
                self.instrumental_feature,
            },
                                                       time_delta_minutes=60.)
Exemple #22
0
def main(_):
  problem_config = FLAGS.problem_config

  # Load the offline dataset and environment.
  dataset, dev_dataset, environment = utils.load_data_and_env(
      task_name=problem_config['task_name'],
      noise_level=problem_config['noise_level'],
      near_policy_dataset=problem_config['near_policy_dataset'],
      dataset_path=FLAGS.dataset_path,
      batch_size=FLAGS.batch_size,
      max_dev_size=FLAGS.max_dev_size)
  environment_spec = specs.make_environment_spec(environment)

  # Create the networks to optimize.
  value_func, instrumental_feature = dfiv.make_ope_networks(
      problem_config['task_name'], environment_spec,
      value_layer_sizes=FLAGS.value_layer_sizes,
      instrumental_layer_sizes=FLAGS.instrumental_layer_sizes)

  # Load pretrained target policy network.
  target_policy_net = utils.load_policy_net(
      task_name=problem_config['task_name'],
      noise_level=problem_config['noise_level'],
      near_policy_dataset=problem_config['near_policy_dataset'],
      dataset_path=FLAGS.dataset_path,
      environment_spec=environment_spec)

  counter = counting.Counter()
  learner_counter = counting.Counter(counter, prefix='learner')
  logger = loggers.TerminalLogger('learner')

  # The learner updates the parameters (and initializes them).
  learner_cls = dfiv.DFIVLearner
  if FLAGS.learner2:
    learner_cls = dfiv.DFIV2Learner
  learner = learner_cls(
      value_func=value_func,
      instrumental_feature=instrumental_feature,
      policy_net=target_policy_net,
      discount=problem_config['discount'],
      value_learning_rate=FLAGS.value_learning_rate,
      instrumental_learning_rate=FLAGS.instrumental_learning_rate,
      stage1_reg=FLAGS.stage1_reg,
      stage2_reg=FLAGS.stage2_reg,
      value_reg=FLAGS.value_reg,
      instrumental_reg=FLAGS.instrumental_reg,
      instrumental_iter=FLAGS.instrumental_iter,
      value_iter=FLAGS.value_iter,
      dataset=dataset,
      d_tm1_weight=FLAGS.d_tm1_weight,
      counter=learner_counter,
      logger=logger)

  eval_counter = counting.Counter(counter, 'eval')
  eval_logger = loggers.TerminalLogger('eval')

  while True:
    learner.step()
    steps = learner.state['num_steps'].numpy()

    if steps % FLAGS.evaluate_every == 0:
      eval_results = {}
      if dev_dataset is not None:
        eval_results = {'dev_mse': learner.cal_validation_err(dev_dataset)}
      eval_results.update(utils.ope_evaluation(
          value_func=value_func,
          policy_net=target_policy_net,
          environment=environment,
          num_init_samples=FLAGS.evaluate_init_samples,
          discount=problem_config['discount'],
          counter=eval_counter))
      eval_logger.write(eval_results)

    if steps >= FLAGS.max_steps:
      break
Exemple #23
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=max_queue_size,
            sequence_length=sequence_length,
            sequence_period=sequence_period,
            batch_size=batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )
        key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed))
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        transformed = hk.without_apply_rng(hk.transform(forward_fn))
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(transformed.apply, backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
Exemple #24
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    dataset, dev_dataset, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=FLAGS.batch_size,
        max_dev_size=FLAGS.max_dev_size,
        shuffle=False,
        repeat=False)
    environment_spec = specs.make_environment_spec(environment)
    """
      task_gamma_map = {
          'bsuite_catch': 0.25,
          'bsuite_mountain_car': 0.5,
          'bsuite_cartpole': 0.44,
      }
      gamma = FLAGS.gamma or task_gamma_map[problem_config['task_name']]
  """

    gamma = utils.get_median(problem_config['task_name'], environment_spec,
                             dataset)

    # Create the networks to optimize.
    value_func, instrumental_feature = kiv_batch.make_ope_networks(
        problem_config['task_name'],
        environment_spec,
        n_component=FLAGS.n_component,
        gamma=gamma)

    # Load pretrained target policy network.
    target_policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')
    logger = loggers.TerminalLogger('learner')

    # The learner updates the parameters (and initializes them).
    num_batches = 0
    for _ in dataset:
        num_batches += 1
    stage1_batch = num_batches // 2
    stage2_batch = num_batches - stage1_batch
    learner = kiv_batch.KIVLearner(value_func=value_func,
                                   instrumental_feature=instrumental_feature,
                                   policy_net=target_policy_net,
                                   discount=problem_config['discount'],
                                   stage1_reg=FLAGS.stage1_reg,
                                   stage2_reg=FLAGS.stage2_reg,
                                   stage1_batch=stage1_batch,
                                   stage2_batch=stage2_batch,
                                   dataset=dataset,
                                   valid_dataset=dev_dataset,
                                   counter=learner_counter,
                                   logger=logger,
                                   checkpoint=False)

    eval_counter = counting.Counter(counter, 'eval')
    eval_logger = loggers.TerminalLogger('eval')

    while True:
        results = {
            'gamma': gamma,
            'stage1_batch': stage1_batch,
            'stage2_batch': stage2_batch,
        }
        # Include learner results in eval results for ease of analysis.
        results.update(learner.step())
        results.update(
            utils.ope_evaluation(value_func=value_func,
                                 policy_net=target_policy_net,
                                 environment=environment,
                                 num_init_samples=FLAGS.evaluate_init_samples,
                                 discount=problem_config['discount'],
                                 counter=eval_counter))
        eval_logger.write(results)
        if learner.state['num_steps'] >= FLAGS.max_steps:
            break
Exemple #25
0
    def __init__(self,
                 value_func: snt.Module,
                 instrumental_feature: snt.Module,
                 policy_net: snt.Module,
                 discount: float,
                 value_learning_rate: float,
                 instrumental_learning_rate: float,
                 value_l2_reg: float,
                 instrumental_l2_reg: float,
                 stage1_reg: float,
                 stage2_reg: float,
                 instrumental_iter: int,
                 value_iter: int,
                 dataset: tf.data.Dataset,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True):
        """Initializes the learner.

        Args:
          value_feature: value function network
          instrumental_feature: dual function network.
          policy_net: policy network.
          discount: global discount.
          value_learning_rate: learning rate for the treatment_net update.
          instrumental_learning_rate: learning rate for the instrumental_net update.
          value_l2_reg: l2 reg for value feature
          instrumental_l2_reg: l2 reg for instrumental
          stage1_reg: ridge regularizer for stage 1 regression
          stage2_reg: ridge regularizer for stage 2 regression
          instrumental_iter: number of iteration for instrumental net
          value_iter: number of iteration for value function,
          dataset: dataset to learn from.
          counter: Counter object for (potentially distributed) counting.
          logger: Logger object for writing logs to.
          checkpoint: boolean indicating whether to checkpoint the learner.
        """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self.stage1_reg = stage1_reg
        self.stage2_reg = stage2_reg
        self.instrumental_iter = instrumental_iter
        self.value_iter = value_iter
        self.discount = discount
        self.value_l2_reg = value_l2_reg
        self.instrumental_reg = instrumental_l2_reg

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self.value_func = value_func
        self.value_feature = value_func._feature
        self.instrumental_feature = instrumental_feature
        self.policy = policy_net
        self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate)
        self._instrumental_func_optimizer = snt.optimizers.Adam(
            instrumental_learning_rate)

        self._variables = [
            value_func.trainable_variables,
            instrumental_feature.trainable_variables,
        ]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self.data = None

        # Create a snapshotter object.
        if checkpoint:
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'value_func':
                value_func,
                'instrumental_feature':
                instrumental_feature,
            },
                                                       time_delta_minutes=60.)
        else:
            self._snapshotter = None
Exemple #26
0
def main(_):
    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=_ENV_NAME.value)
    spec = specs.make_environment_spec(environment)

    key = jax.random.PRNGKey(_SEED.value)
    key, dataset_key, evaluator_key = jax.random.split(key, 3)

    # Load the dataset.
    dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train']
    # Unwrap the environment to get the demonstrations.
    dataset = mbop.episodes_to_timestep_batched_transitions(dataset,
                                                            return_horizon=10)
    dataset = tfds.JaxInMemoryRandomSampleIterator(
        dataset, key=dataset_key, batch_size=_BATCH_SIZE.value)

    # Apply normalization to the dataset.
    mean_std = mbop.get_normalization_stats(dataset,
                                            _NUM_NORMALIZATION_BATCHES.value)
    apply_normalization = jax.jit(
        functools.partial(running_statistics.normalize, mean_std=mean_std))
    dataset = (apply_normalization(sample) for sample in dataset)

    # Create the networks.
    networks = mbop.make_networks(spec,
                                  hidden_layer_sizes=tuple(
                                      _HIDDEN_LAYER_SIZES.value))

    # Use the default losses.
    losses = mbop.MBOPLosses()

    def logger_fn(label: str, steps_key: str):
        return loggers.make_default_logger(label, steps_key=steps_key)

    def make_learner(name, logger_fn, counter, rng_key, dataset, network,
                     loss):
        return mbop.make_ensemble_regressor_learner(
            name,
            _NUM_NETWORKS.value,
            logger_fn,
            counter,
            rng_key,
            dataset,
            network,
            loss,
            optax.adam(_LEARNING_RATE.value),
            _NUM_SGD_STEPS_PER_STEP.value,
        )

    learner = mbop.MBOPLearner(
        networks, losses, dataset, key, logger_fn,
        functools.partial(make_learner, 'world_model'),
        functools.partial(make_learner, 'policy_prior'),
        functools.partial(make_learner, 'n_step_return'))

    planning_config = mbop.MPPIConfig()

    assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, (
        'Number of trajectories must be a multiple of the number of networks.')

    actor_core = mbop.make_ensemble_actor_core(networks,
                                               planning_config,
                                               spec,
                                               mean_std,
                                               use_round_robin=False)
    evaluator = mbop.make_actor(actor_core, evaluator_key, learner)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Train the agent.
    while True:
        for _ in range(_EVALUATE_EVERY.value):
            learner.step()
        eval_loop.run(_EVALUATION_EPISODES.value)
Exemple #27
0
    def __init__(self,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 dataset: tf.data.Dataset,
                 discount: float,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[
                     snt.Optimizer] = snt.optimizers.Adam(1e-4),
                 critic_optimizer: Optional[
                     snt.Optimizer] = snt.optimizers.Adam(1e-4),
                 target_update_period: int = 100,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 cql_alpha: float = 0.0,
                 translate_lse: float = 100.,
                 empirical_policy: dict = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint_subpath: str = '~/acme/'):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      target_network: the target Q critic (which lags behind the online net).
      discount: discount to use for TD updates.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      learning_rate: learning rate for the q-network update.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer (see
        `acme.datasets.reverb.make_dataset` documentation).
      huber_loss_parameter: Quadratic-linear boundary for Huber loss.
      replay_client: client to replay to allow for updating priorities.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        # Create a target networks.
        self._target_policy_network = copy.deepcopy(policy_network)
        self._target_critic_network = copy.deepcopy(critic_network)
        self._critic_optimizer = critic_optimizer
        self._policy_optimizer = policy_optimizer
        self._target_update_period = target_update_period

        # Internalise the hyperparameters.
        self._discount = discount
        self._target_update_period = target_update_period
        # crr specific
        assert policy_improvement_modes in [
            'exp', 'binary', 'all'
        ], 'Policy imp. mode must be one of {exp, binary, all}'
        self._policy_improvement_modes = policy_improvement_modes
        self._beta = beta
        self._ratio_upper_bound = ratio_upper_bound
        # cql specific
        self._alpha = tf.constant(cql_alpha, dtype=tf.float32)
        self._tr = tf.constant(translate_lse, dtype=tf.float32)
        if cql_alpha:
            assert empirical_policy is not None, 'Empirical behavioural policy must be specified with non-zero cql_alpha.'
        self._emp_policy = empirical_policy

        # Learner state.
        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._counter.increment(learner_steps=0)
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Create a checkpointer and snapshoter object.
        self._checkpointer = tf2_savers.Checkpointer(
            objects_to_save=self.state,
            time_delta_minutes=10.,
            directory=checkpoint_subpath,
            subdirectory='crr_learner')

        objects_to_save = {
            'raw_policy': policy_network,
            'critic': critic_network,
        }
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save=objects_to_save, time_delta_minutes=10)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()
Exemple #28
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        def postprocess_aux(extra: LossExtra) -> LossExtra:
            reverb_update = jax.tree_map(
                lambda a: jnp.reshape(a, (-1, *a.shape[2:])),
                extra.reverb_update)
            return extra._replace(metrics=jax.tree_map(jnp.mean,
                                                       extra.metrics),
                                  reverb_update=reverb_update)

        self._num_sgd_steps_per_step = num_sgd_steps_per_step
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step,
                                                  postprocess_aux)
        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        # Initialize the network parameters
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params)
        initial_target_params = self.network.init(key_target)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: ReverbUpdate) -> None:
            if replay_client is None:
                return
            keys, priorities = tree.map_structure(
                utils.fetch_devicearray,
                (reverb_update.keys, reverb_update.priorities))
            replay_client.mutate_priorities(table=replay_table_name,
                                            updates=dict(zip(keys,
                                                             priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #29
0
    def __init__(self,
                 network: networks.QNetwork,
                 obs_spec: specs.Array,
                 discount: float,
                 importance_sampling_exponent: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optix.InitUpdate,
                 rng: hk.PRNGSequence,
                 max_abs_reward: float = 1.,
                 huber_loss_parameter: float = 1.,
                 replay_client: reverb.Client = None,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner."""

        # Transform network into a pure function.
        network = hk.transform(network)

        def loss(params: hk.Params, target_params: hk.Params,
                 sample: reverb.ReplaySample):
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            keys, probs = sample.info[:2]

            # Forward pass.
            q_tm1 = network.apply(params, o_tm1)
            q_t_value = network.apply(target_params, o_t)
            q_t_selector = network.apply(params, o_t)

            # Cast and clip rewards.
            d_t = (d_t * discount).astype(jnp.float32)
            r_t = jnp.clip(r_t, -max_abs_reward,
                           max_abs_reward).astype(jnp.float32)

            # Compute double Q-learning n-step TD-error.
            batch_error = jax.vmap(rlax.double_q_learning)
            td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                   q_t_selector)
            batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

            # Importance weighting.
            importance_weights = (1. / probs).astype(jnp.float32)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)

            # Reweight.
            mean_loss = jnp.mean(importance_weights * batch_loss)  # []

            priorities = jnp.abs(td_error).astype(jnp.float64)

            return mean_loss, (keys, priorities)

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            new_state = TrainingState(params=new_params,
                                      target_params=state.target_params,
                                      opt_state=new_opt_state,
                                      step=state.step + 1)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs

        def update_priorities(outputs: LearnerOutputs):
            for key, priority in zip(outputs.keys, outputs.priorities):
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates={key: priority})

        # Internalise agent components (replay buffer, networks, optimizer).
        self._replay_client = replay_client
        self._iterator = utils.prefetch(iterator)

        # Internalise the hyperparameters.
        self._target_update_period = target_update_period

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialise parameters and optimiser state.
        initial_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_target_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_opt_state = optimizer.init(initial_params)

        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)

        self._forward = jax.jit(network.apply)
        self._sgd_step = jax.jit(sgd_step)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #30
0
    def __init__(self,
                 policy_network: snt.RNNCore,
                 critic_network: networks.CriticDeepRNN,
                 target_policy_network: snt.RNNCore,
                 target_critic_network: networks.CriticDeepRNN,
                 dataset: tf.data.Dataset,
                 accelerator_strategy: Optional[tf.distribute.Strategy] = None,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[snt.Optimizer] = None,
                 critic_optimizer: Optional[snt.Optimizer] = None,
                 discount: float = 0.99,
                 target_update_period: int = 100,
                 num_action_samples_td_learning: int = 1,
                 num_action_samples_policy_weight: int = 4,
                 baseline_reduce_function: str = 'mean',
                 clipping: bool = True,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint: bool = False):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      accelerator_strategy: the strategy used to distribute computation,
        whether on a single, or multiple, GPU or TPU; as supported by
        tf.distribute.
      behavior_network: The network to snapshot under `policy` name. If None,
        snapshots `policy_network` instead.
      cwp_network: CWP network to snapshot: samples actions
        from the policy and weighs them with the critic, then returns the action
        by sampling from the softmax distribution using critic values as logits.
        Used only for snapshotting, not training.
      policy_optimizer: the optimizer to be applied to the policy loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      num_action_samples_td_learning: number of action samples to use to
        estimate expected value of the critic loss w.r.t. stochastic policy.
      num_action_samples_policy_weight: number of action samples to use to
        estimate the advantage function for the CRR weighting of the policy
        loss.
      baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating
        values from `num_action_samples` estimates of the value function.
      clipping: whether to clip gradients by global norm.
      policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which
        determines how the advantage function is processed before being
        multiplied by the policy loss.
      ratio_upper_bound: if policy_improvement_modes is 'exp', determines
        the upper bound of the weight (i.e. the weight is
          min(exp(advantage / beta), upper_bound)
        ).
      beta: if policy_improvement_modes is 'exp', determines the beta (see
        above).
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        if accelerator_strategy is None:
            accelerator_strategy = snt.distribute.Replicator()
        self._accelerator_strategy = accelerator_strategy
        self._policy_improvement_modes = policy_improvement_modes
        self._ratio_upper_bound = ratio_upper_bound
        self._num_action_samples_td_learning = num_action_samples_td_learning
        self._num_action_samples_policy_weight = num_action_samples_policy_weight
        self._baseline_reduce_function = baseline_reduce_function
        self._beta = beta

        # When running on TPUs we have to know the amount of memory required (and
        # thus the sequence length) at the graph compilation stage. At the moment,
        # the only way to get it is to sample from the dataset, since the dataset
        # does not have any metadata, see b/160672927 to track this upcoming
        # feature.
        sample = next(dataset.as_numpy_iterator())
        self._sequence_length = sample.action.shape[1]

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)
        self._discount = discount
        self._clipping = clipping

        self._target_update_period = target_update_period

        with self._accelerator_strategy.scope():
            # Necessary to track when to update target networks.
            self._num_steps = tf.Variable(0, dtype=tf.int32)

            # (Maybe) distributing the dataset across multiple accelerators.
            distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset(
                dataset)
            self._iterator = iter(distributed_dataset)

            # Create the optimizers.
            self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(
                1e-4)
            self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(
                1e-4)

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None
        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                },
                time_delta_minutes=30.)

            raw_policy = snt.DeepRNN(
                [policy_network,
                 networks.StochasticSamplingHead()])
            critic_mean = networks.CriticDeepRNN(
                [critic_network, networks.StochasticMeanHead()])
            objects_to_save = {
                'raw_policy': raw_policy,
                'critic': critic_mean,
            }
            if behavior_network is not None:
                objects_to_save['policy'] = behavior_network
            if cwp_network is not None:
                objects_to_save['cwp_policy'] = cwp_network
            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save=objects_to_save, time_delta_minutes=30)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()