def call(self, state, num_quantiles): batch_size = state.get_shape().as_list()[0] inputs = tf.constant(np.ones( (batch_size * num_quantiles, self.num_actions)), dtype=tf.float32) quantiles_shape = [num_quantiles * batch_size, 1] quantiles = tf.ones(quantiles_shape) return atari_lib.ImplicitQuantileNetworkType( self.layer(inputs), quantiles)
def __call__(self, x, num_quantiles, rng): del rng x = x.reshape((-1)) # flatten state_net_tiled = jnp.tile(x, [num_quantiles, 1]) x *= state_net_tiled quantile_values = linen.Dense( features=self.num_actions, kernel_init=linen.initializers.ones, bias_init=linen.initializers.zeros)(x) quantiles = jnp.ones([num_quantiles, 1]) return atari_lib.ImplicitQuantileNetworkType( quantile_values, quantiles)
def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles, rng): del rng # This weights_initializer gives action 0 a higher weight, ensuring # that it gets picked by the argmax. batch_size = x.shape[0] x = x[None, :] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten quantile_values = nn.Dense(x, features=num_actions, kernel_init=jax.nn.initializers.ones, bias_init=jax.nn.initializers.zeros) quantiles = jnp.ones([num_quantiles * batch_size, 1]) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles, rng): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim]) quantile_net = ( jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(quantile_net, features=state_vector_length, kernel_init=initializer) quantile_net = jax.nn.relu(quantile_net) x = state_net_tiled * quantile_net x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) quantile_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def __call__(self, x, num_quantiles, rng): initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') if not self.inputs_preprocessed: x = preprocess_atari_inputs(x) x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) x = nn.relu(x) x = x.reshape((-1)) # flatten state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim]) quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(features=state_vector_length, kernel_init=initializer)(quantile_net) quantile_net = nn.relu(quantile_net) x = state_net_tiled * quantile_net x = nn.Dense(features=512, kernel_init=initializer)(x) x = nn.relu(x) quantile_values = nn.Dense(features=self.num_actions, kernel_init=initializer)(x) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def __call__(self, x, num_quantiles, rng): if self.net_conf == 'minatar': x = x.squeeze(3) x = x.astype(jnp.float32) x = nn.Conv(features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) elif self.net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x.astype(jnp.float32) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((-1)) # flatten elif self.net_conf == 'classic': #classic environments #print('x input',x.shape) x = x.astype(jnp.float32) x = x.reshape((-1)) #print('x.shape:',x.shape) if self.env is not None and self.env in env_inf: x = x - env_inf[self.env]['MIN_VALS'] x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] x = 2.0 * x - 1.0 if self.noisy: def net(x, features, rng): return NoisyNetwork(features, rng=rng, bias_in=True)(x) else: def net(x, features, rng): return nn.Dense(features, kernel_init=self.initzer)(x) for _ in range(self.hidden_layer): x = net(x, features=self.neurons, rng=rng) x = jax.nn.relu(x) state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim]) quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(features=state_vector_length, kernel_init=self.initzer)(quantile_net) quantile_net = jax.nn.relu(quantile_net) x = state_net_tiled * quantile_net #print('X_before_adv:', x.shape) adv = net(x, features=self.num_actions, rng=rng) val = net(x, features=1, rng=rng) #print('value:', val.shape) dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True))) non_dueling_q = net(x, features=self.num_actions, rng=rng) quantile_values = jnp.where(self.dueling, dueling_q, non_dueling_q) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)