def _create_network(self, name):
    """Builds a multi-network Q-network that outputs Q-values for each network.

    Args:
      name: str, this name is passed to the tf.keras.Model and used to create
        variable scope under the hood by the tf.keras.Model.

    Returns:
      network: tf.keras.Model, the network instantiated by the Keras model.
    """
    # Pass the device_fn to place Q-networks on different devices
    kwargs = {'device_fn': lambda i: '/gpu:{}'.format(i // 4)}
    if self._q_networks_transform is None:
      if self.transform_strategy == 'STOCHASTIC':
        tf.logging.info('Creating q_networks transformation matrix..')
        self._q_networks_transform = atari_helpers.random_stochastic_matrix(
            self.num_networks, num_cols=self._num_convex_combinations)
    if self._q_networks_transform is not None:
      kwargs.update({'transform_matrix': self._q_networks_transform})
    return self.network(
        num_actions=self.num_actions,
        num_networks=self.num_networks,
        transform_strategy=self.transform_strategy,
        name=name,
        **kwargs)
 def _build_networks(self):
   super(MultiNetworkDQNAgent, self)._build_networks()
   # q_argmax is only used for picking an action
   self._q_argmax_eval = tf.argmax(self._net_outputs.q_values, axis=1)[0]
   if self.use_deep_exploration:
     if self.transform_strategy.endswith('STOCHASTIC'):
       q_transform = atari_helpers.random_stochastic_matrix(
           self.num_networks, num_cols=1)
       self._q_episode_transform = tf.get_variable(
           trainable=False,
           dtype=tf.float32,
           shape=q_transform.get_shape().as_list(),
           name='q_episode_transform')
       self._update_episode_q_function = self._q_episode_transform.assign(
           q_transform)
       episode_q_function = tf.tensordot(
           self._net_outputs.unordered_q_networks,
           self._q_episode_transform, axes=[[2], [0]])
       self._q_argmax_train = tf.argmax(episode_q_function[:, :, 0], axis=1)[0]
     elif self.transform_strategy == 'IDENTITY':
       self._q_function_index = tf.Variable(
           initial_value=0,
           trainable=False,
           dtype=tf.int32,
           shape=(),
           name='q_head_episode')
       self._update_episode_q_function = self._q_function_index.assign(
           tf.random.uniform(
               shape=(), maxval=self.num_networks, dtype=tf.int32))
       q_function = self._net_outputs.unordered_q_networks[
           :, :, self._q_function_index]
       # This is only used for picking an action
       self._q_argmax_train = tf.argmax(q_function, axis=1)[0]
   else:
     self._q_argmax_train = self._q_argmax_eval
Beispiel #3
0
 def _network_template(self, state):
     kwargs = {}
     if self._q_heads_transform is None:
         if self.transform_strategy == 'STOCHASTIC':
             tf.logging.info('Creating q_heads transformation matrix..')
             self._q_heads_transform = atari_helpers.random_stochastic_matrix(
                 self.num_heads, num_cols=self._num_convex_combinations)
     if self._q_heads_transform is not None:
         kwargs.update({'transform_matrix': self._q_heads_transform})
     return self.network(self.num_actions, self.num_heads,
                         self._get_network_type(), state,
                         self.transform_strategy, **kwargs)
    def _create_network(self, name):
        """Builds a multi-head Q-network that outputs Q-values for multiple heads.

    Args:
      name: str, this name is passed to the tf.keras.Model and used to create
        variable scope under the hood by the tf.keras.Model.
    Returns:
      network: tf.keras.Model, the network instantiated by the Keras model.
    """
        kwargs = {}  # Used for passing the transformation matrix if any
        if self._q_heads_transform is None:
            if self.reorder_strategy == 'STOCHASTIC':
                tf.logging.info('Creating q_heads transformation matrix..')
                self._q_heads_transform = atari_helpers.random_stochastic_matrix(
                    self.num_heads, num_cols=self._num_convex_combinations)
        if self._q_heads_transform is not None:
            kwargs.update({'transform_matrix': self._q_heads_transform})
        network = self.network(num_actions=self.num_actions,
                               num_heads=self.num_heads,
                               transform_strategy=self.transform_strategy,
                               name=name,
                               **kwargs)
        return network