コード例 #1
0
def lstm_noaf_cell(units: int,
                   peephole: bool = True,
                   kernel_activation: str = "tanh",
                   recurrent_activation: str = "hard_sigmoid",
                   kernel_initializer: str = "glorot_uniform",
                   recurrent_initializer: str = "orthogonal",
                   bias_initializer: str = "zeros",
                   use_bias: bool = True,
                   **kwargs) -> RNNCellBuilder:

    cell = RNNCellBuilder(units, ["h", "C"], kernel_activation, recurrent_activation,
                          kernel_initializer, recurrent_initializer, bias_initializer, use_bias, **kwargs)

    # h_next remove activation
    if peephole:
        return cell \
            .add_recurrent("input", ["X", "h", "C"]) \
            .add_recurrent("forget", ["X", "h", "C"]) \
            .add_recurrent("output", ["X", "h", "C"]) \
            .add_kernel("cell", ["X", "h"]) \
            .add_var("C_next", ["forget", "C", "input", "cell"],
                     lambda x: cell.recurrent_activation(x[0] * x[1] + x[2] * x[3])) \
            .add_var("h_next", ["C_next", "output"], lambda x: x[0] * x[1])

    return cell \
        .add_recurrent("input", ["X", "h"]) \
        .add_recurrent("forget", ["X", "h"]) \
        .add_recurrent("output", ["X", "h"]) \
        .add_kernel("cell", ["X", "h"]) \
        .add_var("C_next", ["forget", "C", "input", "cell"],
                 lambda x: cell.recurrent_activation(x[0] * x[1] + x[2] * x[3])) \
        .add_var("h_next", ["C_next", "output"], lambda x: x[0] * x[1])
コード例 #2
0
def gru_cell(units: int,
             kernel_activation: str = "tanh",
             recurrent_activation: str = "hard_sigmoid",
             kernel_initializer: str = "glorot_uniform",
             recurrent_initializer: str = "orthogonal",
             bias_initializer: str = "zeros",
             use_bias: bool = True,
             **kwargs) -> RNNCellBuilder:

    cell = RNNCellBuilder(units, ["h"], kernel_activation,
                          recurrent_activation, kernel_initializer,
                          recurrent_initializer, bias_initializer, use_bias,
                          **kwargs)

    return cell \
        .add_recurrent("update", ["X", "h"]) \
        .add_recurrent("reset", ["X", "h"]) \
        .add_kernel("output", ["X", "reset"]) \
        .add_var("h_next", ["update", "h", "output"], lambda x: (1 - x[0]) * x[1] + x[0] * x[2])