def __call__(self, x): initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') x = x.astype(jnp.float32) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) x = nn.relu(x) x = x.reshape((-1)) # flatten x = nn.Dense(features=512, kernel_init=initializer)(x) x = nn.relu(x) x = nn.Dense(features=self.num_actions * self.num_atoms, kernel_init=initializer)(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, num_actions, num_atoms): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def call(self, state): inputs = tf.constant(np.zeros((state.shape[0], stack_size)), dtype=tf.float32) net = self.layer(inputs) logits = tf.reshape(net, [-1, self.num_actions, self.num_atoms]) probabilities = tf.keras.activations.softmax(logits) qs = tf.reduce_sum(self.support * probabilities, axis=2) return atari_lib.RainbowNetworkType(qs, logits, probabilities)
def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512): del normalize_obs if minatar: x = x.squeeze(3) x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) else: x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) if env is not None: x = x - env_inf[env]['MIN_VALS'] x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS'] x = 2.0 * x - 1.0 if noisy: def net(x, features): return NoisyNetwork(x, features) else: def net(x, features): return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform()) for _ in range(hidden_layer): x = net(x, features=neurons) #print('x:',x) x = jax.nn.relu(x) if dueling: print('dueling') adv = net(x,features=num_actions * num_atoms) value = net(x, features=num_atoms) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) else: #print('No dueling') x = net(x, features=num_actions * num_atoms) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, support): initializer = nn.initializers.xavier_uniform() x = x.astype(jnp.float32) x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) x = nn.relu(x) x = x.reshape(-1) # flatten x = nn.Dense(features=self.num_actions * self.num_atoms, kernel_init=initializer)(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, support): x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten if self.min_vals is not None: x -= self._min_vals x /= self._max_vals - self._min_vals x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. for layer in self.layers: x = layer(x) x = nn.relu(x) x = self.final_layer(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, support): initializer = nn.initializers.xavier_uniform() x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten x -= gym_lib.ACROBOT_MIN_VALS x /= gym_lib.ACROBOT_MAX_VALS - gym_lib.ACROBOT_MIN_VALS x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(features=512, kernel_init=initializer)(x) x = nn.relu(x) x = nn.Dense(features=512, kernel_init=initializer)(x) x = nn.relu(x) x = nn.Dense(features=self.num_actions * self.num_atoms, kernel_init=initializer)(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, num_actions, num_atoms, support): def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.ones(shape, dtype) to_pick_first_action[:, :num_atoms] = onp.arange(1, num_atoms + 1) return to_pick_first_action x = x[None, :] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=custom_init, bias_init=jax.nn.initializers.ones) logits = x.reshape((-1, num_actions, num_atoms)) probabilities = nn.softmax(logits) qs = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(qs, logits, probabilities)
def __call__(self, x): def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.ones(shape, dtype) to_pick_first_action[:, :self.num_atoms] = onp.arange( 1, self.num_atoms + 1) return to_pick_first_action x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten x = linen.Dense(features=self.num_actions * self.num_atoms, kernel_init=custom_init, bias_init=linen.initializers.ones)(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = linen.softmax(logits) qs = jnp.mean(logits, axis=1) return atari_lib.RainbowNetworkType(qs, logits, probabilities)
def apply(self, x, num_actions, num_atoms, support, noisy, dueling): # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten #x -= gym_lib.CARTPOLE_MIN_VALS #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS #x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. if noisy: print('LunarLander-Noisy[Johan]') initializer = None bias = True def net(x, features, bias, kernel_init): return NoisyNetwork(x, features, bias, kernel_init) else: initializer = nn.initializers.xavier_uniform() bias = None def net(x, features, bias, kernel_init): return nn.Dense(x, features, kernel_init) x = net(x, features=512, bias=bias, kernel_init=initializer) x = jax.nn.relu(x) x = net(x,features=512, bias=bias, kernel_init=initializer) x = jax.nn.relu(x) if dueling: print('LunarLanderRainbowFull-Dueling') adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer) value = net(x, features=num_atoms, bias=bias, kernel_init=initializer) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) else: x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, support, eval_mode=False, key=None): # Generate a random number generation key if not provided if key is None: key = jax.random.PRNGKey(int(time.time() * 1e6)) if not self.inputs_preprocessed: x = preprocess_atari_inputs(x) hidden_sizes = [32, 64, 64] kernel_sizes = [8, 4, 3] stride_sizes = [4, 2, 1] for hidden_size, kernel_size, stride_size in zip( hidden_sizes, kernel_sizes, stride_sizes): x = nn.Conv(features=hidden_size, kernel_size=(kernel_size, kernel_size), strides=(stride_size, stride_size), kernel_init=nn.initializers.xavier_uniform())(x) x = nn.relu(x) x = x.reshape((-1)) # flatten net = feature_layer(key, self.noisy, eval_mode=eval_mode) x = net(x, features=512) # Single hidden layer of size 512 x = nn.relu(x) if self.dueling: adv = net(x, features=self.num_actions * self.num_atoms) value = net(x, features=self.num_atoms) adv = adv.reshape((self.num_actions, self.num_atoms)) value = value.reshape((1, self.num_atoms)) logits = value + (adv - (jnp.mean(adv, axis=0, keepdims=True))) else: x = net(x, features=self.num_actions * self.num_atoms) logits = x.reshape((self.num_actions, self.num_atoms)) if self.distributional: probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities) q_values = jnp.sum(logits, axis=1) # Sum over all the num_atoms return atari_lib.DQNNetworkType(q_values)
def apply(self, x, num_actions, num_atoms): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x -= gym_lib.CARTPOLE_MIN_VALS x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, num_actions, num_atoms, support, noisy, dueling): # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. initializer_conv = nn.initializers.xavier_uniform() x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=initializer_conv) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten. if noisy: print('InvadersRainbowFull-Noisy[Johan]') initializer = None bias = True def net(x, features, bias, kernel_init): return NoisyNetwork(x, features, bias, kernel_init) else: initializer = nn.initializers.xavier_uniform() bias = None def net(x, features, bias, kernel_init): return nn.Dense(x, features, kernel_init) if dueling: print('InvadersRainbowFull-Dueling') adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer) value = net(x, features=num_atoms, bias=bias, kernel_init=initializer) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) else: x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, support, eval_mode=False, key=None): def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.ones(shape, dtype) to_pick_first_action[:, :self.num_atoms] = onp.arange( 1, self.num_atoms + 1) return to_pick_first_action x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten x = nn.Dense( features=self.num_actions * self.num_atoms, kernel_init=custom_init, bias_init=nn.initializers.ones)( x) logits = x.reshape((self.num_actions, self.num_atoms)) if not self.distributional: qs = jnp.sum(logits, axis=-1) # Sum over all the num_atoms return atari_lib.DQNNetworkType(qs) probabilities = nn.softmax(logits) qs = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(qs, logits, probabilities)
def __call__(self, x, support, rng): if self.net_conf == 'minatar': x = x.squeeze(3) x = x.astype(jnp.float32) for _ in range(self.hidden_conv): x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), padding='SAME', kernel_init=self.initzer)(x) x = layer_funct_inf[self.layer_funct](x) x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = layer_funct_inf[self.layer_funct](x) x = x.reshape((-1)) elif self.net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x.astype(jnp.float32) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=self.initzer)(x) x = layer_funct_inf[self.layer_funct](x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=self.initzer)(x) x = layer_funct_inf[self.layer_funct](x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = layer_funct_inf[self.layer_funct](x) x = x.reshape((-1)) # flatten elif self.net_conf == 'classic': x = x.astype(jnp.float32) x = x.reshape((-1)) if self.env is not None and self.env in env_inf: x = x - env_inf[self.env]['MIN_VALS'] x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] x = 2.0 * x - 1.0 if self.noisy: def net(x, features, rng): return NoisyNetwork(features, rng=rng, bias_in=True)(x) else: def net(x, features, rng): return nn.Dense(features, kernel_init=self.initzer)(x) for _ in range(self.hidden_layer): x = net(x, features=self.neurons, rng=rng) if self.normalization == 'non_normalization': if self.layer_funct != 'non_activation': x = layer_funct_inf[self.layer_funct](x) elif self.normalization == 'BatchNorm': x = nn.BatchNorm(use_running_average=True)(x) if self.layer_funct != 'non_activation': x = layer_funct_inf[self.layer_funct](x) elif self.normalization == 'LayerNorm': if self.layer_funct != 'non_activation': x = layer_funct_inf[self.layer_funct](x) x = nn.LayerNorm()(x) else: print('error: Choose a correct Normalization Module') if self.dueling: adv = net(x, features=self.num_actions * self.num_atoms, rng=rng) value = net(x, features=self.num_atoms, rng=rng) adv = adv.reshape((self.num_actions, self.num_atoms)) value = value.reshape((1, self.num_atoms)) logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) else: x = net(x, features=self.num_actions * self.num_atoms, rng=rng) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def call(self, state): x = self.net(state) logits = tf.reshape(x, [-1, self.num_actions, self.num_atoms]) probabilities = layers.softmax(logits) q_values = tf.reduce_sum(self.support * probabilities, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def __call__(self, x, rng): if self.net_conf == 'minatar': x = x.squeeze(3) x = x.astype(jnp.float32) x = nn.Conv(features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) elif self.net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x.astype(jnp.float32) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((-1)) # flatten elif self.net_conf == 'classic': #classic environments x = x.astype(jnp.float32) x = x.reshape((-1)) if self.env is not None and self.env in env_inf: x = x - env_inf[self.env]['MIN_VALS'] x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] x = 2.0 * x - 1.0 if self.noisy: def net(x, features, rng): return NoisyNetwork(features, rng=rng, bias_in=True)(x) else: def net(x, features, rng): return nn.Dense(features, kernel_init=self.initzer)(x) for _ in range(self.hidden_layer): x = net(x, features=self.neurons, rng=rng) x = jax.nn.relu(x) if self.dueling: adv = net(x, features=self.num_actions * self.num_atoms, rng=rng) value = net(x, features=self.num_atoms, rng=rng) adv = adv.reshape((self.num_actions, self.num_atoms)) value = value.reshape((1, self.num_atoms)) #print('value:', value.shape) logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=1) else: x = net(x, features=self.num_actions * self.num_atoms, rng=rng) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=1) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, num_actions, net_conf, env, normalize_obs, noisy, dueling, num_atoms, hidden_layer=2, neurons=512): del normalize_obs if net_conf == 'minatar': x = x.squeeze(3) x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) elif net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten elif net_conf == 'classic': #classic environments x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) if env is not None: x = x - env_inf[env]['MIN_VALS'] x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS'] x = 2.0 * x - 1.0 if noisy: def net(x, features): return NoisyNetwork(x, features) else: def net(x, features): return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform()) for _ in range(hidden_layer): x = net(x, features=neurons) x = jax.nn.relu(x) if dueling: adv = net(x, features=num_actions * num_atoms) value = net(x, features=num_atoms) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) else: x = net(x, features=num_actions * num_atoms) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)