Example #1
0
def fusion_gru(
        x,  # T x M
        lod,  # 1 x N
        h0,  # N x D
        wx,  # M x 3D
        wh,  # D x 3D
        bias,  # 1 x 3D
        is_reverse,
        act_state,
        act_gate):
    return gru(fc(x, wx, bias), lod, h0, wh,
               np.zeros((1, wh.shape[1]), dtype='float32'), is_reverse,
               act_state, act_gate)
Example #2
0
def fusion_seqexpand_concat_fc(xs, lod, w, b, fc_act):

    T = sum(lod[0])
    N = len(lod[0])
    num_inputs = len(xs)
    D = w.shape[1]

    expanded_inputs = [xs[0]]
    for i in range(num_inputs - 1):
        x = xs[i + 1]
        assert x.shape[0] == N
        expanded = np.repeat(x, lod[0], axis=0)
        assert expanded.shape[0] == T
        assert expanded.shape[1] == x.shape[1]
        expanded_inputs.append(expanded)

    fc_input = np.concatenate(expanded_inputs, axis=1)
    assert fc_input.shape[0] == T
    assert fc_input.shape[1] == w.shape[0]
    fc_out = fc(fc_input, w, b)
    fc_out = fc_act(fc_out)
    assert fc_out.shape[0] == T
    assert fc_out.shape[1] == D
    return fc_out
Example #3
0
def attention_lstm(
        x,  # T x M
        lod,  # 1 x N
        h0,  # N x D
        c0,  # N x D
        fcws,  # (M+D) x 1, 1x1
        fcbs,  # 1 x 1, 1x1
        w,  # (M+D) x 4D
        b,  # 1 x 4D
        act_gate,
        act_cell,
        act_cand):

    T = sum(lod[0])
    N = len(lod[0])
    M = x.shape[1]
    D = b.shape[1] // 4
    assert T == x.shape[0]
    assert len(fcws) == len(fcbs)
    hidden = []
    cell = []

    start_offset = 0
    for bid in range(N):
        seq_len = lod[0][bid]
        xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(
            seq_len, M)
        prev_cell = np.copy(c0[bid]).reshape([1, D])
        prev_hidden = np.copy(h0[bid]).reshape([1, D])
        for step in range(seq_len):
            expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
            tmp = np.concatenate((xi, expanded_cell), axis=1)
            assert tmp.shape[0] == seq_len
            assert tmp.shape[1] == M + D
            for fcid in range(len(fcbs)):
                tmp = fc(tmp, fcws[fcid], fcbs[fcid])
                tmp = ACTIVATION['relu'](tmp)
            tmp = np.reshape(tmp, (1, seq_len))
            tmp = stable_softmax(tmp).reshape(seq_len, 1)
            lstmx = xi * tmp  # seq * M
            lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
            lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
            lstmout = fc(lstmin, w, b).reshape([1, 4 * D])

            g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
            g_f = act_gate(g_f).reshape([1, D])
            g_i = act_gate(g_i).reshape([1, D])
            g_o = act_gate(g_o).reshape([1, D])
            cand = act_cand(cand).reshape([1, D])

            cell_t = (prev_cell * g_f) + (g_i * cand)
            hidden_t = g_o * act_cell(cell_t)

            hidden.append(hidden_t.flatten())
            cell.append(cell_t.flatten())

            prev_cell = cell_t.reshape([1, D])
            prev_hidden = hidden_t.reshape([1, D])

        start_offset += seq_len

    hidden = np.array(hidden).astype('float32').reshape([T, D])
    cell = np.array(cell).astype('float32').reshape([T, D])
    return hidden, cell