Exemplo n.º 1
0
def conv_crop_pool_op(X, sizes, output_sizes, W, b, n_in, n_maps, filter_height, filter_width, filter_dilation, poolsize):
  from Device import is_using_gpu
  if is_using_gpu():
    conv_op = CuDNNConvHWBCOpValidInstance
    pool_op = PoolHWBCOp(poolsize)
    conv_out = conv_op(X, W, b) if filter_height * filter_width > 0 else X
    crop_out = CropToBatchImageSizeInstance(conv_out, sizes)
    Y = pool_op(crop_out)
    Y = CropToBatchImageSizeZeroInstance(Y, output_sizes)
  else:
    Y = X
  return Y
Exemplo n.º 2
0
def conv_crop_pool_op(X, sizes, output_sizes, W, b, n_in, n_maps, filter_height, filter_width, filter_dilation, poolsize):
  from Device import is_using_gpu
  if is_using_gpu():
    conv_op = CuDNNConvHWBCOpValidInstance
    pool_op = PoolHWBCOp(poolsize)
    conv_out = conv_op(X, W, b) if filter_height * filter_width > 0 else X
    crop_out = CropToBatchImageSizeInstance(conv_out, sizes)
    Y = pool_op(crop_out)
    Y = CropToBatchImageSizeZeroInstance(Y, output_sizes)
  else:
    Y = X
  return Y
Exemplo n.º 3
0
def circular_convolution(a, b):
    from Device import is_using_gpu
    has_gpuarray = is_using_gpu()
    try:
        import pygpu
    except Exception:
        has_gpuarray = False
    if has_gpuarray:
        from theano.gpuarray.fft import curfft as fft
        from theano.gpuarray.fft import cuirfft as ifft
    else:
        from theano.tensor.fft import rfft as fft
        from theano.tensor.fft import irfft as ifft
    return ifft(fft(a) * fft(b))
Exemplo n.º 4
0
def circular_convolution(a, b):
  from Device import is_using_gpu
  has_gpuarray = is_using_gpu()
  try:
    import pygpu
  except Exception:
    has_gpuarray = False
  if has_gpuarray:
    from theano.gpuarray.fft import curfft as fft
    from theano.gpuarray.fft import cuirfft as ifft
  else:
    from theano.tensor.fft import rfft as fft
    from theano.tensor.fft import irfft as ifft
  return ifft(fft(a) * fft(b))
Exemplo n.º 5
0
  def __init__(self,
               n_out = None,
               n_units = None,
               direction = 1,
               truncation = -1,
               sampling = 1,
               encoder = None,
               unit = 'lstm',
               n_dec = 0,
               attention = "none",
               recurrent_transform = "none",
               recurrent_transform_attribs = "{}",
               attention_template = 128,
               attention_distance = 'l2',
               attention_step = "linear",
               attention_beam = 0,
               attention_norm = "exp",
               attention_momentum = "none",
               attention_sharpening = 1.0,
               attention_nbest = 0,
               attention_store = False,
               attention_smooth = False,
               attention_glimpse = 1,
               attention_filters = 1,
               attention_accumulator = 'sum',
               attention_loss = 0,
               attention_bn = 0,
               attention_lm = 'none',
               attention_ndec = 1,
               attention_memory = 0,
               attention_alnpts = 0,
               attention_epoch  = 1,
               attention_segstep=0.01,
               attention_offset=0.95,
               attention_method="epoch",
               attention_scale=10,
               context=-1,
               base = None,
               aligner = None,
               lm = False,
               force_lm = False,
               droplm = 1.0,
               forward_weights_init=None,
               bias_random_init_forget_shift=0.0,
               copy_weights_from_base=False,
               segment_input=False,
               join_states=False,
               sample_segment=None,
               **kwargs):
    """
    :param n_out: number of cells
    :param n_units: used when initialized via Network.from_hdf_model_topology
    :param direction: process sequence in forward (1) or backward (-1) direction
    :param truncation: gradient truncation
    :param sampling: scan every nth frame only
    :param encoder: list of encoder layers used as initalization for the hidden state
    :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru')
    :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder
    :param recurrent_transform: name of recurrent transform
    :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform
    :param attention_template:
    :param attention_distance:
    :param attention_step:
    :param attention_beam:
    :param attention_norm:
    :param attention_sharpening:
    :param attention_nbest:
    :param attention_store:
    :param attention_align:
    :param attention_glimpse:
    :param attention_lm:
    :param base: list of layers which outputs are considered as based during attention mechanisms
    :param lm: activate RNNLM
    :param force_lm: expect previous labels to be given during testing
    :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true
    :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value
    """
    source_index = None
    if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')):
      kwargs['sources'] = []
      source_index = kwargs['index']
    unit_given = unit
    from Device import is_using_gpu
    if unit == 'lstm':  # auto selection
      if not is_using_gpu():
        unit = 'lstme'
      elif recurrent_transform == 'none' and (not lm or droplm == 0.0):
        unit = 'lstmp'
      else:
        unit = 'lstmc'
    elif unit in ("lstmc", "lstmp") and not is_using_gpu():
      unit = "lstme"
    if segment_input:
      if is_using_gpu():
        unit = "lstmps"
      else:
        unit = "lstms"
    if n_out is None:
      assert encoder
      n_out = sum([enc.attrs['n_out'] for enc in encoder])
    kwargs.setdefault("n_out", n_out)
    if n_units is not None:
      assert n_units == n_out
    self.attention_weight = T.constant(1.,'float32')
    if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'):
      kwargs['sources'] = []
    elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'):
      kwargs['sources'] = []
    super(RecurrentUnitLayer, self).__init__(**kwargs)
    self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null")
    self.set_attr('n_out', n_out)
    self.set_attr('unit', unit_given.encode("utf8"))
    self.set_attr('truncation', truncation)
    self.set_attr('sampling', sampling)
    self.set_attr('direction', direction)
    self.set_attr('lm', lm)
    self.set_attr('force_lm', force_lm)
    self.set_attr('droplm', droplm)
    if bias_random_init_forget_shift:
      self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift)
    self.set_attr('attention_beam', attention_beam)
    self.set_attr('recurrent_transform', recurrent_transform.encode("utf8"))
    if isinstance(recurrent_transform_attribs, str):
      recurrent_transform_attribs = json.loads(recurrent_transform_attribs)
    if attention_template is not None:
      self.set_attr('attention_template', attention_template)
    self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs)
    self.set_attr('attention_distance', attention_distance.encode("utf8"))
    self.set_attr('attention_step', attention_step.encode("utf8"))
    self.set_attr('attention_norm', attention_norm.encode("utf8"))
    self.set_attr('attention_sharpening', attention_sharpening)
    self.set_attr('attention_nbest', attention_nbest)
    attention_store = attention_store or attention_smooth or attention_momentum != 'none'
    self.set_attr('attention_store', attention_store)
    self.set_attr('attention_smooth', attention_smooth)
    self.set_attr('attention_momentum', attention_momentum.encode('utf8'))
    self.set_attr('attention_glimpse', attention_glimpse)
    self.set_attr('attention_filters', attention_filters)
    self.set_attr('attention_lm', attention_lm)
    self.set_attr('attention_bn', attention_bn)
    self.set_attr('attention_accumulator', attention_accumulator)
    self.set_attr('attention_ndec', attention_ndec)
    self.set_attr('attention_memory', attention_memory)
    self.set_attr('attention_loss', attention_loss)
    self.set_attr('n_dec', n_dec)
    self.set_attr('segment_input', segment_input)
    self.set_attr('attention_alnpts', attention_alnpts)
    self.set_attr('attention_epoch', attention_epoch)
    self.set_attr('attention_segstep', attention_segstep)
    self.set_attr('attention_offset', attention_offset)
    self.set_attr('attention_method', attention_method)
    self.set_attr('attention_scale', attention_scale)
    if segment_input:
      if not self.eval_flag:
      #if self.eval_flag:
        if isinstance(self.sources[0],RecurrentUnitLayer):
          self.inv_att = self.sources[0].inv_att #NBT
        else:
          if not join_states:
            self.inv_att = self.sources[0].attention #NBT
          else:
            assert hasattr(self.sources[0], "nstates"), "source does not have number of states!"
            ns = self.sources[0].nstates
            self.inv_att = self.sources[0].attention[(ns-1)::ns]
        inv_att = T.roll(self.inv_att.dimshuffle(2, 1, 0),1,axis=0)#TBN
        inv_att = T.set_subtensor(inv_att[0],T.zeros((inv_att.shape[1],inv_att.shape[2])))
        inv_att = T.max(inv_att,axis=-1)
      else:
        inv_att = T.zeros((self.sources[0].output.shape[0],self.sources[0].output.shape[1]))
    if encoder and hasattr(encoder[0],'act'):
      self.set_attr('encoder', ",".join([e.name for e in encoder]))
    if base:
      self.set_attr('base', ",".join([b.name for b in base]))
    else:
      base = encoder
    self.base = base
    self.encoder = encoder
    if aligner:
      self.aligner = aligner
    self.set_attr('n_units', n_out)
    unit = eval(unit.upper())(**self.attrs)
    assert isinstance(unit, Unit)
    self.unit = unit
    kwargs.setdefault("n_out", unit.n_out)
    n_out = unit.n_out
    self.set_attr('n_out', unit.n_out)
    if n_dec < 0:
      source_index = self.index
      n_dec *= -1
    if n_dec != 0:
      self.target_index = self.index
      if isinstance(n_dec,float):
        if not source_index:
          source_index = encoder[0].index if encoder else base[0].index
        lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32')
        idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]),
                            [lengths], [T.max(lengths)+1])
        self.index = idx.dimshuffle(1,0)[:-1]
        n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32')
      else:
        if encoder:
          self.index = encoder[0].index
        self.index = T.ones((n_dec,self.index.shape[1]),'int8')
    else:
      n_dec = self.index.shape[0]
    # initialize recurrent weights
    self.W_re = None
    if unit.n_re > 0:
      self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name))
    # initialize forward weights
    bias_init_value = self.create_bias(unit.n_in).get_value()
    if bias_random_init_forget_shift:
      assert unit.n_units * 4 == unit.n_in  # (input gate, forget gate, output gate, net input)
      bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift
    self.b.set_value(bias_init_value)
    if not forward_weights_init:
      forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re
    else:
      self.set_attr('forward_weights_init', forward_weights_init)
    self.forward_weights_init = forward_weights_init
    self.W_in = []
    sample_mean, gamma = None, None
    if copy_weights_from_base:
      self.params = {}
      #self.W_re = self.add_param(base[0].W_re)
      #self.W_in = [ self.add_param(W) for W in base[0].W_in ]
      #self.b = self.add_param(base[0].b)
      self.W_re = base[0].W_re
      self.W_in = base[0].W_in
      self.b = base[0].b
      if self.attrs.get('batch_norm', False):
        sample_mean = base[0].sample_mean
        gamma = base[0].gamma
      #self.masks = base[0].masks
      #self.mass = base[0].mass
    else:
      for s in self.sources:
        W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name))
        self.W_in.append(self.add_param(W))
    # make input
    z = self.b
    for x_t, m, W in zip(self.sources, self.masks, self.W_in):
      if x_t.attrs['sparse']:
        if x_t.output.ndim == 3: out_dim = x_t.output.shape[2]
        elif x_t.output.ndim == 2: out_dim = 1
        else: assert False, x_t.output.ndim
        if x_t.output.ndim == 3:
          z += W[T.cast(x_t.output[:,:,0], 'int32')]
        elif x_t.output.ndim == 2:
          z += W[T.cast(x_t.output, 'int32')]
        else:
          assert False, x_t.output.ndim
      elif m is None:
        z += T.dot(x_t.output, W)
      else:
        z += self.dot(self.mass * m * x_t.output, W)
    #if self.attrs['batch_norm']:
    #  z = self.batch_norm(z, unit.n_in)
    num_batches = self.index.shape[1]
    self.num_batches = num_batches
    non_sequences = []
    if self.attrs['lm'] or attention_lm != 'none':
      if not 'target' in self.attrs:
        self.attrs['target'] = 'classes'
      if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm):
        if copy_weights_from_base:
          self.W_lm_in = base[0].W_lm_in
          self.b_lm_in = base[0].b_lm_in
        else:
          l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out)
          values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX)
          self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name))
          self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in')
      l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX)
      if copy_weights_from_base:
        self.W_lm_out = base[0].W_lm_out
      else:
        self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name))
      if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        self.lmmask = 1
        #if recurrent_transform != 'none':
        #  recurrent_transform = recurrent_transform[:-3]
      elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm):
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        srng = RandomStreams(self.rng.randint(1234) + 1)
        self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)
      else:
        self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)

    if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible)
      src = []
      src_names = []
      n_in = 0
      for e in base:
        #src_base = [ s for s in e.sources if s.name not in src_names ]
        #src_names += [ s.name for s in e.sources ]
        src_base = [ e ]
        src_names += [e.name]
        src += [s.output for s in src_base]
        n_in += sum([s.attrs['n_out'] for s in src_base])
      self.xc = T.concatenate(src, axis=2)
      l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX)
      self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc"))
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX)
      self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in"))
      zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1
      self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in)
      recurrent_transform = 'none'
    elif recurrent_transform == 'attention_align':
      max_skip = base[0].attrs['max_skip']
      values = numpy.zeros((max_skip,), dtype=theano.config.floatX)
      self.T_b = self.add_param(self.shared(value=values, borrow=True, name="T_b"), name="T_b")
      l = sqrt(6.) / sqrt(self.attrs['n_out'] + max_skip)
      values = numpy.asarray(self.rng.uniform(
        low=-l, high=l, size=(self.attrs['n_out'], max_skip)), dtype=theano.config.floatX)
      self.T_W = self.add_param(self.shared(value=values, borrow=True, name="T_W"), name="T_W")
      y_t = T.dot(self.base[0].attention, T.arange(self.base[0].output.shape[0], dtype='float32'))  # NB
      y_t = T.concatenate([T.zeros_like(y_t[:1]), y_t], axis=0)  # (N+1)B
      y_t = y_t[1:] - y_t[:-1]  # NB
      self.y_t = y_t # T.clip(y_t,numpy.float32(0),numpy.float32(max_skip - 1))

      self.y_t = T.cast(self.base[0].backtrace,'float32')
    elif recurrent_transform == 'attention_segment':
      assert aligner.attention, "Segment-wise attention requires attention points!"

    recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self)
    assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase)
    unit.recurrent_transform = recurrent_transform_inst
    self.recurrent_transform = recurrent_transform_inst
    # scan over sequence
    for s in range(self.attrs['sampling']):
      index = self.index[s::self.attrs['sampling']]

      if context > 0:
        from TheanoUtil import context_batched
        n_batches = z.shape[1]
        time, batch, dim = z.shape[0], z.shape[1], z.shape[2]
        #z = context_batched(z[::direction or 1], window=context)[::direction or 1] # TB(CD)

        from theano.ifelse import ifelse
        def context_window(idx, x_in, i_in):
          x_out = x_in[idx:idx + context]
          x_out = x_out.dimshuffle('x',1,0,2).reshape((1, batch, dim * context))
          i_out = i_in[idx:idx+1].repeat(context, axis=0)
          i_out = ifelse(T.lt(idx,context),T.set_subtensor(i_out[:context - idx],numpy.int8(0)),i_out).reshape((1, batch * context))
          return x_out, i_out

        z = z[::direction or 1]
        i = index[::direction or 1]
        out, _ = theano.map(context_window, sequences = [T.arange(z.shape[0])], non_sequences = [T.concatenate([T.zeros((context - 1,z.shape[1],z.shape[2]),dtype='float32'),z],axis=0), i])
        z = out[0][::direction or 1]
        i = out[1][::direction or 1] # T(BC)
        direction = 1
        z = z.reshape((time * batch, context * dim)) # (TB)(CD)
        z = z.reshape((time * batch, context, dim)).dimshuffle(1,0,2) # C(TB)D
        i = i.reshape((time, context, batch)).dimshuffle(1,0,2).reshape((context, time * batch))
        index = i
        num_batches = time * batch

      sequences = z
      sources = self.sources
      if encoder:
        if recurrent_transform == "attention_segment":
          if hasattr(encoder[0],'act'):
            outputs_info = [T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act)]
          else:
           # outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ]
            outputs_info[0] = self.aligner.output[-1]
        elif hasattr(encoder[0],'act'):
          outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ]
        else:
          outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ]
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0))
      else:
        outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ]

      if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        if self.network.y[self.attrs['target']].ndim == 3:
          sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out)
        else:
          y = self.y_in[self.attrs['target']].flatten()
          sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in))

      if sequences == self.b:
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0))

      if unit.recurrent_transform:
        outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial()

      index_f = T.cast(index, theano.config.floatX)
      unit.set_parent(self)

      if segment_input:
        outputs = unit.scan_seg(x=sources,
                                z=sequences[s::self.attrs['sampling']],
                                att = inv_att,
                                non_sequences=non_sequences,
                                i=index_f,
                                outputs_info=outputs_info,
                                W_re=self.W_re,
                                W_in=self.W_in,
                                b=self.b,
                                go_backwards=direction == -1,
                                truncate_gradient=self.attrs['truncation'])
      else:
        outputs = unit.scan(x=sources,
                            z=sequences[s::self.attrs['sampling']],
                            non_sequences=non_sequences,
                            i=index_f,
                            outputs_info=outputs_info,
                            W_re=self.W_re,
                            W_in=self.W_in,
                            b=self.b,
                            go_backwards=direction == -1,
                            truncate_gradient=self.attrs['truncation'])

      if not isinstance(outputs, list):
        outputs = [outputs]
      if outputs:
        outputs[0].name = "%s.act[0]" % self.name
        if context > 0:
          for i in range(len(outputs)):
            outputs[i] = outputs[i][-1].reshape((outputs[i].shape[1]//n_batches,n_batches,outputs[i].shape[2]))

      if unit.recurrent_transform:
        unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):]

      if self.attrs['sampling'] > 1:
        if s == 0:
          self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ]
        self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ]
      else:
        self.act = outputs[:unit.n_act]
        if len(outputs) > unit.n_act:
          self.aux = outputs[unit.n_act:]
    if self.attrs['attention_store']:
      self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT
      for i in range(len(self.attention)):
        vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1))
        last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1], axis=1)
        self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction]

    self.cost_val = numpy.float32(0)
    if recurrent_transform == 'attention_align':
      back = T.ceil(self.aux[sorted(unit.recurrent_transform.state_vars.keys()).index('t')])
      def make_output(base, yout, trace, length):
        length = T.cast(length, 'int32')
        idx = T.cast(trace[:length][::-1],'int32')
        x_out = T.concatenate([base[idx],T.zeros((self.index.shape[0] + 1 - length, base.shape[1]), 'float32')],axis=0)
        y_out = T.concatenate([yout[idx,T.arange(length)],T.zeros((self.index.shape[0] + 1 - length, ), 'float32')],axis=0)
        return x_out, y_out

      output, _ = theano.map(make_output,
                             sequences = [base[0].output.dimshuffle(1,0,2),
                                          self.y_t.dimshuffle(1,2,0),
                                          back.dimshuffle(1,0),
                                          T.sum(self.index,axis=0,dtype='float32')])
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      self.output = output[0].dimshuffle(1,0,2)[:-1]

      z = T.dot(self.act[0], self.T_W)[:-1] + self.T_b
      z = z.reshape((z.shape[0] * z.shape[1], z.shape[2]))
      idx = (self.index[1:].flatten() > 0).nonzero()
      idy = (self.index[1:][::-1].flatten() > 0).nonzero()
      y_out = T.cast(output[1],'int32').dimshuffle(1, 0)[:-1].flatten()
      nll, _ = T.nnet.crossentropy_softmax_1hot(x=z[idx], y_idx=y_out[idy])
      self.cost_val = T.sum(nll)
      recog = T.argmax(z[idx], axis=1)
      real = y_out[idy]
      self.errors = lambda: T.sum(T.neq(recog, real))

      return

      back += T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32')
      idx = (self.index[:-1].flatten() > 0).nonzero()
      idx = T.cast(back[::-1].flatten()[idx],'int32')
      x_out = base[0].output
      #x_out = x_out.dimshuffle(1,0,2).reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      #x_out = x_out.reshape((self.index.shape[1], self.index.shape[0] - 1, x_out.shape[1])).dimshuffle(1,0,2)
      x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      x_out = x_out.reshape((self.index.shape[0] - 1, self.index.shape[1], x_out.shape[1]))
      self.output = T.concatenate([x_out, base[0].output[1:]],axis=0)
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      return


      skips = T.dot(T.nnet.softmax(z), T.arange(z.shape[1], dtype='float32')).reshape(self.index[1:].shape)
      shift = T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32')
      skips = T.concatenate([T.zeros_like(self.y_t[:1]),self.y_t[:-1]],axis=0)
      idx = shift + T.cumsum(skips, axis=0)
      idx = T.cast(idx[:-1].flatten(),'int32')
      #idx = (idx.flatten() > 0).nonzero()
      #idx = base[0].attention.flatten()
      x_out = base[0].output[::-1]
      x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      x_out = x_out.reshape((self.index.shape[0], self.index.shape[1], x_out.shape[1]))
      self.output = T.concatenate([base[0].output[-1:], x_out], axis=0)[::-1]
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      return

    if recurrent_transform == 'batch_norm':
      self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re)
      self.params['sample_mean_batch_norm'].custom_update_normalized = True

    self.make_output(self.act[0][::direction or 1], sample_mean=sample_mean, gamma=gamma)
    self.params.update(unit.params)
Exemplo n.º 6
0
    def __init__(self,
                 n_out=None,
                 n_units=None,
                 direction=1,
                 truncation=-1,
                 sampling=1,
                 encoder=None,
                 unit='lstm',
                 n_dec=0,
                 attention="none",
                 recurrent_transform="none",
                 recurrent_transform_attribs="{}",
                 attention_template=128,
                 attention_distance='l2',
                 attention_step="linear",
                 attention_beam=0,
                 attention_norm="exp",
                 attention_momentum="none",
                 attention_sharpening=1.0,
                 attention_nbest=0,
                 attention_store=False,
                 attention_smooth=False,
                 attention_align=False,
                 attention_glimpse=1,
                 attention_filters=1,
                 attention_accumulator='sum',
                 attention_loss=0,
                 attention_bn=0,
                 attention_lm='none',
                 attention_ndec=1,
                 attention_memory=0,
                 base=None,
                 lm=False,
                 force_lm=False,
                 droplm=1.0,
                 forward_weights_init=None,
                 bias_random_init_forget_shift=0.0,
                 copy_weights_from_base=False,
                 **kwargs):
        """
    :param n_out: number of cells
    :param n_units: used when initialized via Network.from_hdf_model_topology
    :param direction: process sequence in forward (1) or backward (-1) direction
    :param truncation: gradient truncation
    :param sampling: scan every nth frame only
    :param encoder: list of encoder layers used as initalization for the hidden state
    :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru')
    :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder
    :param recurrent_transform: name of recurrent transform
    :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform
    :param attention_template:
    :param attention_distance:
    :param attention_step:
    :param attention_beam:
    :param attention_norm:
    :param attention_sharpening:
    :param attention_nbest:
    :param attention_store:
    :param attention_align:
    :param attention_glimpse:
    :param attention_lm:
    :param base: list of layers which outputs are considered as based during attention mechanisms
    :param lm: activate RNNLM
    :param force_lm: expect previous labels to be given during testing
    :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true
    :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value
    """
        source_index = None
        if len(kwargs['sources']) == 1 and (
                kwargs['sources'][0].layer_class.endswith('length')
                or kwargs['sources'][0].layer_class.startswith('length')):
            kwargs['sources'] = []
            source_index = kwargs['index']
        unit_given = unit
        from Device import is_using_gpu
        if unit == 'lstm':  # auto selection
            if not is_using_gpu():
                unit = 'lstme'
            elif recurrent_transform == 'none' and (not lm or droplm == 0.0):
                unit = 'lstmp'
            else:
                unit = 'lstmc'
        elif unit in ("lstmc", "lstmp") and not is_using_gpu():
            unit = "lstme"
        if n_out is None:
            assert encoder
            n_out = sum([enc.attrs['n_out'] for enc in encoder])
        kwargs.setdefault("n_out", n_out)
        if n_units is not None:
            assert n_units == n_out
        self.attention_weight = T.constant(1., 'float32')
        if len(
                kwargs['sources']
        ) == 1 and kwargs['sources'][0].layer_class.startswith('length'):
            kwargs['sources'] = []
        elif len(
                kwargs['sources']
        ) == 1 and kwargs['sources'][0].layer_class.startswith('signal'):
            kwargs['sources'] = []
        super(RecurrentUnitLayer, self).__init__(**kwargs)
        self.set_attr(
            'from',
            ",".join([s.name
                      for s in self.sources]) if self.sources else "null")
        self.set_attr('n_out', n_out)
        self.set_attr('unit', unit_given.encode("utf8"))
        self.set_attr('truncation', truncation)
        self.set_attr('sampling', sampling)
        self.set_attr('direction', direction)
        self.set_attr('lm', lm)
        self.set_attr('force_lm', force_lm)
        self.set_attr('droplm', droplm)
        if bias_random_init_forget_shift:
            self.set_attr("bias_random_init_forget_shift",
                          bias_random_init_forget_shift)
        self.set_attr('attention_beam', attention_beam)
        self.set_attr('recurrent_transform',
                      recurrent_transform.encode("utf8"))
        if isinstance(recurrent_transform_attribs, str):
            recurrent_transform_attribs = json.loads(
                recurrent_transform_attribs)
        if attention_template is not None:
            self.set_attr('attention_template', attention_template)
        self.set_attr('recurrent_transform_attribs',
                      recurrent_transform_attribs)
        self.set_attr('attention_distance', attention_distance.encode("utf8"))
        self.set_attr('attention_step', attention_step.encode("utf8"))
        self.set_attr('attention_norm', attention_norm.encode("utf8"))
        self.set_attr('attention_sharpening', attention_sharpening)
        self.set_attr('attention_nbest', attention_nbest)
        attention_store = attention_store or attention_smooth or attention_momentum != 'none'
        self.set_attr('attention_store', attention_store)
        self.set_attr('attention_smooth', attention_smooth)
        self.set_attr('attention_momentum', attention_momentum.encode('utf8'))
        self.set_attr('attention_align', attention_align)
        self.set_attr('attention_glimpse', attention_glimpse)
        self.set_attr('attention_filters', attention_filters)
        self.set_attr('attention_lm', attention_lm)
        self.set_attr('attention_bn', attention_bn)
        self.set_attr('attention_accumulator', attention_accumulator)
        self.set_attr('attention_ndec', attention_ndec)
        self.set_attr('attention_memory', attention_memory)
        self.set_attr('attention_loss', attention_loss)
        self.set_attr('n_dec', n_dec)
        if encoder and hasattr(encoder[0], 'act'):
            self.set_attr('encoder', ",".join([e.name for e in encoder]))
        if base:
            self.set_attr('base', ",".join([b.name for b in base]))
        else:
            base = encoder
        self.base = base
        self.set_attr('n_units', n_out)
        unit = eval(unit.upper())(**self.attrs)
        assert isinstance(unit, Unit)
        self.unit = unit
        kwargs.setdefault("n_out", unit.n_out)
        n_out = unit.n_out
        self.set_attr('n_out', unit.n_out)
        if n_dec != 0:
            self.target_index = self.index
            if isinstance(n_dec, float):
                if not source_index:
                    source_index = encoder[0].index if encoder else base[
                        0].index
                lengths = T.cast(
                    T.ceil(
                        T.sum(T.cast(source_index, 'float32'), axis=0) *
                        n_dec), 'int32')
                idx, _ = theano.map(
                    lambda l_i, l_m: T.concatenate([
                        T.ones((l_i, ), 'int8'),
                        T.zeros((l_m - l_i, ), 'int8')
                    ]), [lengths], [T.max(lengths) + 1])
                self.index = idx.dimshuffle(1, 0)[:-1]
                n_dec = T.cast(
                    T.ceil(
                        T.cast(source_index.shape[0], 'float32') *
                        numpy.float32(n_dec)), 'int32')
            else:
                if encoder:
                    self.index = encoder[0].index
                self.index = T.ones((n_dec, self.index.shape[1]), 'int8')
        else:
            n_dec = self.index.shape[0]
        # initialize recurrent weights
        self.W_re = None
        if unit.n_re > 0:
            self.W_re = self.add_param(
                self.create_recurrent_weights(unit.n_units,
                                              unit.n_re,
                                              name="W_re_%s" % self.name))
        # initialize forward weights
        bias_init_value = self.create_bias(unit.n_in).get_value()
        if bias_random_init_forget_shift:
            assert unit.n_units * 4 == unit.n_in  # (input gate, forget gate, output gate, net input)
            bias_init_value[unit.n_units:2 *
                            unit.n_units] += bias_random_init_forget_shift
        self.b.set_value(bias_init_value)
        if not forward_weights_init:
            forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re
        else:
            self.set_attr('forward_weights_init', forward_weights_init)
        self.forward_weights_init = forward_weights_init
        self.W_in = []
        if copy_weights_from_base:
            self.params = {}
            #self.W_re = self.add_param(base[0].W_re)
            #self.W_in = [ self.add_param(W) for W in base[0].W_in ]
            #self.b = self.add_param(base[0].b)
            self.W_re = base[0].W_re
            self.W_in = base[0].W_in
            self.b = base[0].b
            self.masks = base[0].masks
            self.mass = base[0].mass
        else:
            for s in self.sources:
                W = self.create_forward_weights(s.attrs['n_out'],
                                                unit.n_in,
                                                name="W_in_%s_%s" %
                                                (s.name, self.name))
                self.W_in.append(self.add_param(W))
        # make input
        z = self.b
        for x_t, m, W in zip(self.sources, self.masks, self.W_in):
            if x_t.attrs['sparse']:
                if x_t.output.ndim == 3: out_dim = x_t.output.shape[2]
                elif x_t.output.ndim == 2: out_dim = 1
                else: assert False, x_t.output.ndim
                if x_t.output.ndim == 3:
                    z += W[T.cast(x_t.output[:, :, 0], 'int32')]
                elif x_t.output.ndim == 2:
                    z += W[T.cast(x_t.output, 'int32')]
                else:
                    assert False, x_t.output.ndim
            elif m is None:
                z += T.dot(x_t.output, W)
            else:
                z += self.dot(self.mass * m * x_t.output, W)
        #if self.attrs['batch_norm']:
        #  z = self.batch_norm(z, unit.n_in)
        num_batches = self.index.shape[1]
        self.num_batches = num_batches
        non_sequences = []
        if self.attrs['lm'] or attention_lm != 'none':
            if not 'target' in self.attrs:
                self.attrs['target'] = 'classes'
            if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm):
                if copy_weights_from_base:
                    self.W_lm_in = base[0].W_lm_in
                    self.b_lm_in = base[0].b_lm_in
                else:
                    l = sqrt(6.) / sqrt(unit.n_out +
                                        self.y_in[self.attrs['target']].n_out)
                    values = numpy.asarray(self.rng.uniform(
                        low=-l,
                        high=l,
                        size=(unit.n_out,
                              self.y_in[self.attrs['target']].n_out)),
                                           dtype=theano.config.floatX)
                    self.W_lm_in = self.add_param(
                        self.shared(value=values,
                                    borrow=True,
                                    name="W_lm_in_" + self.name))
                    self.b_lm_in = self.create_bias(
                        self.y_in[self.attrs['target']].n_out, 'b_lm_in')
            l = sqrt(6.) / sqrt(unit.n_in +
                                self.y_in[self.attrs['target']].n_out)
            values = numpy.asarray(self.rng.uniform(
                low=-l,
                high=l,
                size=(self.y_in[self.attrs['target']].n_out, unit.n_in)),
                                   dtype=theano.config.floatX)
            if copy_weights_from_base:
                self.W_lm_out = base[0].W_lm_out
            else:
                self.W_lm_out = self.add_param(
                    self.shared(value=values,
                                borrow=True,
                                name="W_lm_out_" + self.name))
            if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
                self.lmmask = 1
                #if recurrent_transform != 'none':
                #  recurrent_transform = recurrent_transform[:-3]
            elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm):
                from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
                srng = RandomStreams(self.rng.randint(1234) + 1)
                self.lmmask = T.cast(
                    srng.binomial(n=1,
                                  p=1.0 - self.attrs['droplm'],
                                  size=self.index.shape),
                    theano.config.floatX).dimshuffle(0, 1,
                                                     'x').repeat(unit.n_in,
                                                                 axis=2)
            else:
                self.lmmask = T.zeros_like(self.index,
                                           dtype='float32').dimshuffle(
                                               0, 1, 'x').repeat(unit.n_in,
                                                                 axis=2)

        if recurrent_transform == 'input':  # attention is just a sequence dependent bias (lstmp compatible)
            src = []
            src_names = []
            n_in = 0
            for e in base:
                #src_base = [ s for s in e.sources if s.name not in src_names ]
                #src_names += [ s.name for s in e.sources ]
                src_base = [e]
                src_names += [e.name]
                src += [s.output for s in src_base]
                n_in += sum([s.attrs['n_out'] for s in src_base])
            self.xc = T.concatenate(src, axis=2)
            l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in)
            values = numpy.asarray(self.rng.uniform(low=-l,
                                                    high=l,
                                                    size=(n_in, 1)),
                                   dtype=theano.config.floatX)
            self.W_att_xc = self.add_param(
                self.shared(value=values, borrow=True, name="W_att_xc"))
            values = numpy.asarray(self.rng.uniform(
                low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)),
                                   dtype=theano.config.floatX)
            self.W_att_in = self.add_param(
                self.shared(value=values, borrow=True, name="W_att_in"))
            zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc)))  # TB1
            self.zc = T.dot(
                T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(
                    self.xc.shape[2], axis=2),
                      axis=0,
                      keepdims=True), self.W_att_in)
            recurrent_transform = 'none'
        recurrent_transform_inst = RecurrentTransform.transform_classes[
            recurrent_transform](layer=self)
        assert isinstance(recurrent_transform_inst,
                          RecurrentTransform.RecurrentTransformBase)
        unit.recurrent_transform = recurrent_transform_inst
        self.recurrent_transform = recurrent_transform_inst

        # scan over sequence
        for s in range(self.attrs['sampling']):
            index = self.index[s::self.attrs['sampling']]
            sequences = z
            sources = self.sources
            if encoder:
                if hasattr(encoder[0], 'act'):
                    outputs_info = [
                        T.concatenate([e.act[i][-1] for e in encoder], axis=1)
                        for i in range(unit.n_act)
                    ]
                else:
                    outputs_info = [
                        T.concatenate([e[i] for e in encoder], axis=1)
                        for i in range(unit.n_act)
                    ]
                sequences += T.alloc(
                    numpy.cast[theano.config.floatX](0), n_dec, num_batches,
                    unit.n_in) + (self.zc if self.attrs['recurrent_transform']
                                  == 'input' else numpy.float32(0))
            else:
                outputs_info = [
                    T.alloc(numpy.cast[theano.config.floatX](0), num_batches,
                            unit.n_units) for a in range(unit.n_act)
                ]

            if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (
                    self.train_flag or force_lm):
                if self.network.y[self.attrs['target']].ndim == 3:
                    sequences += T.dot(self.network.y[self.attrs['target']],
                                       self.W_lm_out)
                else:
                    y = self.y_in[self.attrs['target']].flatten()
                    sequences += self.W_lm_out[y].reshape(
                        (index.shape[0], index.shape[1], unit.n_in))

            if sequences == self.b:
                sequences += T.alloc(
                    numpy.cast[theano.config.floatX](0), n_dec, num_batches,
                    unit.n_in) + (self.zc if self.attrs['recurrent_transform']
                                  == 'input' else numpy.float32(0))

            if unit.recurrent_transform:
                outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial(
                )

            index_f = T.cast(index, theano.config.floatX)
            unit.set_parent(self)
            outputs = unit.scan(x=sources,
                                z=sequences[s::self.attrs['sampling']],
                                non_sequences=non_sequences,
                                i=index_f,
                                outputs_info=outputs_info,
                                W_re=self.W_re,
                                W_in=self.W_in,
                                b=self.b,
                                go_backwards=direction == -1,
                                truncate_gradient=self.attrs['truncation'])

            if not isinstance(outputs, list):
                outputs = [outputs]
            if outputs:
                outputs[0].name = "%s.act[0]" % self.name

            if unit.recurrent_transform:
                unit.recurrent_transform_state_var_seqs = outputs[
                    -len(unit.recurrent_transform.state_vars):]

            if self.attrs['sampling'] > 1:
                if s == 0:
                    self.act = [
                        T.alloc(numpy.cast['float32'](0), self.index.shape[0],
                                self.index.shape[1], n_out) for act in outputs
                    ]
                self.act = [
                    T.set_subtensor(tot[s::self.attrs['sampling']], act)
                    for tot, act in zip(self.act, outputs)
                ]
            else:
                self.act = outputs[:unit.n_act]
                if len(outputs) > unit.n_act:
                    self.aux = outputs[unit.n_act:]
        if self.attrs['attention_store']:
            self.attention = [
                self.aux[i].dimshuffle(0, 2, 1) for i, v in enumerate(
                    sorted(unit.recurrent_transform.state_vars.keys()))
                if v.startswith('att_')
            ]  # NBT
            for i in range(len(self.attention)):
                vec = T.eye(self.attention[i].shape[2], 1,
                            -direction * (self.attention[i].shape[2] - 1))
                last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1],
                                                        axis=1)
                self.attention[i] = T.concatenate(
                    [self.attention[i][1:], last], axis=0)[::direction]

        if self.attrs['attention_align']:
            bp = [
                self.aux[i] for i, v in enumerate(
                    sorted(unit.recurrent_transform.state_vars.keys()))
                if v.startswith('K_')
            ]

            def backtrace(k, i_p):
                return i_p - k[i_p, T.arange(k.shape[1], dtype='int32')]

            self.alignment = []
            for K in bp:  # K: NTB
                aln, _ = theano.scan(
                    backtrace,
                    sequences=[T.cast(K, 'int32').dimshuffle(1, 0, 2)],
                    outputs_info=[
                        T.cast(self.index.shape[0] - 1, 'int32') + T.zeros(
                            (K.shape[2], ), 'int32')
                    ])
                aln = theano.printing.Print("aln")(aln)
                self.alignment.append(aln)  # TB

        if recurrent_transform == 'batch_norm':
            self.params['sample_mean_batch_norm'].custom_update = T.dot(
                T.mean(self.act[0], axis=[0, 1]), self.W_re)
            self.params[
                'sample_mean_batch_norm'].custom_update_normalized = True

        self.make_output(self.act[0][::direction or 1])
        self.params.update(unit.params)
Exemplo n.º 7
0
  def __init__(self,
               n_out,
               n_units = None,
               direction = 1,
               truncation = -1,
               sampling = 1,
               encoder = None,
               unit = 'lstm',
               n_dec = 0,
               attention = "none",
               recurrent_transform = "none",
               recurrent_transform_attribs = "{}",
               attention_template = 128,
               attention_distance = 'l2',
               attention_step = "linear",
               attention_beam = 0,
               attention_norm = "exp",
               attention_momentum = "none",
               attention_sharpening = 1.0,
               attention_nbest = 0,
               attention_store = False,
               attention_smooth = False,
               attention_align = False,
               attention_glimpse = 1,
               attention_filters = 1,
               attention_accumulator = 'sum',
               attention_bn = 0,
               attention_lm = 'none',
               attention_ndec = 1,
               base = None,
               lm = False,
               force_lm = False,
               droplm = 1.0,
               forward_weights_init=None,
               bias_random_init_forget_shift=0.0,
               **kwargs):
    """
    :param n_out: number of cells
    :param n_units: used when initialized via Network.from_hdf_model_topology
    :param direction: process sequence in forward (1) or backward (-1) direction
    :param truncation: gradient truncation
    :param sampling: scan every nth frame only
    :param encoder: list of encoder layers used as initalization for the hidden state
    :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru')
    :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder
    :param recurrent_transform: name of recurrent transform
    :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform
    :param attention_template:
    :param attention_distance:
    :param attention_step:
    :param attention_beam:
    :param attention_norm:
    :param attention_sharpening:
    :param attention_nbest:
    :param attention_store:
    :param attention_align:
    :param attention_glimpse:
    :param attention_lm:
    :param base: list of layers which outputs are considered as based during attention mechanisms
    :param lm: activate RNNLM
    :param force_lm: expect previous labels to be given during testing
    :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true
    :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value
    """
    source_index = None
    if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')):
      kwargs['sources'] = []
      source_index = kwargs['index']
    unit_given = unit
    from Device import is_using_gpu
    if unit == 'lstm':  # auto selection
      if not is_using_gpu():
        unit = 'lstme'
      elif recurrent_transform == 'none' and (not lm or droplm == 0.0):
        unit = 'lstmp'
      else:
        unit = 'lstmc'
    elif unit in ("lstmc", "lstmp") and not is_using_gpu():
      unit = "lstme"
    kwargs.setdefault("n_out", n_out)
    if n_units is not None:
      assert n_units == n_out
    self.attention_weight = T.constant(1.,'float32')
    if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'):
      kwargs['sources'] = []
    elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'):
      kwargs['sources'] = []
    super(RecurrentUnitLayer, self).__init__(**kwargs)
    self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null")
    self.set_attr('n_out', n_out)
    self.set_attr('unit', unit_given.encode("utf8"))
    self.set_attr('truncation', truncation)
    self.set_attr('sampling', sampling)
    self.set_attr('direction', direction)
    self.set_attr('lm', lm)
    self.set_attr('force_lm', force_lm)
    self.set_attr('droplm', droplm)
    if bias_random_init_forget_shift:
      self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift)
    self.set_attr('attention_beam', attention_beam)
    self.set_attr('recurrent_transform', recurrent_transform.encode("utf8"))
    if isinstance(recurrent_transform_attribs, str):
      recurrent_transform_attribs = json.loads(recurrent_transform_attribs)
    if attention_template is not None:
      self.set_attr('attention_template', attention_template)
    self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs)
    self.set_attr('attention_distance', attention_distance.encode("utf8"))
    self.set_attr('attention_step', attention_step.encode("utf8"))
    self.set_attr('attention_norm', attention_norm.encode("utf8"))
    self.set_attr('attention_sharpening', attention_sharpening)
    self.set_attr('attention_nbest', attention_nbest)
    attention_store = attention_store or attention_smooth or attention_momentum != 'none'
    self.set_attr('attention_store', attention_store)
    self.set_attr('attention_smooth', attention_smooth)
    self.set_attr('attention_momentum', attention_momentum.encode('utf8'))
    self.set_attr('attention_align', attention_align)
    self.set_attr('attention_glimpse', attention_glimpse)
    self.set_attr('attention_filters', attention_filters)
    self.set_attr('attention_lm', attention_lm)
    self.set_attr('attention_bn', attention_bn)
    self.set_attr('attention_accumulator', attention_accumulator)
    self.set_attr('attention_ndec', attention_ndec)
    self.set_attr('n_dec', n_dec)
    if encoder and hasattr(encoder[0],'act'):
      self.set_attr('encoder', ",".join([e.name for e in encoder]))
    if base:
      self.set_attr('base', ",".join([b.name for b in base]))
    else:
      base = encoder
    self.base = base
    self.set_attr('n_units', n_out)
    unit = eval(unit.upper())(**self.attrs)
    assert isinstance(unit, Unit)
    self.unit = unit
    kwargs.setdefault("n_out", unit.n_out)
    n_out = unit.n_out
    self.set_attr('n_out', unit.n_out)
    if n_dec != 0:
      self.target_index = self.index
      if isinstance(n_dec,float):
        if not source_index:
          source_index = encoder[0].index if encoder else base[0].index
        lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32')
        idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]),
                            [lengths], [T.max(lengths)+1])
        self.index = idx.dimshuffle(1,0)[:-1]
        n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32')
      else:
        self.index = encoder[0].index
        self.index = T.ones((n_dec,self.index.shape[1]),'int64') # TODO: this gives a graph replacement error for int8
    else:
      n_dec = self.index.shape[0]
    # initialize recurrent weights
    self.W_re = None
    if unit.n_re > 0:
      self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name))
    # initialize forward weights
    bias_init_value = self.create_bias(unit.n_in).get_value()
    if bias_random_init_forget_shift:
      assert unit.n_units * 4 == unit.n_in  # (input gate, forget gate, output gate, net input)
      bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift
    self.b.set_value(bias_init_value)
    if not forward_weights_init:
      forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re
    else:
      self.set_attr('forward_weights_init', forward_weights_init)
    self.forward_weights_init = forward_weights_init
    self.W_in = []
    for s in self.sources:
      W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name))
      self.W_in.append(self.add_param(W))
    # make input
    z = self.b
    for x_t, m, W in zip(self.sources, self.masks, self.W_in):
      if x_t.attrs['sparse']:
        if x_t.output.ndim == 3: out_dim = x_t.output.shape[2]
        elif x_t.output.ndim == 2: out_dim = 1
        else: assert False, x_t.output.ndim
        if x_t.output.ndim == 3:
          z += W[T.cast(x_t.output[:,:,0], 'int32')]
        elif x_t.output.ndim == 2:
          z += W[T.cast(x_t.output, 'int32')]
        else:
          assert False, x_t.output.ndim
      elif m is None:
        z += T.dot(x_t.output, W)
      else:
        z += self.dot(self.mass * m * x_t.output, W)
    #if self.attrs['batch_norm']:
    #  z = self.batch_norm(z, unit.n_in)
    num_batches = self.index.shape[1]
    self.num_batches = num_batches
    non_sequences = []
    if self.attrs['lm'] or attention_lm != 'none':
      if not 'target' in self.attrs:
        self.attrs['target'] = 'classes'
      if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm):
        l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out)
        values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX)
        self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name))
        self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in')
      l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX)
      self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name))
      if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        self.lmmask = 1
        #if recurrent_transform != 'none':
        #  recurrent_transform = recurrent_transform[:-3]
      elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm):
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        srng = RandomStreams(self.rng.randint(1234) + 1)
        self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)
      else:
        self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)

    if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible)
      src = []
      src_names = []
      n_in = 0
      for e in base:
        src_base = [ s for s in e.sources if s.name not in src_names ]
        src_names += [ s.name for s in e.sources ]
        src += [s.output for s in src_base]
        n_in += sum([s.attrs['n_out'] for s in src_base])
      self.xc = T.concatenate(src, axis=2)
      l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX)
      self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc"))
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX)
      self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in"))
      zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1
      self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in)
      recurrent_transform = 'none'
    recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self)
    assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase)
    unit.recurrent_transform = recurrent_transform_inst
    self.recurrent_transform = recurrent_transform_inst

    # scan over sequence
    for s in range(self.attrs['sampling']):
      index = self.index[s::self.attrs['sampling']]
      sequences = z
      sources = self.sources
      if encoder:
        if hasattr(encoder[0],'act'):
          outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ]
        else:
          outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ]
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else 0)
      else:
        outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ]

      if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        if self.network.y[self.attrs['target']].ndim == 3:
          sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out)
        else:
          y = self.y_in[self.attrs['target']].flatten()
          sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in))

      if sequences == self.b:
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else 0)

      if unit.recurrent_transform:
        outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial()

      index_f = T.cast(index, theano.config.floatX)
      unit.set_parent(self)
      outputs = unit.scan(x=sources,
                          z=sequences[s::self.attrs['sampling']],
                          non_sequences=non_sequences,
                          i=index_f,
                          outputs_info=outputs_info,
                          W_re=self.W_re,
                          W_in=self.W_in,
                          b=self.b,
                          go_backwards=direction == -1,
                          truncate_gradient=self.attrs['truncation'])

      if not isinstance(outputs, list):
        outputs = [outputs]
      if outputs:
        outputs[0].name = "%s.act[0]" % self.name

      if unit.recurrent_transform:
        unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):]

      if self.attrs['sampling'] > 1:
        if s == 0:
          self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ]
        self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ]
      else:
        self.act = outputs[:unit.n_act]
        if len(outputs) > unit.n_act:
          self.aux = outputs[unit.n_act:]
    if self.attrs['attention_store']:
      self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT
      for i in range(len(self.attention)):
        vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1))
        last = vec.dimshuffle(1,'x', 0).repeat(self.index.shape[1], axis=1)
        self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction]

    if self.attrs['attention_align']:
      bp = [ self.aux[i] for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('K_') ]
      def backtrace(k,i_p):
        return i_p - k[i_p,T.arange(k.shape[1],dtype='int32')]
      self.alignment = []
      for K in bp: # K: NTB
        aln, _ = theano.scan(backtrace, sequences=[T.cast(K,'int32').dimshuffle(1,0,2)],
                             outputs_info=[T.cast(self.index.shape[0] - 1,'int32') + T.zeros((K.shape[2],),'int32')])
        aln = theano.printing.Print("aln")(aln)
        self.alignment.append(aln) # TB

    if recurrent_transform == 'batch_norm':
      self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re)
      self.params['sample_mean_batch_norm'].custom_update_normalized = True

    self.make_output(self.act[0][::direction or 1])
    self.params.update(unit.params)
Exemplo n.º 8
0
  def __init__(self,
               n_out = None,
               n_units = None,
               direction = 1,
               truncation = -1,
               sampling = 1,
               encoder = None,
               unit = 'lstm',
               n_dec = 0,
               attention = "none",
               recurrent_transform = "none",
               recurrent_transform_attribs = "{}",
               attention_template = 128,
               attention_distance = 'l2',
               attention_step = "linear",
               attention_beam = 0,
               attention_norm = "exp",
               attention_momentum = "none",
               attention_sharpening = 1.0,
               attention_nbest = 0,
               attention_store = False,
               attention_smooth = False,
               attention_glimpse = 1,
               attention_filters = 1,
               attention_accumulator = 'sum',
               attention_loss = 0,
               attention_bn = 0,
               attention_lm = 'none',
               attention_ndec = 1,
               attention_memory = 0,
               attention_alnpts = 0,
               attention_epoch  = 1,
               attention_segstep=0.01,
               attention_offset=0.95,
               attention_method="epoch",
               attention_scale=10,
               context=-1,
               base = None,
               aligner = None,
               lm = False,
               force_lm = False,
               droplm = 1.0,
               forward_weights_init=None,
               bias_random_init_forget_shift=0.0,
               copy_weights_from_base=False,
               segment_input=False,
               join_states=False,
               state_memory=False,
               sample_segment=None,
               **kwargs):
    """
    :param n_out: number of cells
    :param n_units: used when initialized via Network.from_hdf_model_topology
    :param direction: process sequence in forward (1) or backward (-1) direction
    :param truncation: gradient truncation
    :param sampling: scan every nth frame only
    :param encoder: list of encoder layers used as initalization for the hidden state
    :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru')
    :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder
    :param recurrent_transform: name of recurrent transform
    :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform
    :param attention_template:
    :param attention_distance:
    :param attention_step:
    :param attention_beam:
    :param attention_norm:
    :param attention_sharpening:
    :param attention_nbest:
    :param attention_store:
    :param attention_align:
    :param attention_glimpse:
    :param attention_lm:
    :param base: list of layers which outputs are considered as based during attention mechanisms
    :param lm: activate RNNLM
    :param force_lm: expect previous labels to be given during testing
    :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true
    :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value
    """
    source_index = None
    if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')):
      kwargs['sources'] = []
      source_index = kwargs['index']
    unit_given = unit
    from Device import is_using_gpu
    if unit == 'lstm':  # auto selection
      if not is_using_gpu():
        unit = 'lstme'
      elif recurrent_transform == 'none' and (not lm or droplm == 0.0):
        unit = 'lstmp'
      else:
        unit = 'lstmc'
    elif unit in ("lstmc", "lstmp") and not is_using_gpu():
      unit = "lstme"
    if segment_input:
      if is_using_gpu():
        unit = "lstmps"
      else:
        unit = "lstms"
    if n_out is None:
      assert encoder
      n_out = sum([enc.attrs['n_out'] for enc in encoder])
    kwargs.setdefault("n_out", n_out)
    if n_units is not None:
      assert n_units == n_out
    self.attention_weight = T.constant(1.,'float32')
    if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'):
      kwargs['sources'] = []
    elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'):
      kwargs['sources'] = []
    super(RecurrentUnitLayer, self).__init__(**kwargs)
    self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null")
    self.set_attr('n_out', n_out)
    self.set_attr('unit', unit_given.encode("utf8"))
    self.set_attr('truncation', truncation)
    self.set_attr('sampling', sampling)
    self.set_attr('direction', direction)
    self.set_attr('lm', lm)
    self.set_attr('force_lm', force_lm)
    self.set_attr('droplm', droplm)
    if bias_random_init_forget_shift:
      self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift)
    self.set_attr('attention_beam', attention_beam)
    self.set_attr('recurrent_transform', recurrent_transform.encode("utf8"))
    if isinstance(recurrent_transform_attribs, str):
      recurrent_transform_attribs = json.loads(recurrent_transform_attribs)
    if attention_template is not None:
      self.set_attr('attention_template', attention_template)
    self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs)
    self.set_attr('attention_distance', attention_distance.encode("utf8"))
    self.set_attr('attention_step', attention_step.encode("utf8"))
    self.set_attr('attention_norm', attention_norm.encode("utf8"))
    self.set_attr('attention_sharpening', attention_sharpening)
    self.set_attr('attention_nbest', attention_nbest)
    attention_store = attention_store or attention_smooth or attention_momentum != 'none'
    self.set_attr('attention_store', attention_store)
    self.set_attr('attention_smooth', attention_smooth)
    self.set_attr('attention_momentum', attention_momentum.encode('utf8'))
    self.set_attr('attention_glimpse', attention_glimpse)
    self.set_attr('attention_filters', attention_filters)
    self.set_attr('attention_lm', attention_lm)
    self.set_attr('attention_bn', attention_bn)
    self.set_attr('attention_accumulator', attention_accumulator)
    self.set_attr('attention_ndec', attention_ndec)
    self.set_attr('attention_memory', attention_memory)
    self.set_attr('attention_loss', attention_loss)
    self.set_attr('n_dec', n_dec)
    self.set_attr('segment_input', segment_input)
    self.set_attr('attention_alnpts', attention_alnpts)
    self.set_attr('attention_epoch', attention_epoch)
    self.set_attr('attention_segstep', attention_segstep)
    self.set_attr('attention_offset', attention_offset)
    self.set_attr('attention_method', attention_method)
    self.set_attr('attention_scale', attention_scale)
    if segment_input:
      if not self.eval_flag:
      #if self.eval_flag:
        if isinstance(self.sources[0],RecurrentUnitLayer):
          self.inv_att = self.sources[0].inv_att #NBT
        else:
          if not join_states:
            self.inv_att = self.sources[0].attention #NBT
          else:
            assert hasattr(self.sources[0], "nstates"), "source does not have number of states!"
            ns = self.sources[0].nstates
            self.inv_att = self.sources[0].attention[(ns-1)::ns]
        inv_att = T.roll(self.inv_att.dimshuffle(2, 1, 0),1,axis=0)#TBN
        inv_att = T.set_subtensor(inv_att[0],T.zeros((inv_att.shape[1],inv_att.shape[2])))
        inv_att = T.max(inv_att,axis=-1)
      else:
        inv_att = T.zeros((self.sources[0].output.shape[0],self.sources[0].output.shape[1]))
    if encoder and hasattr(encoder[0],'act'):
      self.set_attr('encoder', ",".join([e.name for e in encoder]))
    if base:
      self.set_attr('base', ",".join([b.name for b in base]))
    else:
      base = encoder
    self.base = base
    self.encoder = encoder
    if aligner:
      self.aligner = aligner
    self.set_attr('n_units', n_out)
    unit = eval(unit.upper())(**self.attrs)
    assert isinstance(unit, Unit)
    self.unit = unit
    kwargs.setdefault("n_out", unit.n_out)
    n_out = unit.n_out
    self.set_attr('n_out', unit.n_out)
    if n_dec < 0:
      source_index = self.index
      n_dec *= -1
    if n_dec != 0:
      self.target_index = self.index
      if isinstance(n_dec,float):
        if not source_index:
          source_index = encoder[0].index if encoder else base[0].index
        lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32')
        idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]),
                            [lengths], [T.max(lengths)+1])
        self.index = idx.dimshuffle(1,0)[:-1]
        n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32')
      else:
        if encoder:
          self.index = encoder[0].index
        self.index = T.ones((n_dec,self.index.shape[1]),'int8')
    else:
      n_dec = self.index.shape[0]
    # initialize recurrent weights
    self.W_re = None
    if unit.n_re > 0:
      self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name))
    # initialize forward weights
    bias_init_value = self.create_bias(unit.n_in).get_value()
    if bias_random_init_forget_shift:
      assert unit.n_units * 4 == unit.n_in  # (input gate, forget gate, output gate, net input)
      bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift
    self.b.set_value(bias_init_value)
    if not forward_weights_init:
      forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re
    else:
      self.set_attr('forward_weights_init', forward_weights_init)
    self.forward_weights_init = forward_weights_init
    self.W_in = []
    sample_mean, gamma = None, None
    if copy_weights_from_base:
      self.params = {}
      #self.W_re = self.add_param(base[0].W_re)
      #self.W_in = [ self.add_param(W) for W in base[0].W_in ]
      #self.b = self.add_param(base[0].b)
      self.W_re = base[0].W_re
      self.W_in = base[0].W_in
      self.b = base[0].b
      if self.attrs.get('batch_norm', False):
        sample_mean = base[0].sample_mean
        gamma = base[0].gamma
      #self.masks = base[0].masks
      #self.mass = base[0].mass
    else:
      for s in self.sources:
        W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name))
        self.W_in.append(self.add_param(W))
    # make input
    z = self.b
    for x_t, m, W in zip(self.sources, self.masks, self.W_in):
      if x_t.attrs['sparse']:
        if x_t.output.ndim == 3: out_dim = x_t.output.shape[2]
        elif x_t.output.ndim == 2: out_dim = 1
        else: assert False, x_t.output.ndim
        if x_t.output.ndim == 3:
          z += W[T.cast(x_t.output[:,:,0], 'int32')]
        elif x_t.output.ndim == 2:
          z += W[T.cast(x_t.output, 'int32')]
        else:
          assert False, x_t.output.ndim
      elif m is None:
        z += T.dot(x_t.output, W)
      else:
        z += self.dot(self.mass * m * x_t.output, W)
    #if self.attrs['batch_norm']:
    #  z = self.batch_norm(z, unit.n_in)
    num_batches = self.index.shape[1]
    self.num_batches = num_batches
    non_sequences = []
    if self.attrs['lm'] or attention_lm != 'none':
      if not 'target' in self.attrs:
        self.attrs['target'] = 'classes'
      if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm):
        if copy_weights_from_base:
          self.W_lm_in = base[0].W_lm_in
          self.b_lm_in = base[0].b_lm_in
        else:
          l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out)
          values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX)
          self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name))
          self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in')
      l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX)
      if copy_weights_from_base:
        self.W_lm_out = base[0].W_lm_out
      else:
        self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name))
      if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        self.lmmask = 1
        #if recurrent_transform != 'none':
        #  recurrent_transform = recurrent_transform[:-3]
      elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm):
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        srng = RandomStreams(self.rng.randint(1234) + 1)
        self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)
      else:
        self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2)

    if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible)
      src = []
      src_names = []
      n_in = 0
      for e in base:
        #src_base = [ s for s in e.sources if s.name not in src_names ]
        #src_names += [ s.name for s in e.sources ]
        src_base = [ e ]
        src_names += [e.name]
        src += [s.output for s in src_base]
        n_in += sum([s.attrs['n_out'] for s in src_base])
      self.xc = T.concatenate(src, axis=2)
      l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in)
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX)
      self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc"))
      values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX)
      self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in"))
      zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1
      self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in)
      recurrent_transform = 'none'
    elif recurrent_transform == 'attention_align':
      max_skip = base[0].attrs['max_skip']
      values = numpy.zeros((max_skip,), dtype=theano.config.floatX)
      self.T_b = self.add_param(self.shared(value=values, borrow=True, name="T_b"), name="T_b")
      l = sqrt(6.) / sqrt(self.attrs['n_out'] + max_skip)
      values = numpy.asarray(self.rng.uniform(
        low=-l, high=l, size=(self.attrs['n_out'], max_skip)), dtype=theano.config.floatX)
      self.T_W = self.add_param(self.shared(value=values, borrow=True, name="T_W"), name="T_W")
      y_t = T.dot(self.base[0].attention, T.arange(self.base[0].output.shape[0], dtype='float32'))  # NB
      y_t = T.concatenate([T.zeros_like(y_t[:1]), y_t], axis=0)  # (N+1)B
      y_t = y_t[1:] - y_t[:-1]  # NB
      self.y_t = y_t # T.clip(y_t,numpy.float32(0),numpy.float32(max_skip - 1))

      self.y_t = T.cast(self.base[0].backtrace,'float32')
    elif recurrent_transform == 'attention_segment':
      assert aligner.attention, "Segment-wise attention requires attention points!"

    recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self)
    assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase)
    unit.recurrent_transform = recurrent_transform_inst
    self.recurrent_transform = recurrent_transform_inst
    state_memory *= self.train_flag
    # scan over sequence
    for s in range(self.attrs['sampling']):
      index = self.index[s::self.attrs['sampling']]

      if context > 0:
        from TheanoUtil import context_batched
        n_batches = z.shape[1]
        time, batch, dim = z.shape[0], z.shape[1], z.shape[2]
        #z = context_batched(z[::direction or 1], window=context)[::direction or 1] # TB(CD)

        from theano.ifelse import ifelse
        def context_window(idx, x_in, i_in):
          x_out = x_in[idx:idx + context]
          x_out = x_out.dimshuffle('x',1,0,2).reshape((1, batch, dim * context))
          i_out = i_in[idx:idx+1].repeat(context, axis=0)
          i_out = ifelse(T.lt(idx,context),T.set_subtensor(i_out[:context - idx],numpy.int8(0)),i_out).reshape((1, batch * context))
          return x_out, i_out

        z = z[::direction or 1]
        i = index[::direction or 1]
        out, _ = theano.map(context_window, sequences = [T.arange(z.shape[0])], non_sequences = [T.concatenate([T.zeros((context - 1,z.shape[1],z.shape[2]),dtype='float32'),z],axis=0), i])
        z = out[0][::direction or 1]
        i = out[1][::direction or 1]  # T(BC)
        direction = 1
        z = z.reshape((time * batch, context * dim))  # (TB)(CD)
        z = z.reshape((time * batch, context, dim)).dimshuffle(1,0,2)  # C(TB)D
        i = i.reshape((time * batch, context)).dimshuffle(1,0)  # C(TB)

        index = i
        num_batches = time * batch

      sequences = z
      sources = self.sources
      if state_memory:
        self.init_state = [
          self.add_param(self.shared(numpy.zeros((state_memory or 1, unit.n_units), dtype='float32'), name='init_%d_%s' % (a, self.name)))
          for a in range(unit.n_act)]  # has to be initialized for train and test
      if encoder:
        if recurrent_transform == "attention_segment":
          if hasattr(encoder[0],'act'):
            outputs_info = [T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act)]
          else:
           # outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ]
            outputs_info[0] = self.aligner.output[-1]
        elif hasattr(encoder[0],'act'):
          outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ]
        else:
          outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ]
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0))
      elif state_memory:
        outputs_info = self.init_state
      else:
        outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ]

      if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm):
        if self.network.y[self.attrs['target']].ndim == 3:
          sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out)
        else:
          y = self.y_in[self.attrs['target']].flatten()
          sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in))

      if sequences == self.b:
        sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0))

      if unit.recurrent_transform:
        outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial()

      index_f = T.cast(index, theano.config.floatX)
      unit.set_parent(self)

      if segment_input:
        outputs = unit.scan_seg(x=sources,
                                z=sequences[s::self.attrs['sampling']],
                                att = inv_att,
                                non_sequences=non_sequences,
                                i=index_f,
                                outputs_info=outputs_info,
                                W_re=self.W_re,
                                W_in=self.W_in,
                                b=self.b,
                                go_backwards=direction == -1,
                                truncate_gradient=self.attrs['truncation'])
      else:
        outputs = unit.scan(x=sources,
                            z=sequences[s::self.attrs['sampling']],
                            non_sequences=non_sequences,
                            i=index_f,
                            outputs_info=outputs_info,
                            W_re=self.W_re,
                            W_in=self.W_in,
                            b=self.b,
                            go_backwards=direction == -1,
                            truncate_gradient=self.attrs['truncation'])

      if not isinstance(outputs, list):
        outputs = [outputs]
      if outputs:
        outputs[0].name = "%s.act[0]" % self.name
        if context > 0:
          for i in range(len(outputs)):
            outputs[i] = outputs[i][-1].reshape((outputs[i].shape[1]//n_batches,n_batches,outputs[i].shape[2]))

      if unit.recurrent_transform:
        unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):]

      if self.attrs['sampling'] > 1:
        if s == 0:
          self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ]
        self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ]
      else:
        self.act = outputs[:unit.n_act]
        if len(outputs) > unit.n_act:
          self.aux = outputs[unit.n_act:]
        if state_memory:
          for i in range(len(self.act)):
            self.init_state[i].live_update = self.act[i][-1]
    if self.attrs['attention_store']:
      self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT
      for i in range(len(self.attention)):
        vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1))
        last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1], axis=1)
        self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction]

    self.cost_val = numpy.float32(0)
    if recurrent_transform == 'attention_align':
      back = T.ceil(self.aux[sorted(unit.recurrent_transform.state_vars.keys()).index('t')])
      def make_output(base, yout, trace, length):
        length = T.cast(length, 'int32')
        idx = T.cast(trace[:length][::-1],'int32')
        x_out = T.concatenate([base[idx],T.zeros((self.index.shape[0] + 1 - length, base.shape[1]), 'float32')],axis=0)
        y_out = T.concatenate([yout[idx,T.arange(length)],T.zeros((self.index.shape[0] + 1 - length, ), 'float32')],axis=0)
        return x_out, y_out

      output, _ = theano.map(make_output,
                             sequences = [base[0].output.dimshuffle(1,0,2),
                                          self.y_t.dimshuffle(1,2,0),
                                          back.dimshuffle(1,0),
                                          T.sum(self.index,axis=0,dtype='float32')])
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      self.output = output[0].dimshuffle(1,0,2)[:-1]

      z = T.dot(self.act[0], self.T_W)[:-1] + self.T_b
      z = z.reshape((z.shape[0] * z.shape[1], z.shape[2]))
      idx = (self.index[1:].flatten() > 0).nonzero()
      idy = (self.index[1:][::-1].flatten() > 0).nonzero()
      y_out = T.cast(output[1],'int32').dimshuffle(1, 0)[:-1].flatten()
      nll, _ = T.nnet.crossentropy_softmax_1hot(x=z[idx], y_idx=y_out[idy])
      self.cost_val = T.sum(nll)
      recog = T.argmax(z[idx], axis=1)
      real = y_out[idy]
      self.errors = lambda: T.sum(T.neq(recog, real))

      return

      back += T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32')
      idx = (self.index[:-1].flatten() > 0).nonzero()
      idx = T.cast(back[::-1].flatten()[idx],'int32')
      x_out = base[0].output
      #x_out = x_out.dimshuffle(1,0,2).reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      #x_out = x_out.reshape((self.index.shape[1], self.index.shape[0] - 1, x_out.shape[1])).dimshuffle(1,0,2)
      x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      x_out = x_out.reshape((self.index.shape[0] - 1, self.index.shape[1], x_out.shape[1]))
      self.output = T.concatenate([x_out, base[0].output[1:]],axis=0)
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      return


      skips = T.dot(T.nnet.softmax(z), T.arange(z.shape[1], dtype='float32')).reshape(self.index[1:].shape)
      shift = T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32')
      skips = T.concatenate([T.zeros_like(self.y_t[:1]),self.y_t[:-1]],axis=0)
      idx = shift + T.cumsum(skips, axis=0)
      idx = T.cast(idx[:-1].flatten(),'int32')
      #idx = (idx.flatten() > 0).nonzero()
      #idx = base[0].attention.flatten()
      x_out = base[0].output[::-1]
      x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx]
      x_out = x_out.reshape((self.index.shape[0], self.index.shape[1], x_out.shape[1]))
      self.output = T.concatenate([base[0].output[-1:], x_out], axis=0)[::-1]
      self.attrs['n_out'] = base[0].attrs['n_out']
      self.params.update(unit.params)
      return

    if recurrent_transform == 'batch_norm':
      self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re)
      self.params['sample_mean_batch_norm'].custom_update_normalized = True

    self.make_output(self.act[0][::direction or 1], sample_mean=sample_mean, gamma=gamma)
    self.params.update(unit.params)