Ejemplo n.º 1
0
    def apply(
        self,
        x,
        action_dim,
        max_action,
        key=None,
        MPO=False,
        sample=False,
        log_sig_min=-20,
        log_sig_max=2,
    ):
        x = nn.Dense(x, features=200)
        x = nn.LayerNorm(x)
        x = nn.tanh(x)
        x = nn.Dense(x, features=200)
        x = nn.elu(x)
        x = nn.Dense(x, features=2 * action_dim)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = nn.softplus(log_sig)
        log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return max_action * nn.tanh(mu), log_sig
        else:
            pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig)
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1)
            return max_action * pi, log_pi
Ejemplo n.º 2
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
Ejemplo n.º 3
0
    def apply(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=1)

        q1 = nn.Dense(state_action, features=500)
        q1 = nn.LayerNorm(q1)
        q1 = nn.tanh(q1)
        q1 = nn.Dense(q1, features=500)
        q1 = nn.elu(q1)
        q1 = nn.Dense(q1, features=1)

        if Q1:
            return q1

        q2 = nn.Dense(state_action, features=500)
        q2 = nn.LayerNorm(q2)
        q2 = nn.tanh(q2)
        q2 = nn.Dense(q2, features=500)
        q2 = nn.elu(q2)
        q2 = nn.Dense(q2, features=1)

        return q1, q2
Ejemplo n.º 4
0
def PixelCNNPP(images, depth=5, features=160, k=10, dropout_p=.5):
    # Special convolutional and resnet blocks which allow information flow
    # downwards and to the right.
    ConvDown_ = ConvDown.partial(features=features)
    ConvDownRight_ = ConvDownRight.partial(features=features)

    ResDown_ = ResDown.partial(dropout_p=dropout_p)
    ResDownRight_ = ResDownRight.partial(dropout_p=dropout_p)

    # Conv Modules which halve or double the spatial dimensions
    HalveDown = ConvDown_.partial(strides=(2, 2))
    HalveDownRight = ConvDownRight_.partial(strides=(2, 2))

    DoubleDown = ConvTransposeDown.partial(features=features)
    DoubleDownRight = ConvTransposeDownRight.partial(features=features)

    # Add channel of ones to distinguish image from padding later on
    images = np.pad(images, ((0, 0), (0, 0), (0, 0), (0, 1)),
                    constant_values=1)

    # Stack of `(down, down_right)` pairs, where information flows downwards
    # through `down` and downwards and to the right through `down_right`.
    # We refer to the building of the stack as the 'forward pass' and the un-doing
    # of the stack as the 'reverse pass'.
    stack = []

    # -------------------------- FORWARD PASS ----------------------------------
    down = shift_down(ConvDown_(images, kernel_size=(2, 3)))
    down_right = (shift_down(ConvDown_(images, kernel_size=(1, 3))) +
                  shift_right(ConvDownRight_(images, kernel_size=(2, 1))))

    stack.append((down, down_right))
    for _ in range(depth):
        down, down_right = ResDown_(down), ResDownRight_(down_right, down)
        stack.append((down, down_right))

    # Resize spatial dims 32 x 32  -->  16 x 16
    down, down_right = HalveDown(down), HalveDownRight(down_right)
    stack.append((down, down_right))

    for _ in range(depth):
        down, down_right = ResDown_(down), ResDownRight_(down_right, down)
        stack.append((down, down_right))

    # Resize spatial dims 16 x 16  -->  8 x 8
    down, down_right = HalveDown(down), HalveDownRight(down_right)
    stack.append((down, down_right))

    for _ in range(depth):
        down, down_right = ResDown_(down), ResDownRight_(down_right, down)
        stack.append((down, down_right))

    # The stack now contains (in order from last appended):
    #
    #   Number of layers     Spatial dims
    #   depth + 1             8 x  8
    #   depth + 1            16 x 16
    #   depth + 1            32 x 32

    # -------------------------- REVERSE PASS ----------------------------------
    down, down_right = stack.pop()

    for _ in range(depth):
        down_fwd, down_right_fwd = stack.pop()
        down = ResDown_(down, down_fwd)
        down_right = ResDownRight_(down_right,
                                   np.concatenate((down, down_right_fwd), -1))

    # Resize spatial dims 8 x 8  -->  16 x 16
    down, down_right = DoubleDown(down), DoubleDownRight(down_right)

    for _ in range(depth + 1):
        down_fwd, down_right_fwd = stack.pop()
        down = ResDown_(down, down_fwd)
        down_right = ResDownRight_(down_right,
                                   np.concatenate((down, down_right_fwd), -1))

    # Resize spatial dims 16 x 16  -->  32 x 32
    down, down_right = DoubleDown(down), DoubleDownRight(down_right)

    for _ in range(depth + 1):
        down_fwd, down_right_fwd = stack.pop()
        down = ResDown_(down, down_fwd)
        down_right = ResDownRight_(down_right,
                                   np.concatenate((down, down_right_fwd), -1))

    assert len(stack) == 0

    # Note init_scale=0.1 on this layer was not in the original implementation,
    # but seems to make training more stable.
    return ConvOneByOne(nn.elu(down_right), 10 * k, init_scale=0.1)
Ejemplo n.º 5
0
def concat_elu(x):
    return nn.elu(np.concatenate((x, -x), -1))
Ejemplo n.º 6
0
def elu_feature_map(x):
    return nn.elu(x) + 1