def lstm_layer(num_timesteps: int, num_inputs: int, num_units: int, init=glorot_uniform, bias_init=sym.zeros): """ Create a single cell and replicate it `num_timesteps` times for training. Return X,[(batch_size,num_classes) x num_timesteps] """ X = Variable("X", shape=(_batch_size, num_timesteps, num_inputs), dtype='float32') U_shape = (num_inputs + num_units, num_units) b_shape = (1, num_units) Ug = Variable("Ug", init=init(U_shape)) bg = Variable("bg", init=bias_init(shape=b_shape)) Ui = Variable("Ui", init=init(U_shape)) bi = Variable("bi", init=bias_init(shape=b_shape)) Uf = Variable("Uf", init=init(U_shape)) bf = Variable("bf", init=bias_init(shape=b_shape) + sym.ones(shape=b_shape)) Uo = Variable("Uo", init=init(U_shape)) bo = Variable("bo", init=bias_init(shape=b_shape)) def cell(x_t, s_t, h_t): xh_t = sym.concatenate(x_t, h_t, axis=1) g = lstm_gate(sym.tanh, Ug, bg, xh_t, num_units) i = lstm_gate(sym.sigmoid, Ui, bi, xh_t, num_units) f = lstm_gate(sym.sigmoid, Uf, bf, xh_t, num_units) o = lstm_gate(sym.sigmoid, Uo, bo, xh_t, num_units) s_t1 = s_t * f + g * i h_t1 = sym.tanh(s_t1) * o return (s_t1, h_t1) xs = sym.split(X, indices_or_sections=num_timesteps, axis=1) xs = [sym.squeeze(x, axis=1) for x in xs] # in TF: # batch_size = sym.shape(X)[0] # s_shape = sym.stack([batch_size, num_units], name="s_shape") # # s = sym.zeros(s_shape, dtype=np.float32) if num_units > num_inputs: s_like = sym.pad(xs[0], pad_width=((0, 0), (0, num_units-num_inputs))) else: s_like = xs[0][:, 0:num_units] # TODO untested s = sym.zeros_like(s_like) h = s outputs = [] for x in xs: # x = sym.squeeze(x, axis=1) s, h = cell(x, s, h) outputs.append(h) return X, outputs
def _get_model(dshape): data = sym.Variable('data', shape=dshape) fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) left, right = sym.split(fc1, indices_or_sections=2, axis=1) return sym.Group(((left + 1), (right - 1)))
def _get_model(dshape): data = sym.Variable('data', shape=dshape) fc1 = sym.dense(data, units=dshape[-1] * 2, use_bias=True) left, right = sym.split(fc1, indices_or_sections=2, axis=1) return sym.Group(((left + 1), (right - 1)))