Beispiel #1
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        network: snt.RNNCore,
        optimizer: snt.Optimizer,
        sequence_length: int,
        td_lambda: float,
        discount: float,
        seed: int,
    ):
        """A recurrent actor-critic agent."""

        # Internalise network and optimizer.
        self._forward = tf.function(network)
        self._network = network
        self._optimizer = optimizer

        # Initialise recurrent state.
        self._state: snt.LSTMState = network.initial_state(1)
        self._rollout_initial_state: snt.LSTMState = network.initial_state(1)

        # Set seed and internalise hyperparameters.
        tf.random.set_seed(seed)
        self._sequence_length = sequence_length
        self._num_transitions_in_buffer = 0
        self._discount = discount
        self._td_lambda = td_lambda

        # Initialise rolling experience buffer.
        shapes = [obs_spec.shape, (), (), (), ()]
        dtypes = [obs_spec.dtype, np.int32, np.float32, np.float32, np.float32]
        self._buffer = [
            np.zeros(shape=(self._sequence_length, 1) + shape, dtype=dtype)
            for shape, dtype in zip(shapes, dtypes)
        ]
Beispiel #2
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        queue: adder.Adder,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        n_step_horizon: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
        verbose_level: Optional[int] = 0,
    ):
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        tf2_utils.create_variables(network, [environment_spec.observations])

        actor = acting.A2CActor(environment_spec=environment_spec,
                                verbose_level=verbose_level,
                                network=network,
                                queue=queue)
        learner = learning.A2CLearner(
            environment_spec=environment_spec,
            network=network,
            dataset=queue,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=0,
                         observations_per_step=n_step_horizon)
Beispiel #3
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: snt.RNNCore,
      optimizer: tf.train.Optimizer,
      sequence_length: int,
      td_lambda: float,
      agent_discount: float,
      seed: int,
  ):
    """A recurrent actor-critic agent."""
    del action_spec  # unused
    tf.set_random_seed(seed)
    self._sequence_length = sequence_length
    self._num_transitions_in_buffer = 0

    # Create the policy ops.
    obs = tf.placeholder(shape=(1,) + obs_spec.shape, dtype=obs_spec.dtype)
    mask = tf.placeholder(shape=(1,), dtype=tf.float32)
    state = self._placeholders_like(network.initial_state(batch_size=1))
    (online_logits, _), next_state = network((obs, mask), state)
    action = tf.squeeze(tf.multinomial(online_logits, 1, output_dtype=tf.int32))

    # Create placeholders and numpy arrays for learning from trajectories.
    shapes = [obs_spec.shape, (), (), (), ()]
    dtypes = [obs_spec.dtype, np.int32, np.float32, np.float32, np.float32]

    placeholders = [
        tf.placeholder(shape=(self._sequence_length, 1) + shape, dtype=dtype)
        for shape, dtype in zip(shapes, dtypes)]
    observations, actions, rewards, discounts, masks = placeholders

    # Build actor and critic losses.
    (logits, values), final_state = tf.nn.dynamic_rnn(
        network, (observations, tf.expand_dims(masks, -1)),
        initial_state=state, dtype=tf.float32, time_major=True)
    (_, bootstrap_value), _ = network((obs, mask), final_state)
    values, bootstrap_value = tree.map_structure(
        lambda t: tf.squeeze(t, axis=-1), (values, bootstrap_value))
    critic_loss, (advantages, _) = td_lambda_loss(
        state_values=values,
        rewards=rewards,
        pcontinues=agent_discount * discounts,
        bootstrap_value=bootstrap_value,
        lambda_=td_lambda)
    actor_loss = discrete_policy_gradient_loss(logits, actions, advantages)

    # Updates.
    grads_and_vars = optimizer.compute_gradients(actor_loss + critic_loss)
    grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], 5.)
    grads_and_vars = [(g, pair[1]) for g, pair in zip(grads, grads_and_vars)]
    train_op = optimizer.apply_gradients(grads_and_vars)

    # Create TF session and callables.
    session = tf.Session()
    self._reset_fn = session.make_callable(
        network.initial_state(batch_size=1))
    self._policy_fn = session.make_callable(
        [action, next_state], [obs, mask, state])
    self._update_fn = session.make_callable(
        [train_op, final_state], placeholders + [obs, mask, state])
    session.run(tf.global_variables_initializer())

    # Initialize numpy buffers
    self.state = self._reset_fn()
    self.update_init_state = self._reset_fn()
    self.arrays = [
        np.zeros(shape=(self._sequence_length, 1) + shape, dtype=dtype)
        for shape, dtype in zip(shapes, dtypes)]
Beispiel #5
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.01,
        epsilon_schedule_timesteps: float = 20000,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        if store_lstm_state:
            extra_spec = {
                'core_state':
                tf2_utils.squeeze_batch_dim(network.initial_state(1)),
            }
        else:
            extra_spec = ()

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        self._adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = make_reverb_dataset(server_address=address,
                                      batch_size=batch_size,
                                      prefetch_size=prefetch_size,
                                      sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._saver = tf2_savers.Saver(learner.state)

        policy_network = snt.DeepRNN([
            network,
            EpsilonGreedyExploration(
                epsilon_init=epsilon_init,
                epsilon_final=epsilon_final,
                epsilon_schedule_timesteps=epsilon_schedule_timesteps)
        ])
        actor = actors.RecurrentActor(policy_network,
                                      self._adder,
                                      store_recurrent_state=store_lstm_state)

        max_Q_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(),
        ])
        self._deterministic_actor = actors.RecurrentActor(
            max_Q_network, self._adder, store_recurrent_state=store_lstm_state)

        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Beispiel #6
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        reverb_client = reverb.TFClient(address)
        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        dataset = datasets.make_reverb_dataset(
            client=reverb_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb_client,
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Beispiel #7
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 network: snt.RNNCore,
                 target_network: snt.RNNCore,
                 burn_in_length: int,
                 trace_length: int,
                 replay_period: int,
                 demonstration_dataset: tf.data.Dataset,
                 demonstration_ratio: float,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 discount: float = 0.99,
                 batch_size: int = 32,
                 target_update_period: int = 100,
                 importance_sampling_exponent: float = 0.2,
                 epsilon: float = 0.01,
                 learning_rate: float = 1e-3,
                 log_to_bigtable: bool = False,
                 log_name: str = 'agent',
                 checkpoint: bool = True,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               sequence_length=sequence_length)

        # Combine with demonstration dataset.
        transition = functools.partial(_sequence_from_episode,
                                       extra_spec=extra_spec,
                                       **sequence_kwargs)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Beispiel #8
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=adders.SequenceAdder.signature(
                                       environment_spec,
                                       extras_spec=extra_spec,
                                       sequence_length=sequence_length))
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               batch_size=batch_size)

        tf2_utils.create_variables(network, [environment_spec.observations])

        self._actor = acting.IMPALAActor(network, adder)
        self._learner = learning.IMPALALearner(
            environment_spec=environment_spec,
            network=network,
            dataset=dataset,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )
Beispiel #9
0
def build_learner(agent: snt.RNNCore,
                  agent_state,
                  env_outputs,
                  agent_outputs,
                  reward_clipping: str,
                  discounting: float,
                  baseline_cost: float,
                  entropy_cost: float,
                  policy_cloning_cost: float,
                  value_cloning_cost: float,
                  clip_grad_norm: float,
                  clip_advantage: bool,
                  learning_rate: float,
                  batch_size: int,
                  batch_size_from_replay: int,
                  unroll_length: int,
                  reward_scaling: float = 1.0,
                  adam_beta1: float = 0.9,
                  adam_beta2: float = 0.999,
                  adam_epsilon: float = 1e-8,
                  fixed_step_mul: bool = False,
                  step_mul: int = 8):
    """Builds the learner loop.

    Returns:
        A tuple of (done, infos, and environment frames) where
        the environment frames tensor causes an update.

    """
    learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs,
                                      agent_state)

    # Use last baseline value (from the value function) to bootstrap.
    bootstrap_value = learner_outputs.baseline[-1]

    # At this point, the environment outputs at time step `t` are the inputs that
    # lead to the learner_outputs at time step `t`. After the following shifting,
    # the actions in agent_outputs and learner_outputs at time step `t` is what
    # leads to the environment outputs at time step `t`.
    agent_outputs = tf.nest.map_structure(lambda t: t[1:], agent_outputs)
    agent_outputs_from_buffer = tf.nest.map_structure(
        lambda t: t[:, :batch_size_from_replay], agent_outputs)
    learner_outputs_from_buffer = tf.nest.map_structure(
        lambda t: t[:-1, :batch_size_from_replay], learner_outputs)

    rewards, infos, done, _ = tf.nest.map_structure(lambda t: t[1:],
                                                    env_outputs)
    learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs)

    rewards = rewards * reward_scaling
    clipped_rewards = clip_rewards(rewards, reward_clipping)
    discounts = tf.to_float(~done) * discounting

    # We only need to learn a step_mul policy if the step multiplier is not fixed.
    if not fixed_step_mul:
        agent_outputs.action['step_mul'] = agent_outputs.step_mul
        agent_outputs.action_logits['step_mul'] = agent_outputs.step_mul_logits
        learner_outputs.action_logits[
            'step_mul'] = learner_outputs.step_mul_logits
        agent_outputs_from_buffer.action_logits[
            'step_mul'] = agent_outputs_from_buffer.step_mul_logits
        learner_outputs_from_buffer.action_logits[
            'step_mul'] = learner_outputs_from_buffer.step_mul_logits

    actions = tf.nest.flatten(
        tf.nest.map_structure(lambda x: tf.squeeze(x, axis=2),
                              agent_outputs.action))
    behaviour_logits = tf.nest.flatten(agent_outputs.action_logits)
    target_logits = tf.nest.flatten(learner_outputs.action_logits)
    behaviour_logits_from_buffer = tf.nest.flatten(
        agent_outputs_from_buffer.action_logits)
    target_logits_from_buffer = tf.nest.flatten(
        learner_outputs_from_buffer.action_logits)

    behaviour_neg_log_probs = sum(
        tf.nest.map_structure(compute_neg_log_probs, behaviour_logits,
                              actions))
    target_neg_log_probs = sum(
        tf.nest.map_structure(compute_neg_log_probs, target_logits, actions))
    entropy_loss = sum(
        tf.nest.map_structure(compute_entropy_loss, target_logits))

    with tf.device('/cpu'):
        vtrace_returns = vtrace.from_importance_weights(
            log_rhos=behaviour_neg_log_probs - target_neg_log_probs,
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs.baseline,
            bootstrap_value=bootstrap_value)

    advantages = tf.stop_gradient(vtrace_returns.pg_advantages)
    # Clip advantages to strictly positive:
    if clip_advantage:
        advantages *= tf.where(advantages > 0.0, tf.ones_like(advantages),
                               tf.zeros_like(advantages))
    policy_gradient_loss = tf.reduce_sum(
        target_neg_log_probs * tf.stop_gradient(vtrace_returns.pg_advantages))
    baseline_loss = .5 * tf.reduce_sum(
        tf.square(vtrace_returns.vs - learner_outputs.baseline))
    entropy_loss = tf.reduce_sum(entropy_loss)

    # Compute the CLEAR policy cloning loss and the value cloning as described in https://arxiv.org/abs/1811.11682:
    policy_cloning_loss = sum(
        tf.nest.map_structure(compute_policy_cloning_loss,
                              target_logits_from_buffer,
                              behaviour_logits_from_buffer))
    value_cloning_loss = tf.reduce_sum(
        tf.square(learner_outputs_from_buffer.baseline -
                  tf.stop_gradient(agent_outputs_from_buffer.baseline)))

    # Combine individual losses, weighted by cost factors, to build overall loss:
    total_loss = policy_gradient_loss \
                 + baseline_cost * baseline_loss \
                 + entropy_cost * entropy_loss \
                 + policy_cloning_cost * policy_cloning_loss \
                 + value_cloning_cost * value_cloning_loss

    optimizer = tf.train.AdamOptimizer(learning_rate, adam_beta1, adam_beta2,
                                       adam_epsilon)
    parameters = tf.trainable_variables()
    gradients = tf.gradients(total_loss, parameters)
    gradients, grad_norm = clip_gradients(gradients, clip_grad_norm)
    train_op = optimizer.apply_gradients(list(zip(gradients, parameters)))

    # Merge updating the network and environment frames into a single tensor.
    with tf.control_dependencies([train_op]):
        if fixed_step_mul:
            step_env_frames = unroll_length * (
                batch_size - batch_size_from_replay) * step_mul
        else:
            # do not use replay samples to calculate num environment frames
            step_env_frames = tf.to_int64(
                tf.reduce_sum(
                    learner_outputs.step_mul[:, batch_size_from_replay:] + 1))
        num_env_frames_and_train = tf.train.get_global_step().assign_add(
            step_env_frames)

    # Adding a few summaries.
    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.scalar('entropy_cost', entropy_cost)
    tf.summary.scalar('loss/policy_gradient', policy_gradient_loss)
    tf.summary.scalar('loss/baseline', baseline_loss)
    tf.summary.scalar('loss/entropy', entropy_loss)
    tf.summary.scalar('loss/policy_cloning', policy_cloning_loss)
    tf.summary.scalar('loss/value_cloning', value_cloning_loss)
    tf.summary.scalar('loss/total_loss', total_loss)
    for action_name, action in agent_outputs.action.items():
        tf.summary.histogram(f'action/{action_name}', action)
    tf.summary.scalar('grad_norm', grad_norm)

    return done, infos, num_env_frames_and_train
Beispiel #10
0
def build_critic_learner(agent: snt.RNNCore,
                         agent_state,
                         env_outputs,
                         agent_outputs,
                         reward_clipping: str,
                         discounting: float,
                         clip_grad_norm: float,
                         learning_rate: float,
                         batch_size: int,
                         batch_size_from_replay: int,
                         unroll_length: int,
                         reward_scaling: float = 1.0,
                         adam_beta1: float = 0.9,
                         adam_beta2: float = 0.999,
                         adam_epsilon: float = 1e-8,
                         fixed_step_mul: bool = False,
                         step_mul: int = 8):
    learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs,
                                      agent_state)

    bootstrap_value = learner_outputs.baseline[-1]
    rewards, infos, done, _ = tf.nest.map_structure(lambda t: t[1:],
                                                    env_outputs)
    learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs)

    rewards = rewards * reward_scaling
    clipped_rewards = clip_rewards(rewards, reward_clipping)
    discounts = tf.to_float(~done) * discounting

    returns = tf.scan(lambda a, x: x[0] + x[1] * a,
                      elems=[clipped_rewards, discounts],
                      initializer=bootstrap_value,
                      parallel_iterations=1,
                      reverse=True,
                      back_prop=False)

    baseline_loss = .5 * tf.reduce_sum(
        tf.square(returns - learner_outputs.baseline))

    # Optimization
    optimizer = tf.train.AdamOptimizer(learning_rate, adam_beta1, adam_beta2,
                                       adam_epsilon)
    parameters = tf.trainable_variables()
    gradients = tf.gradients(baseline_loss, parameters)
    gradients, grad_norm = clip_gradients(gradients, clip_grad_norm)
    train_op = optimizer.apply_gradients(list(zip(gradients, parameters)))

    # Merge updating the network and environment frames into a single tensor.
    with tf.control_dependencies([train_op]):
        if fixed_step_mul:
            step_env_frames = unroll_length * (
                batch_size - batch_size_from_replay) * step_mul
        else:
            # do not use replay samples to calculate num environment frames
            step_env_frames = tf.to_int64(
                tf.reduce_sum(
                    learner_outputs.step_mul[:, batch_size_from_replay:] + 1))
        num_env_frames_and_train = tf.train.get_global_step().assign_add(
            step_env_frames)

    # Adding a few summaries.
    tf.summary.scalar('ciritc_pretrain/learning_rate', learning_rate,
                      ['ciritc_pretrain_summaries'])
    tf.summary.scalar('ciritc_pretrain/baseline_loss', baseline_loss,
                      ['ciritc_pretrain_summaries'])
    tf.summary.scalar('ciritc_pretrain/grad_norm', grad_norm,
                      ['ciritc_pretrain_summaries'])

    summary_op = tf.summary.merge_all('ciritc_pretrain_summaries')

    return num_env_frames_and_train, summary_op