예제 #1
0
    def __call__(self, x, key=None, sample=False, MPO=False):
        x = nn.Dense(features=200)(x)
        x = nn.LayerNorm()(x)
        x = nn.tanh(x)
        x = nn.Dense(features=200)(x)
        x = nn.elu(x)
        x = nn.Dense(features=2 * self.action_dim)(x)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = jnp.clip(log_sig, self.log_sig_min, self.log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return self.max_action * nn.tanh(mu), log_sig
        else:
            sig = jnp.exp(log_sig)
            pi = mu + random.normal(key, mu.shape) * sig
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(
                jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1, keepdims=True,
            )
            return self.max_action * pi, log_pi
예제 #2
0
 def __call__(self, hidden_states):
     hidden_states = nn.Dense(hidden_states.shape[-1],
                              name="dense",
                              dtype=self.dtype)(hidden_states)
     hidden_states = nn.elu(
         hidden_states)  # TODO: ACT2FN[config.hidden_act]
     return FlaxBertLayerNorm(name="LayerNorm",
                              dtype=self.dtype)(hidden_states)
예제 #3
0
    def __call__(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=-1)

        q1 = nn.Dense(features=500)(state_action)
        q1 = nn.LayerNorm()(q1)
        q1 = nn.tanh(q1)
        q1 = nn.Dense(features=500)(q1)
        q1 = nn.elu(q1)
        q1 = nn.Dense(features=1)(q1)

        if Q1:
            return q1

        q2 = nn.Dense(features=500)(state_action)
        q2 = nn.LayerNorm()(q2)
        q2 = nn.tanh(q2)
        q2 = nn.Dense(features=500)(q2)
        q2 = nn.elu(q2)
        q2 = nn.Dense(features=1)(q2)

        return q1, q2
예제 #4
0
 def __call__(self, inputs, deterministic=True):
     """Applies Transformer MlpBlock module."""
     cfg = self.config
     actual_out_dim = (inputs.shape[-1]
                       if self.out_dim is None else self.out_dim)
     x = nn.Dense(cfg.mlp_dim,
                  dtype=cfg.dtype,
                  kernel_init=cfg.kernel_init,
                  bias_init=cfg.bias_init)(inputs)
     x = nn.elu(x)
     x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
     output = nn.Dense(actual_out_dim,
                       dtype=cfg.dtype,
                       kernel_init=cfg.kernel_init,
                       bias_init=cfg.bias_init)(x)
     output = nn.Dropout(rate=cfg.dropout_rate)(output,
                                                deterministic=deterministic)
     return output
예제 #5
0
파일: pixelcnn.py 프로젝트: tokusumi/flax
    def __call__(self, images):
        # Special convolutional and resnet blocks which allow information flow
        # downwards and to the right.
        conv_down = partial(ConvDown, features=self.features)
        conv_down_right = partial(ConvDownRight, features=self.features)

        res_down = partial(ResDown, dropout_p=self.dropout_p)
        res_down_right = partial(ResDownRight, dropout_p=self.dropout_p)

        # Conv Modules which halve or double the spatial dimensions
        halve_down = partial(conv_down, strides=(2, 2))
        halve_down_right = partial(conv_down_right, strides=(2, 2))

        double_down = partial(ConvTransposeDown, features=self.features)
        double_down_right = partial(ConvTransposeDownRight,
                                    features=self.features)

        # Add channel of ones to distinguish image from padding later on
        images = jnp.pad(images, ((0, 0), (0, 0), (0, 0), (0, 1)),
                         constant_values=1)

        # Stack of `(down, down_right)` pairs, where information flows downwards
        # through `down` and downwards and to the right through `down_right`.
        # We refer to the building of the stack as the 'forward pass' and the
        # undoing of the stack as the 'reverse pass'.
        stack = []

        # -------------------------- FORWARD PASS ----------------------------------
        down = shift_down(conv_down(kernel_size=(2, 3))(images))
        down_right = (shift_down(conv_down(kernel_size=(1, 3))(images)) +
                      shift_right(conv_down_right(kernel_size=(2, 1))(images)))

        stack.append((down, down_right))
        for _ in range(self.depth):
            down, down_right = res_down()(down), res_down_right()(down_right,
                                                                  down)
            stack.append((down, down_right))

        # Resize spatial dims 32 x 32  -->  16 x 16
        down, down_right = halve_down()(down), halve_down_right()(down_right)
        stack.append((down, down_right))

        for _ in range(self.depth):
            down, down_right = res_down()(down), res_down_right()(down_right,
                                                                  down)
            stack.append((down, down_right))

        # Resize spatial dims 16 x 16  -->  8 x 8
        down, down_right = halve_down()(down), halve_down_right()(down_right)
        stack.append((down, down_right))

        for _ in range(self.depth):
            down, down_right = res_down()(down), res_down_right()(down_right,
                                                                  down)
            stack.append((down, down_right))

        # The stack now contains (in order from last appended):
        #
        #   Number of layers     Spatial dims
        #   depth + 1             8 x  8
        #   depth + 1            16 x 16
        #   depth + 1            32 x 32

        # -------------------------- REVERSE PASS ----------------------------------
        down, down_right = stack.pop()

        for _ in range(self.depth):
            down_fwd, down_right_fwd = stack.pop()
            down = res_down()(down, down_fwd)
            down_right = res_down_right()(down_right,
                                          jnp.concatenate(
                                              (down, down_right_fwd), -1))

        # Resize spatial dims 8 x 8  -->  16 x 16
        down, down_right = double_down()(down), double_down_right()(down_right)

        for _ in range(self.depth + 1):
            down_fwd, down_right_fwd = stack.pop()
            down = res_down()(down, down_fwd)
            down_right = res_down_right()(down_right,
                                          jnp.concatenate(
                                              (down, down_right_fwd), -1))

        # Resize spatial dims 16 x 16  -->  32 x 32
        down, down_right = double_down()(down), double_down_right()(down_right)

        for _ in range(self.depth + 1):
            down_fwd, down_right_fwd = stack.pop()
            down = res_down()(down, down_fwd)
            down_right = res_down_right()(down_right,
                                          jnp.concatenate(
                                              (down, down_right_fwd), -1))

        assert not stack

        # Note init_scale=0.1 on this layer was not in the original implementation,
        # but seems to make training more stable.
        return ConvOneByOne(10 * self.logistic_components,
                            init_scale=0.1)(nn.elu(down_right))
예제 #6
0
파일: pixelcnn.py 프로젝트: tokusumi/flax
def concat_elu(x):
    return nn.elu(jnp.concatenate((x, -x), -1))