def func_boxspace(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) mu = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) logvar = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) return {'mu': mu(S), 'logvar': logvar(S)}
def func(S, is_training): flatten = hk.Flatten() batch_norm_m = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_v = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_m = partial(batch_norm_m, is_training=is_training) batch_norm_v = partial(batch_norm_v, is_training=is_training) mu = hk.Sequential(( hk.Linear(7), batch_norm_m, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(self.env_boxspace.action_space.shape)), hk.Reshape(self.env_boxspace.action_space.shape), )) logvar = hk.Sequential(( hk.Linear(7), batch_norm_v, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(self.env_boxspace.action_space.shape)), hk.Reshape(self.env_boxspace.action_space.shape), )) return {'mu': mu(flatten(S)), 'logvar': logvar(flatten(S))}
def func(S, is_training): env = self.env_discrete output_shape = (env.action_space.n, *env.observation_space.shape) flatten = hk.Flatten() batch_norm_m = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_v = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_m = partial(batch_norm_m, is_training=is_training) batch_norm_v = partial(batch_norm_v, is_training=is_training) mu = hk.Sequential(( hk.Linear(7), batch_norm_m, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(output_shape)), hk.Reshape(output_shape), )) logvar = hk.Sequential(( hk.Linear(7), batch_norm_v, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(output_shape)), hk.Reshape(output_shape), )) X = flatten(S) return {'mu': mu(X), 'logvar': logvar(X)}
def __init__(self, C, position_enc_fn, name=None): super().__init__(name=name) he_init = hk.initializers.VarianceScaling(scale=2.0) channels = C['encoder_cnn_channels'] kernels = C['encoder_cnn_kernels'] strides = C['encoder_cnn_strides'] hidden_size = channels[-1] self.cnn_layers = hk.Sequential([ hk.Conv2D(channels[0], kernels[0], stride=strides[0], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(channels[1], kernels[1], stride=strides[1], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(channels[2], kernels[2], stride=strides[2], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(hidden_size, kernels[3], stride=strides[3], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, ]) self.pos_embed = SoftPositionEmbed(hidden_size, C['hidden_res'], position_enc_fn) self.linears = hk.Sequential([ # i.e. 1x1 convolution (shared 32 neurons across all locations) hk.Reshape((-1, hidden_size)), # Flatten spatial dim (works with batch) hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), hk.Linear(32, w_init=he_init), jax.nn.relu, hk.Linear(32, w_init=he_init), ])
def q_net(obs): layers_ = tuple(layers) + (onp.prod(output_shape), ) if use_noisy_network: network = NoisyMLP(layers_, factorized_noise=use_factorized_noise) else: network = hk.nets.MLP(layers_) return hk.Reshape(output_shape=output_shape)(network(obs))
def func_discrete_type2(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential( (hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n * discrete.n), hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax)) return seq(S)
def func(S, is_training): flatten = hk.Flatten() batch_norm = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm = partial(batch_norm, is_training=is_training) seq = hk.Sequential( (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(self.env_discrete.action_space.n * 51), hk.Reshape((self.env_discrete.action_space.n, 51)))) return seq(flatten(S))
def func_pi(S, is_training): seq = hk.Sequential(( hk.Linear(8), jax.nn.relu, hk.Linear(8), jax.nn.relu, hk.Linear(8), jax.nn.relu, hk.Linear(prod(env.action_space.shape) * 2, w_init=jnp.zeros), hk.Reshape((*env.action_space.shape, 2)), )) x = seq(S) mu, logvar = x[..., 0], x[..., 1] return {'mu': mu, 'logvar': logvar}
def make_conditioner(event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int) -> hk.Sequential: """Creates an MLP conditioner for each layer of the flow.""" return hk.Sequential([ hk.Flatten(), hk.nets.MLP(hidden_sizes, activate_final=True), # We initialize this linear layer to zero so that the flow is initialized # to the identity function. hk.Linear(np.prod(event_shape) * num_bijector_params, w_init=jnp.zeros, b_init=jnp.zeros), hk.Reshape(tuple(event_shape) + (num_bijector_params, )), ])
def func_pi(S, is_training): shared = hk.Sequential(( hk.Linear(8), jax.nn.relu, hk.Linear(8), jax.nn.relu, )) mu = hk.Sequential(( shared, hk.Linear(8), jax.nn.relu, hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros), hk.Reshape(env.action_space.shape), )) logvar = hk.Sequential(( shared, hk.Linear(8), jax.nn.relu, hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros), hk.Reshape(env.action_space.shape), )) return {'mu': mu(S), 'logvar': logvar(S)}
def func_boxspace_type1(S, A, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) X = jax.vmap(jnp.kron)(S, A) return seq(X)
def func_type2(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) logits = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n * num_bins), hk.Reshape((discrete.n, num_bins)), )) return {'logits': logits(S)}
def func_pi(S, is_training): seq = hk.Sequential(( hk.Linear(8), jax.nn.relu, hk.Linear(8), jax.nn.relu, hk.Linear(8), jax.nn.relu, hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros), hk.Reshape(env.action_space.shape), )) mu = seq(S) return { 'mu': mu, 'logvar': jnp.full_like(mu, -10) } # (almost) deterministic
def func_quantile_type2(S, is_training): """ type-1 q-function: (s,a) -> q(s,a) """ encoder = hk.Sequential((hk.Flatten(), hk.Linear(8), jax.nn.relu)) quantile_fractions = quantiles_uniform( rng=hk.next_rng_key(), batch_size=jax.tree_leaves(S)[0].shape[0], num_quantiles=num_bins) x = encoder(S) quantile_x = quantile_net(x, quantile_fractions=quantile_fractions) quantile_values = hk.Sequential( (hk.Linear(discrete.n), hk.Reshape( (discrete.n, num_bins))))(quantile_x) return { 'values': quantile_values, 'quantile_fractions': quantile_fractions[:, None, :].tile([1, discrete.n, 1]) }
def __init__( self, latent_size: int, hidden_size: int, output_shape: Sequence[int] = MNIST_IMAGE_SHAPE, ): super().__init__(name="model") self._latent_size = latent_size self._hidden_size = hidden_size self._output_shape = output_shape self.generative_network = hk.Sequential([ hk.Linear(self._hidden_size), jax.nn.relu, hk.Linear(self._hidden_size), jax.nn.relu, hk.Linear(np.prod(self._output_shape)), hk.Reshape(self._output_shape, preserve_dims=2), ])
def pi(S, is_training): rng1, rng2, rng3 = hk.next_rng_keys(3) shape = env.action_space.shape rate = hparams.dropout_actor * is_training seq = hk.Sequential(( hk.Linear(hparams.h1_actor), jax.nn.relu, partial(hk.dropout, rng1, rate), hk.Linear(hparams.h2_actor), jax.nn.relu, partial(hk.dropout, rng2, rate), hk.Linear(hparams.h3_actor), jax.nn.relu, partial(hk.dropout, rng3, rate), hk.Linear(onp.prod(shape)), hk.Reshape(shape), # lambda x: low + (high - low) * jax.nn.sigmoid(x), # disable: BoxActionsToReals )) return seq(S) # batch of actions
def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: logits = self._linear(inputs) if not isinstance(self._logit_shape, int): logits = hk.Reshape(self._logit_shape)(logits) return tfd.Categorical(logits=logits, dtype=self._dtype)
def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state)
def q_net(obs): layers_ = tuple(layers) + (onp.prod(output_shape), ) network = NoisyMLP(layers_) return hk.Reshape(output_shape=output_shape)(network(obs))