Exemplo n.º 1
0
def main(_):
    # Create an environment and create the spec.
    environment, environment_spec = _build_environment(
        FLAGS.environment_name, max_steps=FLAGS.max_steps_per_episode)

    if FLAGS.model_name:
        loaded_network = load_wb_model(FLAGS.model_name, FLAGS.model_tag)

        if FLAGS.stochastic:
            head = networks.StochasticSamplingHead()
        else:
            head = lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon
                                                 ).sample()

        policy_network = snt.Sequential([
            loaded_network,
            head,
        ])
        actor = actors.FeedForwardActor(policy_network)

    else:
        actor = RandomActor(environment_spec)

    recorder = DemonstrationRecorder(environment, actor)

    recorder.collect_n_episodes(FLAGS.n_episodes)
    recorder.make_tf_dataset()
    recorder.save(FLAGS.save_dir)
Exemplo n.º 2
0
def main(_):
    # Create an environment and grab the spec.
    environment = atari.environment(FLAGS.game)
    environment_spec = specs.make_environment_spec(environment)

    # Create dataset.
    dataset = atari.dataset(path=FLAGS.dataset_path,
                            game=FLAGS.game,
                            run=FLAGS.run,
                            num_shards=FLAGS.num_shards)
    # Discard extra inputs
    dataset = dataset.map(lambda x: x._replace(data=x.data[:5]))

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Build network.
    g_network = make_network(environment_spec.actions)
    q_network = make_network(environment_spec.actions)
    network = networks.DiscreteFilteredQNetwork(g_network=g_network,
                                                q_network=q_network,
                                                threshold=FLAGS.bcq_threshold)
    tf2_utils.create_variables(network, [environment_spec.observations])

    evaluator_network = snt.Sequential([
        q_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Counters.
    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = bcq.DiscreteBCQLearner(
        network=network,
        dataset=dataset,
        learning_rate=FLAGS.learning_rate,
        discount=FLAGS.discount,
        importance_sampling_exponent=FLAGS.importance_sampling_exponent,
        target_update_period=FLAGS.target_update_period,
        counter=counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Exemplo n.º 3
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_from_id(FLAGS.bsuite_id)
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors_tf2.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Exemplo n.º 4
0
 def __call__(
         self, observation: Union[Mapping[str, tf.Tensor],
                                  tf.Tensor]) -> tf.Tensor:
     q = self._network(observation)
     return trfl.epsilon_greedy(
         q,
         epsilon=self._epsilon,
         legal_actions_mask=observation['legal_actions_mask']).sample()
Exemplo n.º 5
0
def main(_):
    wb_run = init_or_resume()

    if FLAGS.seed:
        tf.random.set_seed(FLAGS.seed)

    # Create an environment and grab the spec.
    environment, env_spec = _build_environment(FLAGS.environment_name)

    # Load demonstration dataset.
    raw_dataset = load_tf_dataset(directory=FLAGS.dataset_dir)

    dataset = preprocess_dataset(raw_dataset, FLAGS.batch_size,
                                 FLAGS.n_step_returns, FLAGS.discount)

    # Create the policy and critic networks.
    policy_network = networks.get_default_critic(env_spec)

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [env_spec.observations])

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Create the actor which defines how we take actions.
    evaluation_actor = actors.FeedForwardActor(evaluator_network)

    counter = counting.Counter()

    disp, disp_loop = _build_custom_loggers(wb_run)

    eval_loop = EnvironmentLoop(environment=environment,
                                actor=evaluation_actor,
                                counter=counter,
                                logger=disp_loop)

    # The learner updates the parameters (and initializes them).
    learner = BCLearner(network=policy_network,
                        learning_rate=FLAGS.learning_rate,
                        dataset=dataset,
                        counter=counter)

    # Run the environment loop.
    for _ in tqdm(range(FLAGS.epochs)):
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)

    learner.save(tag=FLAGS.logs_tag)
Exemplo n.º 6
0
    def _policy(self, observation: types.NestedTensor,
                mask: types.NestedTensor) -> types.NestedTensor:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        qs = self._policy_network(batched_observation)

        qs = qs * tf.cast(mask, dtype=tf.float32)
        # Sample from the policy if it is stochastic.
        action = trfl.epsilon_greedy(qs, epsilon=0.05).sample()

        return action
Exemplo n.º 7
0
def load_policy_net(
    task_name: str,
    noise_level: float,
    dataset_path: str,
    environment_spec: specs.EnvironmentSpec,
    near_policy_dataset: bool = False,
    ):
    dataset_path = Path(dataset_path)
    if task_name.startswith("bsuite"):
        # BSuite tasks.
        bsuite_id = task_name[len("bsuite_"):] + "/0"
        path = bsuite_policy_path(
            bsuite_id, noise_level, near_policy_dataset, dataset_path)
        logging.info("Policy path: %s", path)
        policy_net = tf.saved_model.load(path)

        policy_noise_level = 0.1  # params["policy_noise_level"]
        observation_network = tf2_utils.to_sonnet_module(functools.partial(
            tf.reshape, shape=(-1,) + environment_spec.observations.shape))
        policy_net = snt.Sequential([
            observation_network,
            policy_net,
            # Uncomment this line to add action noise to the target policy.
            lambda q: trfl.epsilon_greedy(q, epsilon=policy_noise_level).sample(),
        ])
    elif task_name.startswith("dm_control"):
        # DM Control tasks.
        if near_policy_dataset:
            raise ValueError(
                "Near-policy dataset is not available for dm_control tasks.")
        dm_control_task = task_name[len("dm_control_"):]
        path = dm_control_policy_path(
            dm_control_task, noise_level, dataset_path)
        logging.info("Policy path: %s", path)
        policy_net = tf.saved_model.load(path)

        policy_noise_level = 0.2  # params["policy_noise_level"]
        observation_network = tf2_utils.to_sonnet_module(tf2_utils.batch_concat)
        policy_net = snt.Sequential([
            observation_network,
            policy_net,
            # Uncomment these two lines to add action noise to target policy.
            acme_utils.GaussianNoise(policy_noise_level),
            networks.ClipToSpec(environment_spec.actions),
        ])
    else:
        raise ValueError(f"task name {task_name} is unsupported.")
    return policy_net
Exemplo n.º 8
0
  def actor(
      self,
      replay: reverb.Client,
      variable_source: acme.VariableSource,
      counter: counting.Counter,
      epsilon: float,
  ) -> acme.EnvironmentLoop:
    """The actor process."""
    environment = self._environment_factory(False)
    network = self._network_factory(self._environment_spec.actions)

    tf2_utils.create_variables(network, [self._obs_spec])

    policy_network = snt.DeepRNN([
        network,
        lambda qs: tf.cast(trfl.epsilon_greedy(qs, epsilon).sample(), tf.int32),
    ])

    # Component to add things into replay.
    sequence_length = self._burn_in_length + self._trace_length + 1
    adder = adders.SequenceAdder(
        client=replay,
        period=self._replay_period,
        sequence_length=sequence_length,
        delta_encoded=True,
    )

    variable_client = tf2_variable_utils.VariableClient(
        client=variable_source,
        variables={'policy': policy_network.variables},
        update_period=self._variable_update_period)

    # Make sure not to use a random policy after checkpoint restoration by
    # assigning variables before running the environment loop.
    variable_client.update_and_wait()

    # Create the agent.
    actor = actors.RecurrentActor(
        policy_network=policy_network,
        variable_client=variable_client,
        adder=adder)

    counter = counting.Counter(counter, 'actor')
    logger = loggers.make_default_logger(
        'actor', save_data=False, steps_key='actor_steps')

    # Create the loop to connect environment and agent.
    return acme.EnvironmentLoop(environment, actor, counter, logger)
Exemplo n.º 9
0
    def _policy(
        self, observation: types.NestedTensor, state: types.NestedTensor,
        mask: types.NestedTensor
    ) -> Tuple[types.NestedTensor, types.NestedTensor]:

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        qvals, new_state = self._network(batched_observation, state)

        # Sample from the policy if it is stochastic.
        action = trfl.epsilon_greedy(qvals,
                                     epsilon=0.05,
                                     legal_actions_mask=tf.cast(
                                         mask, dtype=tf.float32)).sample()
        return action, new_state
Exemplo n.º 10
0
    def actor(
        self,
        replay: reverb.Client,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
        epsilon: float,
    ) -> acme.EnvironmentLoop:
        """The actor process."""
        environment = self._environment_factory(False)
        network = self._network_factory(self._env_spec.actions)

        # Just inline the policy network here.
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        tf2_utils.create_variables(policy_network,
                                   [self._env_spec.observations])
        variable_client = tf2_variable_utils.VariableClient(
            client=variable_source,
            variables={'policy': policy_network.trainable_variables},
            update_period=self._variable_update_period)

        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        # Component to add things into replay.
        adder = adders.NStepTransitionAdder(
            client=replay,
            n_step=self._n_step,
            discount=self._discount,
        )

        # Create the agent.
        actor = actors.FeedForwardActor(policy_network, adder, variable_client)

        # Create the loop to connect environment and agent.
        counter = counting.Counter(counter, 'actor')
        logger = loggers.make_default_logger('actor',
                                             save_data=False,
                                             steps_key='actor_steps')
        return acme.EnvironmentLoop(environment, actor, counter, logger)
Exemplo n.º 11
0
    def evaluator(
        self,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        environment = self._environment_factory(True)
        network = self._network_factory(self._env_spec.actions)

        # Just inline the policy network here.
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, self._evaluator_epsilon).sample(),
        ])

        tf2_utils.create_variables(policy_network,
                                   [self._env_spec.observations])

        variable_client = tf2_variable_utils.VariableClient(
            client=variable_source,
            variables={'policy': policy_network.trainable_variables},
            update_period=self._variable_update_period)

        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        # Create the agent.
        actor = actors.FeedForwardActor(policy_network,
                                        variable_client=variable_client)

        # Create the run loop and return it.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')
        counter = counting.Counter(counter, 'evaluator')
        return acme.EnvironmentLoop(environment,
                                    actor,
                                    counter=counter,
                                    logger=logger)
Exemplo n.º 12
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 20,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.01,
        epsilon_schedule_timesteps: int = 20000,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        max_gradient_norm: Optional[float] = None,
        logger: loggers.Logger = None,
    ):
        """Initialize the agent.

        Args:
        environment_spec: description of the actions, observations, etc.
        network: the online Q network (the one being optimized)
        batch_size: batch size for updates.
        prefetch_size: size to prefetch from replay.
        target_update_period: number of learner steps to perform before updating
            the target networks.
        samples_per_insert: number of samples to take from replay for every insert
            that is made.
        min_replay_size: minimum replay size before updating. This and all
            following arguments are related to dataset construction and will be
            ignored if a dataset argument is passed.
        max_replay_size: maximum replay size.
        importance_sampling_exponent: power to which importance weights are raised
            before normalizing (beta). See https://arxiv.org/pdf/1710.02298.pdf
        priority_exponent: exponent used in prioritized sampling (omega).
            See https://arxiv.org/pdf/1710.02298.pdf
        n_step: number of steps to squash into a single transition.
        epsilon_init: Initial epsilon value (probability of taking a random action)
        epsilon_final: Final epsilon value (probability of taking a random action)
        epsilon_schedule_timesteps: timesteps to decay epsilon from 'epsilon_init'
            to 'epsilon_final'. 
        learning_rate: learning rate for the q-network update.
        discount: discount to use for TD updates.
        logger: logger object to be used by learner.
        max_gradient_norm: used for gradient clipping.
        """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        self._adder = adders.NStepTransitionAdder(
            client=reverb.Client(address), n_step=n_step, discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = make_reverb_dataset(server_address=address,
                                      batch_size=batch_size,
                                      prefetch_size=prefetch_size)

        policy_network = snt.Sequential([
            network,
            EpsilonGreedyExploration(
                epsilon_init=epsilon_init,
                epsilon_final=epsilon_final,
                epsilon_schedule_timesteps=epsilon_schedule_timesteps)
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors_tf2.FeedForwardActor(policy_network, self._adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            max_gradient_norm=max_gradient_norm,
            logger=logger,
            checkpoint=False)

        self._saver = tf2_savers.Saver(learner.state)

        # Deterministic (max-Q) actor.
        max_Q_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=0.0).sample(),
        ])
        self._deterministic_actor = actors_tf2.FeedForwardActor(max_Q_network)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 13
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 32,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 100000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: Optional[float] = 0.05,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        logger: loggers.Logger = None,
        max_gradient_norm: Optional[float] = None,
        expert_data: List[Dict] = None,
    ) -> None:
        """ Initialize the agent. """

        # Create a replay server to add data to. This uses no limiter behavior
        # in order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # Adding expert data to the replay memory:
        if expert_data is not None:
            for d in expert_data:
                adder.add_first(d["first"])
                for (action, next_ts) in d["mid"]:
                    adder.add(np.int32(action), next_ts)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Creating the epsilon greedy policy network:
        epsilon = tf.Variable(epsilon)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not
        # needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            max_gradient_norm=max_gradient_norm,
            logger=logger,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 14
0
Arquivo: agent.py Projeto: dzorlu/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 15
0
def main(_):
    wb_run = init_or_resume()

    if FLAGS.seed:
        tf.random.set_seed(FLAGS.seed)
    # Create an environment and grab the spec.
    environment, env_spec = _build_environment(
        FLAGS.environment_name, max_steps=FLAGS.max_eval_episode_len)

    # Load demonstration dataset.
    raw_dataset = load_tf_dataset(directory=FLAGS.dataset_dir)
    empirical_policy = compute_empirical_policy(raw_dataset)

    dataset = preprocess_dataset(raw_dataset, FLAGS.batch_size,
                                 FLAGS.n_step_returns, FLAGS.discount)

    # Create the main critic network
    critic_network = networks.get_default_critic(env_spec)

    policy_network = snt.Sequential([
        critic_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    tf2_utils.create_variables(critic_network, [env_spec.observations])

    # Create the actor which defines how we take actions.
    evaluation_actor = actors.FeedForwardActor(policy_network)

    counter = counting.Counter()

    disp, disp_loop = _build_custom_loggers(wb_run)

    eval_loop = EnvironmentLoop(environment=environment,
                                actor=evaluation_actor,
                                counter=counter,
                                logger=disp_loop)

    learner = CQLLearner(network=critic_network,
                         dataset=dataset,
                         discount=FLAGS.discount,
                         importance_sampling_exponent=0.2,
                         learning_rate=FLAGS.learning_rate,
                         cql_alpha=FLAGS.cql_alpha,
                         translate_lse=FLAGS.translate_lse,
                         target_update_period=100,
                         empirical_policy=empirical_policy,
                         logger=disp,
                         counter=counter)

    # Run the environment loop.
    for e in tqdm(range(FLAGS.epochs)):
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
        # Visualization of the policy
        Q = evaluate_q(learner._network, environment)
        plot = visualize_policy(Q, environment)
        wb_run.log({'chart': plot, 'epoch_counter': e})

    learner.save(tag=FLAGS.logs_tag)
Exemplo n.º 16
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        cql_alpha: float = 1.,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint_subpath: str = '~/acme/',
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
      checkpoint_subpath: directory for the checkpoint.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = CQLLearner(
            network=network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            cql_alpha=cql_alpha,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            counter=counter,
            checkpoint_subpath=checkpoint_subpath)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 17
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.01,
        epsilon_schedule_timesteps: float = 20000,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        if store_lstm_state:
            extra_spec = {
                'core_state':
                tf2_utils.squeeze_batch_dim(network.initial_state(1)),
            }
        else:
            extra_spec = ()

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        self._adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = make_reverb_dataset(server_address=address,
                                      batch_size=batch_size,
                                      prefetch_size=prefetch_size,
                                      sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._saver = tf2_savers.Saver(learner.state)

        policy_network = snt.DeepRNN([
            network,
            EpsilonGreedyExploration(
                epsilon_init=epsilon_init,
                epsilon_final=epsilon_final,
                epsilon_schedule_timesteps=epsilon_schedule_timesteps)
        ])
        actor = actors.RecurrentActor(policy_network,
                                      self._adder,
                                      store_recurrent_state=store_lstm_state)

        max_Q_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(),
        ])
        self._deterministic_actor = actors.RecurrentActor(
            max_Q_network, self._adder, store_recurrent_state=store_lstm_state)

        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 18
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        reverb_client = reverb.TFClient(address)
        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        dataset = datasets.make_reverb_dataset(
            client=reverb_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb_client,
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 19
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 network: snt.RNNCore,
                 target_network: snt.RNNCore,
                 burn_in_length: int,
                 trace_length: int,
                 replay_period: int,
                 demonstration_dataset: tf.data.Dataset,
                 demonstration_ratio: float,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 discount: float = 0.99,
                 batch_size: int = 32,
                 target_update_period: int = 100,
                 importance_sampling_exponent: float = 0.2,
                 epsilon: float = 0.01,
                 learning_rate: float = 1e-3,
                 log_to_bigtable: bool = False,
                 log_name: str = 'agent',
                 checkpoint: bool = True,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               sequence_length=sequence_length)

        # Combine with demonstration dataset.
        transition = functools.partial(_sequence_from_episode,
                                       extra_spec=extra_spec,
                                       **sequence_kwargs)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 20
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        params=None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
        paths: Save_paths = None,
    ):
        """Initialize the agent.

        Args:
          environment_spec: description of the actions, observations, etc.
          network: the online Q network (the one being optimized)
          batch_size: batch size for updates.
          prefetch_size: size to prefetch from replay.
          target_update_period: number of learner steps to perform before updating
            the target networks.
          samples_per_insert: number of samples to take from replay for every insert
            that is made.
          min_replay_size: minimum replay size before updating. This and all
            following arguments are related to dataset construction and will be
            ignored if a dataset argument is passed.
          max_replay_size: maximum replay size.
          importance_sampling_exponent: power to which importance weights are raised
            before normalizing.
          priority_exponent: exponent used in prioritized sampling.
          n_step: number of steps to squash into a single transition.
          epsilon: probability of taking a random action; ignored if a policy
            network is given.
          learning_rate: learning rate for the q-network update.
          discount: discount to use for TD updates.
          logger: logger object to be used by learner.
          checkpoint: boolean indicating whether to checkpoint the learner.
        """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        if params is None:
            params = {
                'batch_size': 256,
                'prefetch_size': 4,
                'target_update_period': 100,
                'samples_per_insert': 32.0,
                'min_replay_size': 1000,
                'max_replay_size': 1000000,
                'importance_sampling_exponent': 0.2,
                'priority_exponent': 0.6,
                'n_step': 5,
                'epsilon': 0.05,
                'learning_rate': 1e-3,
                'discount': 0.99,
            }
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(params['priority_exponent']),
            remover=reverb.selectors.Fifo(),
            max_size=params['max_replay_size'],
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=params['n_step'],
                                            discount=params['discount'])

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=params['batch_size'],
            prefetch_size=params['prefetch_size'],
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        epsilon = tf.Variable(params['epsilon'], trainable=False)

        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        # tf2_utils.create_variables(network, [environment_spec.observations])
        # tf2_utils.create_variables(target_network, [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=params['discount'],
            importance_sampling_exponent=params[
                'importance_sampling_exponent'],
            learning_rate=params['learning_rate'],
            target_update_period=params['target_update_period'],
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            checkpoint=checkpoint)

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                add_uid=False,
                objects_to_save=learner.state,
                directory=paths.data_dir,
                subdirectory=paths.experiment_name,
                time_delta_minutes=60.)
        else:
            self._checkpointer = None

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(params['batch_size'],
                                              params['min_replay_size']),
                         observations_per_step=float(params['batch_size']) /
                         params['samples_per_insert'])
Exemplo n.º 21
0
def R2D2AtariActorNetwork(num_actions: int, epsilon: tf.Variable):
    network = networks.R2D2AtariNetwork(num_actions)
    return snt.DeepRNN([
        network,
        lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
    ])
Exemplo n.º 22
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        demonstration_dataset: tf.data.Dataset,
        demonstration_ratio: float,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      demonstration_dataset: tf.data.Dataset producing (timestep, action)
        tuples containing full episodes.
      demonstration_ratio: Ratio of transitions coming from demonstrations.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            transition_adder=True)

        # Combine with demonstration dataset.
        transition = functools.partial(_n_step_transition_from_episode,
                                       n_step=n_step,
                                       discount=discount)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(prefetch_size)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = dqn.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 23
0
def DQNAtariActorNetwork(num_actions: int, epsilon: tf.Variable):
    network = networks.DQNAtariNetwork(num_actions)
    return snt.Sequential([
        network,
        lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
    ])