Ejemplo n.º 1
0
  def lstm_unit(self, prefix, x, cont, static=None, h=None, c=None,
         batch_size=100, timestep=0, lstm_hidden=1000,
         weight_filler=None, bias_filler=None,
         weight_lr_mult=1, bias_lr_mult=2,
         weight_decay_mult=1, bias_decay_mult=0, concat_hidden=True):

    #assume static is already transformed
    if not weight_filler:
      weight_filler = self.uniform_weight_filler(-0.08, 0.08)
    if not bias_filler:
      bias_filler = self.constant_filler(0)
    if not h:
      h = self.dummy_data_layer([1, batch_size, lstm_hidden], 1)
    if not c:
      c = self.dummy_data_layer([1, batch_size, lstm_hidden], 1)
    gate_dim=self.gate_dim

    def get_name(name):
        return '%s_%s' % (prefix, name)
    def get_param(weight_name, bias_name=None):
        w = dict(lr_mult=weight_lr_mult, decay_mult=weight_decay_mult,
                 name=get_name(weight_name))
        if bias_name is not None:
            b = dict(lr_mult=bias_lr_mult, decay_mult=bias_decay_mult,
                     name=get_name(bias_name))
            return [w, b]
        return [w]
    # gate_dim is the dimension of the cell state inputs:
    # 4 gates (i, f, o, g), each with dimension dim
    # Add layer to transform all timesteps of x to the hidden state dimension.
    #     x_transform = W_xc * x + b_c
    x = L.InnerProduct(x, num_output=gate_dim, axis=2,
        weight_filler=weight_filler, bias_filler=bias_filler,
        param=get_param('W_xc', 'b_c'))
    setattr(self.n, get_name('%d_x_transform' %timestep), x)
    h_conted = L.Eltwise(h, cont, coeff_blob=True) 
    h = L.InnerProduct(h_conted, num_output=gate_dim, axis=2, bias_term=False,
        weight_filler=weight_filler, param=get_param('W_hc'))
    h_name = get_name('%d_h_transform' %timestep)
    if not hasattr(self.n, h_name):
        setattr(self.n, h_name, h)
    gate_input_args = x, h
    if static is not None:
        gate_input_args += (static, )
    gate_input = L.Eltwise(*gate_input_args)
    assert cont is not None
    c, h = L.LSTMUnit(c, gate_input, cont, ntop=2)
    return h, c 
def add_rnn(n,
            data,
            act,
            clip,
            batch_size,
            T,
            K,
            num_step,
            lstm_dim=2048,
            mode='train'):
    add_lstm_init(n, batch_size, lstm_dim)
    n.clip_reshape = L.Reshape(clip, shape=dict(dim=[1, T, batch_size]))
    if mode is 'train' or mode is 'test_encode':
        clip_slice = L.Slice(n.clip_reshape, ntop=T, axis=1)
        if mode == 'train':
            act_slice = L.Slice(act, ntop=T - 1, axis=0)
            x = L.Slice(data, axis=0, ntop=T)
            x_set = ()
            label_set = ()
            silence_set = ()
        for i in range(T):
            t = tag(i + 1)
            n.tops['clip' + t] = clip_slice[i]
            if mode == 'train':
                n.tops['x' + t] = x[i]
                if i < T - 1:
                    n.tops['act' + t] = act_slice[i]
                if i < T - num_step:
                    x_set = x_set + (x[i], )
                if i < K - 1:
                    silence_set += (act_slice[i], )
                if i >= K:
                    label_set = label_set + (x[i], )
        if mode == 'train':
            n.x = L.Concat(*x_set, axis=0)
            n.label = L.Concat(*label_set, axis=0)
            add_lstm_encoder(n, n.x, batch_size, lstm_dim)
        else:
            add_lstm_encoder(n, data, batch_size, lstm_dim)
    if T > num_step:
        x_gate = L.Slice(n.x_gate, ntop=T - num_step, axis=0)
        if type(x_gate) is caffe.net_spec.Top:
            x_gate = (x_gate, )
    else:
        x_gate = ()

    for i in range(0, T):
        t_1 = tag(i)
        t = tag(i + 1)

        clip_t = n.tops[
            'clip' +
            t] if mode == 'train' or mode == 'test_encode' else n.clip_reshape
        n.tops['h_conted' + t_1] = eltwise(n.tops['h' + t_1], clip_t,
                                           P.Eltwise.SUM, True)
        # Decoding
        if i == T - num_step:
            if mode == 'train':
                h_set = ()
                act_set = ()
                for j in range(K, T - num_step + 1):
                    t_j = tag(j)
                    h_set = h_set + (n.tops['h_conted' + t_j], )
                    act_set = act_set + (n.tops['act' + t_j], )
                n.h = L.Concat(*h_set, axis=0)
                n.act_concat = L.Concat(*act_set, axis=0)
                top = add_decoder(n, n.h, n.act_concat)
            else:
                top = add_decoder(n, n.tops['h_conted' + t_1], act)
            x_outs = L.Slice(top, axis=0, ntop=T - num_step - K + 1)
            if type(x_outs) is caffe.net_spec.Top:
                x_outs = [x_outs]
            for j in range(K, T - num_step + 1):
                n.tops['x_hat' + tag(j + 1)] = x_outs[j - K]
            dec_tag = tag(2) if mode == 'train' else ''
            if mode == 'test_decode':
                add_lstm_encoder(n,
                                 n.tops['x_hat' + t],
                                 batch_size,
                                 lstm_dim=lstm_dim,
                                 flatten=False)
                x_gate = x_gate + (n.tops['x_gate'], )
            elif num_step > 1:
                add_lstm_encoder(n,
                                 n.tops['x_hat' + t],
                                 batch_size,
                                 lstm_dim=lstm_dim,
                                 t=t,
                                 tag=dec_tag,
                                 flatten=False)
                x_gate = x_gate + (n.tops['x_gate' + t], )

        if i > T - num_step:
            dec_t = tag(i - T + num_step + 1)
            dec_tp = tag(i - T + num_step + 2)
            top = add_decoder(n,
                              n.tops['h_conted' + t_1],
                              n.tops['act' + t_1],
                              tag=dec_t)
            n.tops['x_hat' + t] = top
            if i < T - 1:
                add_lstm_encoder(n,
                                 n.tops['x_hat' + t],
                                 batch_size,
                                 lstm_dim=lstm_dim,
                                 t=t,
                                 tag=dec_tp,
                                 flatten=False)
                x_gate = x_gate + (n.tops['x_gate' + t], )

        if i < T - 1 or mode is not 'train':
            # H-1 to H
            if mode is not 'test_decode':
                n.tops['x_gate' + t] = x_gate[i]
            n.tops['h_gate' + t] = fc(n.tops['h_conted' + t_1],
                                      4 * lstm_dim,
                                      weight_filler=dict(type='uniform',
                                                         min=-0.08,
                                                         max=0.08),
                                      param_name='Wh',
                                      axis=2,
                                      bias=False)
            n.tops['gate' + t] = eltwise(x_gate[i], n.tops['h_gate' + t],
                                         P.Eltwise.SUM)
            n.tops['c' + t], n.tops['h' + t] = L.LSTMUnit(
                n.tops['c' + t_1],
                n.tops['gate' + t],
                clip_t,
                ntop=2,
                clip_gradients=[0, 0.1, 0])

    # Define Loss functions
    if mode == 'train':
        x_hat = ()
        for i in range(K, T):
            t = tag(i + 1)
            x_hat = x_hat + (n.tops['x_hat' + t], )
        silence_set += (n.tops['c' + tag(T - 1)], )
        n.silence = L.Silence(*silence_set, ntop=0)
        n.x_hat = L.Concat(*x_hat, axis=0)
        n.label_flat = L.Flatten(n.label, axis=0, end_axis=1)
        n.l2_loss = L.EuclideanLoss(n.x_hat, n.label_flat)
    return n