コード例 #1
0
ファイル: LSTM.py プロジェクト: aniltrue/MLProject
def lstm_cell(units: int,
              peephole: bool = True,
              kernel_activation: str = "tanh",
              recurrent_activation: str = "hard_sigmoid",
              kernel_initializer: str = "glorot_uniform",
              recurrent_initializer: str = "orthogonal",
              bias_initializer: str = "zeros",
              use_bias: bool = True,
              **kwargs) -> RNNCellBuilder:

    cell = RNNCellBuilder(units, ["h", "C"], kernel_activation,
                          recurrent_activation, kernel_initializer,
                          recurrent_initializer, bias_initializer, use_bias,
                          **kwargs)

    if peephole:
        return cell \
            .add_recurrent("input", ["X", "h", "C"]) \
            .add_recurrent("forget", ["X", "h", "C"]) \
            .add_recurrent("output", ["X", "h", "C"]) \
            .add_kernel("cell", ["X", "h"]) \
            .add_var("C_next", ["forget", "C", "input", "cell"],
                     lambda x: cell.recurrent_activation(x[0] * x[1] + x[2] * x[3])) \
            .add_var("h_next", ["C_next", "output"], lambda x: cell.kernel_activation(x[0]) * x[1])

    return cell \
        .add_recurrent("input", ["X", "h"]) \
        .add_recurrent("forget", ["X", "h"]) \
        .add_recurrent("output", ["X", "h"]) \
        .add_kernel("cell", ["X", "h"]) \
        .add_var("C_next", ["forget", "C", "input", "cell"],
                 lambda x: cell.recurrent_activation(x[0] * x[1] + x[2] * x[3])) \
        .add_var("h_next", ["C_next", "output"], lambda x: cell.kernel_activation(x[0]) * x[1])
コード例 #2
0
def gru_cell(units: int,
             kernel_activation: str = "tanh",
             recurrent_activation: str = "hard_sigmoid",
             kernel_initializer: str = "glorot_uniform",
             recurrent_initializer: str = "orthogonal",
             bias_initializer: str = "zeros",
             use_bias: bool = True,
             **kwargs) -> RNNCellBuilder:

    cell = RNNCellBuilder(units, ["h"], kernel_activation,
                          recurrent_activation, kernel_initializer,
                          recurrent_initializer, bias_initializer, use_bias,
                          **kwargs)

    return cell \
        .add_recurrent("update", ["X", "h"]) \
        .add_recurrent("reset", ["X", "h"]) \
        .add_kernel("output", ["X", "reset"]) \
        .add_var("h_next", ["update", "h", "output"], lambda x: (1 - x[0]) * x[1] + x[0] * x[2])