def lstm_unit(self, prefix, x, cont, static=None, h=None, c=None, batch_size=100, timestep=0, lstm_hidden=1000, weight_filler=None, bias_filler=None, weight_lr_mult=1, bias_lr_mult=2, weight_decay_mult=1, bias_decay_mult=0, concat_hidden=True): #assume static is already transformed if not weight_filler: weight_filler = self.uniform_weight_filler(-0.08, 0.08) if not bias_filler: bias_filler = self.constant_filler(0) if not h: h = self.dummy_data_layer([1, batch_size, lstm_hidden], 1) if not c: c = self.dummy_data_layer([1, batch_size, lstm_hidden], 1) gate_dim=self.gate_dim def get_name(name): return '%s_%s' % (prefix, name) def get_param(weight_name, bias_name=None): w = dict(lr_mult=weight_lr_mult, decay_mult=weight_decay_mult, name=get_name(weight_name)) if bias_name is not None: b = dict(lr_mult=bias_lr_mult, decay_mult=bias_decay_mult, name=get_name(bias_name)) return [w, b] return [w] # gate_dim is the dimension of the cell state inputs: # 4 gates (i, f, o, g), each with dimension dim # Add layer to transform all timesteps of x to the hidden state dimension. # x_transform = W_xc * x + b_c x = L.InnerProduct(x, num_output=gate_dim, axis=2, weight_filler=weight_filler, bias_filler=bias_filler, param=get_param('W_xc', 'b_c')) setattr(self.n, get_name('%d_x_transform' %timestep), x) h_conted = L.Eltwise(h, cont, coeff_blob=True) h = L.InnerProduct(h_conted, num_output=gate_dim, axis=2, bias_term=False, weight_filler=weight_filler, param=get_param('W_hc')) h_name = get_name('%d_h_transform' %timestep) if not hasattr(self.n, h_name): setattr(self.n, h_name, h) gate_input_args = x, h if static is not None: gate_input_args += (static, ) gate_input = L.Eltwise(*gate_input_args) assert cont is not None c, h = L.LSTMUnit(c, gate_input, cont, ntop=2) return h, c
def add_rnn(n, data, act, clip, batch_size, T, K, num_step, lstm_dim=2048, mode='train'): add_lstm_init(n, batch_size, lstm_dim) n.clip_reshape = L.Reshape(clip, shape=dict(dim=[1, T, batch_size])) if mode is 'train' or mode is 'test_encode': clip_slice = L.Slice(n.clip_reshape, ntop=T, axis=1) if mode == 'train': act_slice = L.Slice(act, ntop=T - 1, axis=0) x = L.Slice(data, axis=0, ntop=T) x_set = () label_set = () silence_set = () for i in range(T): t = tag(i + 1) n.tops['clip' + t] = clip_slice[i] if mode == 'train': n.tops['x' + t] = x[i] if i < T - 1: n.tops['act' + t] = act_slice[i] if i < T - num_step: x_set = x_set + (x[i], ) if i < K - 1: silence_set += (act_slice[i], ) if i >= K: label_set = label_set + (x[i], ) if mode == 'train': n.x = L.Concat(*x_set, axis=0) n.label = L.Concat(*label_set, axis=0) add_lstm_encoder(n, n.x, batch_size, lstm_dim) else: add_lstm_encoder(n, data, batch_size, lstm_dim) if T > num_step: x_gate = L.Slice(n.x_gate, ntop=T - num_step, axis=0) if type(x_gate) is caffe.net_spec.Top: x_gate = (x_gate, ) else: x_gate = () for i in range(0, T): t_1 = tag(i) t = tag(i + 1) clip_t = n.tops[ 'clip' + t] if mode == 'train' or mode == 'test_encode' else n.clip_reshape n.tops['h_conted' + t_1] = eltwise(n.tops['h' + t_1], clip_t, P.Eltwise.SUM, True) # Decoding if i == T - num_step: if mode == 'train': h_set = () act_set = () for j in range(K, T - num_step + 1): t_j = tag(j) h_set = h_set + (n.tops['h_conted' + t_j], ) act_set = act_set + (n.tops['act' + t_j], ) n.h = L.Concat(*h_set, axis=0) n.act_concat = L.Concat(*act_set, axis=0) top = add_decoder(n, n.h, n.act_concat) else: top = add_decoder(n, n.tops['h_conted' + t_1], act) x_outs = L.Slice(top, axis=0, ntop=T - num_step - K + 1) if type(x_outs) is caffe.net_spec.Top: x_outs = [x_outs] for j in range(K, T - num_step + 1): n.tops['x_hat' + tag(j + 1)] = x_outs[j - K] dec_tag = tag(2) if mode == 'train' else '' if mode == 'test_decode': add_lstm_encoder(n, n.tops['x_hat' + t], batch_size, lstm_dim=lstm_dim, flatten=False) x_gate = x_gate + (n.tops['x_gate'], ) elif num_step > 1: add_lstm_encoder(n, n.tops['x_hat' + t], batch_size, lstm_dim=lstm_dim, t=t, tag=dec_tag, flatten=False) x_gate = x_gate + (n.tops['x_gate' + t], ) if i > T - num_step: dec_t = tag(i - T + num_step + 1) dec_tp = tag(i - T + num_step + 2) top = add_decoder(n, n.tops['h_conted' + t_1], n.tops['act' + t_1], tag=dec_t) n.tops['x_hat' + t] = top if i < T - 1: add_lstm_encoder(n, n.tops['x_hat' + t], batch_size, lstm_dim=lstm_dim, t=t, tag=dec_tp, flatten=False) x_gate = x_gate + (n.tops['x_gate' + t], ) if i < T - 1 or mode is not 'train': # H-1 to H if mode is not 'test_decode': n.tops['x_gate' + t] = x_gate[i] n.tops['h_gate' + t] = fc(n.tops['h_conted' + t_1], 4 * lstm_dim, weight_filler=dict(type='uniform', min=-0.08, max=0.08), param_name='Wh', axis=2, bias=False) n.tops['gate' + t] = eltwise(x_gate[i], n.tops['h_gate' + t], P.Eltwise.SUM) n.tops['c' + t], n.tops['h' + t] = L.LSTMUnit( n.tops['c' + t_1], n.tops['gate' + t], clip_t, ntop=2, clip_gradients=[0, 0.1, 0]) # Define Loss functions if mode == 'train': x_hat = () for i in range(K, T): t = tag(i + 1) x_hat = x_hat + (n.tops['x_hat' + t], ) silence_set += (n.tops['c' + tag(T - 1)], ) n.silence = L.Silence(*silence_set, ntop=0) n.x_hat = L.Concat(*x_hat, axis=0) n.label_flat = L.Flatten(n.label, axis=0, end_axis=1) n.l2_loss = L.EuclideanLoss(n.x_hat, n.label_flat) return n