def __call__(self, x, key=None, sample=False, MPO=False): x = nn.Dense(features=200)(x) x = nn.LayerNorm()(x) x = nn.tanh(x) x = nn.Dense(features=200)(x) x = nn.elu(x) x = nn.Dense(features=2 * self.action_dim)(x) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = jnp.clip(log_sig, self.log_sig_min, self.log_sig_max) if MPO: return mu, log_sig if not sample: return self.max_action * nn.tanh(mu), log_sig else: sig = jnp.exp(log_sig) pi = mu + random.normal(key, mu.shape) * sig log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum( jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1, keepdims=True, ) return self.max_action * pi, log_pi
def __call__(self, hidden_states): hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states) hidden_states = nn.elu( hidden_states) # TODO: ACT2FN[config.hidden_act] return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
def __call__(self, state, action, Q1=False): state_action = jnp.concatenate([state, action], axis=-1) q1 = nn.Dense(features=500)(state_action) q1 = nn.LayerNorm()(q1) q1 = nn.tanh(q1) q1 = nn.Dense(features=500)(q1) q1 = nn.elu(q1) q1 = nn.Dense(features=1)(q1) if Q1: return q1 q2 = nn.Dense(features=500)(state_action) q2 = nn.LayerNorm()(q2) q2 = nn.tanh(q2) q2 = nn.Dense(features=500)(q2) q2 = nn.elu(q2) q2 = nn.Dense(features=1)(q2) return q1, q2
def __call__(self, inputs, deterministic=True): """Applies Transformer MlpBlock module.""" cfg = self.config actual_out_dim = (inputs.shape[-1] if self.out_dim is None else self.out_dim) x = nn.Dense(cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(inputs) x = nn.elu(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) output = nn.Dense(actual_out_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(x) output = nn.Dropout(rate=cfg.dropout_rate)(output, deterministic=deterministic) return output
def __call__(self, images): # Special convolutional and resnet blocks which allow information flow # downwards and to the right. conv_down = partial(ConvDown, features=self.features) conv_down_right = partial(ConvDownRight, features=self.features) res_down = partial(ResDown, dropout_p=self.dropout_p) res_down_right = partial(ResDownRight, dropout_p=self.dropout_p) # Conv Modules which halve or double the spatial dimensions halve_down = partial(conv_down, strides=(2, 2)) halve_down_right = partial(conv_down_right, strides=(2, 2)) double_down = partial(ConvTransposeDown, features=self.features) double_down_right = partial(ConvTransposeDownRight, features=self.features) # Add channel of ones to distinguish image from padding later on images = jnp.pad(images, ((0, 0), (0, 0), (0, 0), (0, 1)), constant_values=1) # Stack of `(down, down_right)` pairs, where information flows downwards # through `down` and downwards and to the right through `down_right`. # We refer to the building of the stack as the 'forward pass' and the # undoing of the stack as the 'reverse pass'. stack = [] # -------------------------- FORWARD PASS ---------------------------------- down = shift_down(conv_down(kernel_size=(2, 3))(images)) down_right = (shift_down(conv_down(kernel_size=(1, 3))(images)) + shift_right(conv_down_right(kernel_size=(2, 1))(images))) stack.append((down, down_right)) for _ in range(self.depth): down, down_right = res_down()(down), res_down_right()(down_right, down) stack.append((down, down_right)) # Resize spatial dims 32 x 32 --> 16 x 16 down, down_right = halve_down()(down), halve_down_right()(down_right) stack.append((down, down_right)) for _ in range(self.depth): down, down_right = res_down()(down), res_down_right()(down_right, down) stack.append((down, down_right)) # Resize spatial dims 16 x 16 --> 8 x 8 down, down_right = halve_down()(down), halve_down_right()(down_right) stack.append((down, down_right)) for _ in range(self.depth): down, down_right = res_down()(down), res_down_right()(down_right, down) stack.append((down, down_right)) # The stack now contains (in order from last appended): # # Number of layers Spatial dims # depth + 1 8 x 8 # depth + 1 16 x 16 # depth + 1 32 x 32 # -------------------------- REVERSE PASS ---------------------------------- down, down_right = stack.pop() for _ in range(self.depth): down_fwd, down_right_fwd = stack.pop() down = res_down()(down, down_fwd) down_right = res_down_right()(down_right, jnp.concatenate( (down, down_right_fwd), -1)) # Resize spatial dims 8 x 8 --> 16 x 16 down, down_right = double_down()(down), double_down_right()(down_right) for _ in range(self.depth + 1): down_fwd, down_right_fwd = stack.pop() down = res_down()(down, down_fwd) down_right = res_down_right()(down_right, jnp.concatenate( (down, down_right_fwd), -1)) # Resize spatial dims 16 x 16 --> 32 x 32 down, down_right = double_down()(down), double_down_right()(down_right) for _ in range(self.depth + 1): down_fwd, down_right_fwd = stack.pop() down = res_down()(down, down_fwd) down_right = res_down_right()(down_right, jnp.concatenate( (down, down_right_fwd), -1)) assert not stack # Note init_scale=0.1 on this layer was not in the original implementation, # but seems to make training more stable. return ConvOneByOne(10 * self.logistic_components, init_scale=0.1)(nn.elu(down_right))
def concat_elu(x): return nn.elu(jnp.concatenate((x, -x), -1))