def gru(dh, x): dhs = Sdh(dh) # previous value, stabilized # note: input does not get a stabilizer here, user is meant to do that outside # projected contribution from input(s), hidden, and bias projx3 = b + times(x, W) projh2 = times(dhs, H) zt_proj = slice (projx3, stack_axis, 0*stacked_dim, 1*stacked_dim) + slice (projh2, stack_axis, 0*stacked_dim, 1*stacked_dim) rt_proj = slice (projx3, stack_axis, 1*stacked_dim, 2*stacked_dim) + slice (projh2, stack_axis, 1*stacked_dim, 2*stacked_dim) ct_proj = slice (projx3, stack_axis, 2*stacked_dim, 3*stacked_dim) zt = sigmoid (zt_proj) # update gate z(t) rt = sigmoid (rt_proj) # reset gate r(t) rs = dhs * rt # "cell" c ct = activation (ct_proj + times(rs, H1)) ht = (1 - zt) * ct + zt * dhs # hidden state ht / output # for comparison: CUDNN_GRU # i(t) = sigmoid(W_i x(t) + R_i h(t-1) + b_Wi + b_Ru) # r(t) = sigmoid(W_r x(t) + R_r h(t-1) + b_Wr + b_Rr) --same up to here # h'(t) = tanh(W_h x(t) + r(t) .* (R_h h(t-1)) + b_Wh + b_Rh) --r applied after projection? Would make life easier! # h(t) = (1 - i(t) .* h'(t)) + i(t) .* h(t-1) --TODO: need to confirm bracketing with NVIDIA h = times(Sht(ht), Wmr) if has_projection else \ ht # returns the new state as a tuple with names but order matters return Function.NamedOutput(h=h)
def lstm(dh, dc, x): dhs = Sdh(dh) # previous values, stabilized dcs = Sdc(dc) # note: input does not get a stabilizer here, user is meant to do that outside # projected contribution from input(s), hidden, and bias proj4 = b + times(x, W) + times(dhs, H) it_proj = slice(proj4, stack_axis, 0 * stacked_dim, 1 * stacked_dim) # split along stack_axis bit_proj = slice(proj4, stack_axis, 1 * stacked_dim, 2 * stacked_dim) ft_proj = slice(proj4, stack_axis, 2 * stacked_dim, 3 * stacked_dim) ot_proj = slice(proj4, stack_axis, 3 * stacked_dim, 4 * stacked_dim) # helper to inject peephole connection if requested def peep(x, c, C): return x + C * c if use_peepholes else x it = sigmoid(peep(it_proj, dcs, Ci)) # input gate(t) # TODO: should both activations be replaced? bit = it * activation(bit_proj) # applied to tanh of input network ft = sigmoid(peep(ft_proj, dcs, Cf)) # forget-me-not gate(t) bft = ft * dc # applied to cell(t-1) ct = bft + bit # c(t) is sum of both ot = sigmoid(peep(ot_proj, Sct(ct), Co)) # output gate(t) ht = ot * activation(ct) # applied to tanh(cell(t)) c = ct # cell value h = times(Sht(ht), Wmr) if has_projection else \ ht # returns the new state as a tuple with names but order matters return (Function.NamedOutput(h=h), Function.NamedOutput(c=c))
def rnn(dh, x): dhs = Sdh(dh) # previous value, stabilized ht = activation (times(x, W) + times(dhs, H) + b) h = times(Sht(ht), Wmr) if has_projection else \ ht return Function.NamedOutput(h=h)