def conv_crop_pool_op(X, sizes, output_sizes, W, b, n_in, n_maps, filter_height, filter_width, filter_dilation, poolsize): from Device import is_using_gpu if is_using_gpu(): conv_op = CuDNNConvHWBCOpValidInstance pool_op = PoolHWBCOp(poolsize) conv_out = conv_op(X, W, b) if filter_height * filter_width > 0 else X crop_out = CropToBatchImageSizeInstance(conv_out, sizes) Y = pool_op(crop_out) Y = CropToBatchImageSizeZeroInstance(Y, output_sizes) else: Y = X return Y
def circular_convolution(a, b): from Device import is_using_gpu has_gpuarray = is_using_gpu() try: import pygpu except Exception: has_gpuarray = False if has_gpuarray: from theano.gpuarray.fft import curfft as fft from theano.gpuarray.fft import cuirfft as ifft else: from theano.tensor.fft import rfft as fft from theano.tensor.fft import irfft as ifft return ifft(fft(a) * fft(b))
def __init__(self, n_out = None, n_units = None, direction = 1, truncation = -1, sampling = 1, encoder = None, unit = 'lstm', n_dec = 0, attention = "none", recurrent_transform = "none", recurrent_transform_attribs = "{}", attention_template = 128, attention_distance = 'l2', attention_step = "linear", attention_beam = 0, attention_norm = "exp", attention_momentum = "none", attention_sharpening = 1.0, attention_nbest = 0, attention_store = False, attention_smooth = False, attention_glimpse = 1, attention_filters = 1, attention_accumulator = 'sum', attention_loss = 0, attention_bn = 0, attention_lm = 'none', attention_ndec = 1, attention_memory = 0, attention_alnpts = 0, attention_epoch = 1, attention_segstep=0.01, attention_offset=0.95, attention_method="epoch", attention_scale=10, context=-1, base = None, aligner = None, lm = False, force_lm = False, droplm = 1.0, forward_weights_init=None, bias_random_init_forget_shift=0.0, copy_weights_from_base=False, segment_input=False, join_states=False, sample_segment=None, **kwargs): """ :param n_out: number of cells :param n_units: used when initialized via Network.from_hdf_model_topology :param direction: process sequence in forward (1) or backward (-1) direction :param truncation: gradient truncation :param sampling: scan every nth frame only :param encoder: list of encoder layers used as initalization for the hidden state :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru') :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder :param recurrent_transform: name of recurrent transform :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform :param attention_template: :param attention_distance: :param attention_step: :param attention_beam: :param attention_norm: :param attention_sharpening: :param attention_nbest: :param attention_store: :param attention_align: :param attention_glimpse: :param attention_lm: :param base: list of layers which outputs are considered as based during attention mechanisms :param lm: activate RNNLM :param force_lm: expect previous labels to be given during testing :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value """ source_index = None if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')): kwargs['sources'] = [] source_index = kwargs['index'] unit_given = unit from Device import is_using_gpu if unit == 'lstm': # auto selection if not is_using_gpu(): unit = 'lstme' elif recurrent_transform == 'none' and (not lm or droplm == 0.0): unit = 'lstmp' else: unit = 'lstmc' elif unit in ("lstmc", "lstmp") and not is_using_gpu(): unit = "lstme" if segment_input: if is_using_gpu(): unit = "lstmps" else: unit = "lstms" if n_out is None: assert encoder n_out = sum([enc.attrs['n_out'] for enc in encoder]) kwargs.setdefault("n_out", n_out) if n_units is not None: assert n_units == n_out self.attention_weight = T.constant(1.,'float32') if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'): kwargs['sources'] = [] elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'): kwargs['sources'] = [] super(RecurrentUnitLayer, self).__init__(**kwargs) self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null") self.set_attr('n_out', n_out) self.set_attr('unit', unit_given.encode("utf8")) self.set_attr('truncation', truncation) self.set_attr('sampling', sampling) self.set_attr('direction', direction) self.set_attr('lm', lm) self.set_attr('force_lm', force_lm) self.set_attr('droplm', droplm) if bias_random_init_forget_shift: self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift) self.set_attr('attention_beam', attention_beam) self.set_attr('recurrent_transform', recurrent_transform.encode("utf8")) if isinstance(recurrent_transform_attribs, str): recurrent_transform_attribs = json.loads(recurrent_transform_attribs) if attention_template is not None: self.set_attr('attention_template', attention_template) self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs) self.set_attr('attention_distance', attention_distance.encode("utf8")) self.set_attr('attention_step', attention_step.encode("utf8")) self.set_attr('attention_norm', attention_norm.encode("utf8")) self.set_attr('attention_sharpening', attention_sharpening) self.set_attr('attention_nbest', attention_nbest) attention_store = attention_store or attention_smooth or attention_momentum != 'none' self.set_attr('attention_store', attention_store) self.set_attr('attention_smooth', attention_smooth) self.set_attr('attention_momentum', attention_momentum.encode('utf8')) self.set_attr('attention_glimpse', attention_glimpse) self.set_attr('attention_filters', attention_filters) self.set_attr('attention_lm', attention_lm) self.set_attr('attention_bn', attention_bn) self.set_attr('attention_accumulator', attention_accumulator) self.set_attr('attention_ndec', attention_ndec) self.set_attr('attention_memory', attention_memory) self.set_attr('attention_loss', attention_loss) self.set_attr('n_dec', n_dec) self.set_attr('segment_input', segment_input) self.set_attr('attention_alnpts', attention_alnpts) self.set_attr('attention_epoch', attention_epoch) self.set_attr('attention_segstep', attention_segstep) self.set_attr('attention_offset', attention_offset) self.set_attr('attention_method', attention_method) self.set_attr('attention_scale', attention_scale) if segment_input: if not self.eval_flag: #if self.eval_flag: if isinstance(self.sources[0],RecurrentUnitLayer): self.inv_att = self.sources[0].inv_att #NBT else: if not join_states: self.inv_att = self.sources[0].attention #NBT else: assert hasattr(self.sources[0], "nstates"), "source does not have number of states!" ns = self.sources[0].nstates self.inv_att = self.sources[0].attention[(ns-1)::ns] inv_att = T.roll(self.inv_att.dimshuffle(2, 1, 0),1,axis=0)#TBN inv_att = T.set_subtensor(inv_att[0],T.zeros((inv_att.shape[1],inv_att.shape[2]))) inv_att = T.max(inv_att,axis=-1) else: inv_att = T.zeros((self.sources[0].output.shape[0],self.sources[0].output.shape[1])) if encoder and hasattr(encoder[0],'act'): self.set_attr('encoder', ",".join([e.name for e in encoder])) if base: self.set_attr('base', ",".join([b.name for b in base])) else: base = encoder self.base = base self.encoder = encoder if aligner: self.aligner = aligner self.set_attr('n_units', n_out) unit = eval(unit.upper())(**self.attrs) assert isinstance(unit, Unit) self.unit = unit kwargs.setdefault("n_out", unit.n_out) n_out = unit.n_out self.set_attr('n_out', unit.n_out) if n_dec < 0: source_index = self.index n_dec *= -1 if n_dec != 0: self.target_index = self.index if isinstance(n_dec,float): if not source_index: source_index = encoder[0].index if encoder else base[0].index lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32') idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]), [lengths], [T.max(lengths)+1]) self.index = idx.dimshuffle(1,0)[:-1] n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32') else: if encoder: self.index = encoder[0].index self.index = T.ones((n_dec,self.index.shape[1]),'int8') else: n_dec = self.index.shape[0] # initialize recurrent weights self.W_re = None if unit.n_re > 0: self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name)) # initialize forward weights bias_init_value = self.create_bias(unit.n_in).get_value() if bias_random_init_forget_shift: assert unit.n_units * 4 == unit.n_in # (input gate, forget gate, output gate, net input) bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift self.b.set_value(bias_init_value) if not forward_weights_init: forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re else: self.set_attr('forward_weights_init', forward_weights_init) self.forward_weights_init = forward_weights_init self.W_in = [] sample_mean, gamma = None, None if copy_weights_from_base: self.params = {} #self.W_re = self.add_param(base[0].W_re) #self.W_in = [ self.add_param(W) for W in base[0].W_in ] #self.b = self.add_param(base[0].b) self.W_re = base[0].W_re self.W_in = base[0].W_in self.b = base[0].b if self.attrs.get('batch_norm', False): sample_mean = base[0].sample_mean gamma = base[0].gamma #self.masks = base[0].masks #self.mass = base[0].mass else: for s in self.sources: W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name)) self.W_in.append(self.add_param(W)) # make input z = self.b for x_t, m, W in zip(self.sources, self.masks, self.W_in): if x_t.attrs['sparse']: if x_t.output.ndim == 3: out_dim = x_t.output.shape[2] elif x_t.output.ndim == 2: out_dim = 1 else: assert False, x_t.output.ndim if x_t.output.ndim == 3: z += W[T.cast(x_t.output[:,:,0], 'int32')] elif x_t.output.ndim == 2: z += W[T.cast(x_t.output, 'int32')] else: assert False, x_t.output.ndim elif m is None: z += T.dot(x_t.output, W) else: z += self.dot(self.mass * m * x_t.output, W) #if self.attrs['batch_norm']: # z = self.batch_norm(z, unit.n_in) num_batches = self.index.shape[1] self.num_batches = num_batches non_sequences = [] if self.attrs['lm'] or attention_lm != 'none': if not 'target' in self.attrs: self.attrs['target'] = 'classes' if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm): if copy_weights_from_base: self.W_lm_in = base[0].W_lm_in self.b_lm_in = base[0].b_lm_in else: l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX) self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name)) self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in') l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX) if copy_weights_from_base: self.W_lm_out = base[0].W_lm_out else: self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name)) if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): self.lmmask = 1 #if recurrent_transform != 'none': # recurrent_transform = recurrent_transform[:-3] elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm): from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams srng = RandomStreams(self.rng.randint(1234) + 1) self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) else: self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible) src = [] src_names = [] n_in = 0 for e in base: #src_base = [ s for s in e.sources if s.name not in src_names ] #src_names += [ s.name for s in e.sources ] src_base = [ e ] src_names += [e.name] src += [s.output for s in src_base] n_in += sum([s.attrs['n_out'] for s in src_base]) self.xc = T.concatenate(src, axis=2) l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX) self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc")) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX) self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in")) zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1 self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in) recurrent_transform = 'none' elif recurrent_transform == 'attention_align': max_skip = base[0].attrs['max_skip'] values = numpy.zeros((max_skip,), dtype=theano.config.floatX) self.T_b = self.add_param(self.shared(value=values, borrow=True, name="T_b"), name="T_b") l = sqrt(6.) / sqrt(self.attrs['n_out'] + max_skip) values = numpy.asarray(self.rng.uniform( low=-l, high=l, size=(self.attrs['n_out'], max_skip)), dtype=theano.config.floatX) self.T_W = self.add_param(self.shared(value=values, borrow=True, name="T_W"), name="T_W") y_t = T.dot(self.base[0].attention, T.arange(self.base[0].output.shape[0], dtype='float32')) # NB y_t = T.concatenate([T.zeros_like(y_t[:1]), y_t], axis=0) # (N+1)B y_t = y_t[1:] - y_t[:-1] # NB self.y_t = y_t # T.clip(y_t,numpy.float32(0),numpy.float32(max_skip - 1)) self.y_t = T.cast(self.base[0].backtrace,'float32') elif recurrent_transform == 'attention_segment': assert aligner.attention, "Segment-wise attention requires attention points!" recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self) assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase) unit.recurrent_transform = recurrent_transform_inst self.recurrent_transform = recurrent_transform_inst # scan over sequence for s in range(self.attrs['sampling']): index = self.index[s::self.attrs['sampling']] if context > 0: from TheanoUtil import context_batched n_batches = z.shape[1] time, batch, dim = z.shape[0], z.shape[1], z.shape[2] #z = context_batched(z[::direction or 1], window=context)[::direction or 1] # TB(CD) from theano.ifelse import ifelse def context_window(idx, x_in, i_in): x_out = x_in[idx:idx + context] x_out = x_out.dimshuffle('x',1,0,2).reshape((1, batch, dim * context)) i_out = i_in[idx:idx+1].repeat(context, axis=0) i_out = ifelse(T.lt(idx,context),T.set_subtensor(i_out[:context - idx],numpy.int8(0)),i_out).reshape((1, batch * context)) return x_out, i_out z = z[::direction or 1] i = index[::direction or 1] out, _ = theano.map(context_window, sequences = [T.arange(z.shape[0])], non_sequences = [T.concatenate([T.zeros((context - 1,z.shape[1],z.shape[2]),dtype='float32'),z],axis=0), i]) z = out[0][::direction or 1] i = out[1][::direction or 1] # T(BC) direction = 1 z = z.reshape((time * batch, context * dim)) # (TB)(CD) z = z.reshape((time * batch, context, dim)).dimshuffle(1,0,2) # C(TB)D i = i.reshape((time, context, batch)).dimshuffle(1,0,2).reshape((context, time * batch)) index = i num_batches = time * batch sequences = z sources = self.sources if encoder: if recurrent_transform == "attention_segment": if hasattr(encoder[0],'act'): outputs_info = [T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act)] else: # outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] outputs_info[0] = self.aligner.output[-1] elif hasattr(encoder[0],'act'): outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ] else: outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) else: outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ] if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): if self.network.y[self.attrs['target']].ndim == 3: sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out) else: y = self.y_in[self.attrs['target']].flatten() sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in)) if sequences == self.b: sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) if unit.recurrent_transform: outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial() index_f = T.cast(index, theano.config.floatX) unit.set_parent(self) if segment_input: outputs = unit.scan_seg(x=sources, z=sequences[s::self.attrs['sampling']], att = inv_att, non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) else: outputs = unit.scan(x=sources, z=sequences[s::self.attrs['sampling']], non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) if not isinstance(outputs, list): outputs = [outputs] if outputs: outputs[0].name = "%s.act[0]" % self.name if context > 0: for i in range(len(outputs)): outputs[i] = outputs[i][-1].reshape((outputs[i].shape[1]//n_batches,n_batches,outputs[i].shape[2])) if unit.recurrent_transform: unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):] if self.attrs['sampling'] > 1: if s == 0: self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ] self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ] else: self.act = outputs[:unit.n_act] if len(outputs) > unit.n_act: self.aux = outputs[unit.n_act:] if self.attrs['attention_store']: self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT for i in range(len(self.attention)): vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1)) last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1], axis=1) self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction] self.cost_val = numpy.float32(0) if recurrent_transform == 'attention_align': back = T.ceil(self.aux[sorted(unit.recurrent_transform.state_vars.keys()).index('t')]) def make_output(base, yout, trace, length): length = T.cast(length, 'int32') idx = T.cast(trace[:length][::-1],'int32') x_out = T.concatenate([base[idx],T.zeros((self.index.shape[0] + 1 - length, base.shape[1]), 'float32')],axis=0) y_out = T.concatenate([yout[idx,T.arange(length)],T.zeros((self.index.shape[0] + 1 - length, ), 'float32')],axis=0) return x_out, y_out output, _ = theano.map(make_output, sequences = [base[0].output.dimshuffle(1,0,2), self.y_t.dimshuffle(1,2,0), back.dimshuffle(1,0), T.sum(self.index,axis=0,dtype='float32')]) self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) self.output = output[0].dimshuffle(1,0,2)[:-1] z = T.dot(self.act[0], self.T_W)[:-1] + self.T_b z = z.reshape((z.shape[0] * z.shape[1], z.shape[2])) idx = (self.index[1:].flatten() > 0).nonzero() idy = (self.index[1:][::-1].flatten() > 0).nonzero() y_out = T.cast(output[1],'int32').dimshuffle(1, 0)[:-1].flatten() nll, _ = T.nnet.crossentropy_softmax_1hot(x=z[idx], y_idx=y_out[idy]) self.cost_val = T.sum(nll) recog = T.argmax(z[idx], axis=1) real = y_out[idy] self.errors = lambda: T.sum(T.neq(recog, real)) return back += T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32') idx = (self.index[:-1].flatten() > 0).nonzero() idx = T.cast(back[::-1].flatten()[idx],'int32') x_out = base[0].output #x_out = x_out.dimshuffle(1,0,2).reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] #x_out = x_out.reshape((self.index.shape[1], self.index.shape[0] - 1, x_out.shape[1])).dimshuffle(1,0,2) x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] x_out = x_out.reshape((self.index.shape[0] - 1, self.index.shape[1], x_out.shape[1])) self.output = T.concatenate([x_out, base[0].output[1:]],axis=0) self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) return skips = T.dot(T.nnet.softmax(z), T.arange(z.shape[1], dtype='float32')).reshape(self.index[1:].shape) shift = T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32') skips = T.concatenate([T.zeros_like(self.y_t[:1]),self.y_t[:-1]],axis=0) idx = shift + T.cumsum(skips, axis=0) idx = T.cast(idx[:-1].flatten(),'int32') #idx = (idx.flatten() > 0).nonzero() #idx = base[0].attention.flatten() x_out = base[0].output[::-1] x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] x_out = x_out.reshape((self.index.shape[0], self.index.shape[1], x_out.shape[1])) self.output = T.concatenate([base[0].output[-1:], x_out], axis=0)[::-1] self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) return if recurrent_transform == 'batch_norm': self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re) self.params['sample_mean_batch_norm'].custom_update_normalized = True self.make_output(self.act[0][::direction or 1], sample_mean=sample_mean, gamma=gamma) self.params.update(unit.params)
def __init__(self, n_out=None, n_units=None, direction=1, truncation=-1, sampling=1, encoder=None, unit='lstm', n_dec=0, attention="none", recurrent_transform="none", recurrent_transform_attribs="{}", attention_template=128, attention_distance='l2', attention_step="linear", attention_beam=0, attention_norm="exp", attention_momentum="none", attention_sharpening=1.0, attention_nbest=0, attention_store=False, attention_smooth=False, attention_align=False, attention_glimpse=1, attention_filters=1, attention_accumulator='sum', attention_loss=0, attention_bn=0, attention_lm='none', attention_ndec=1, attention_memory=0, base=None, lm=False, force_lm=False, droplm=1.0, forward_weights_init=None, bias_random_init_forget_shift=0.0, copy_weights_from_base=False, **kwargs): """ :param n_out: number of cells :param n_units: used when initialized via Network.from_hdf_model_topology :param direction: process sequence in forward (1) or backward (-1) direction :param truncation: gradient truncation :param sampling: scan every nth frame only :param encoder: list of encoder layers used as initalization for the hidden state :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru') :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder :param recurrent_transform: name of recurrent transform :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform :param attention_template: :param attention_distance: :param attention_step: :param attention_beam: :param attention_norm: :param attention_sharpening: :param attention_nbest: :param attention_store: :param attention_align: :param attention_glimpse: :param attention_lm: :param base: list of layers which outputs are considered as based during attention mechanisms :param lm: activate RNNLM :param force_lm: expect previous labels to be given during testing :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value """ source_index = None if len(kwargs['sources']) == 1 and ( kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')): kwargs['sources'] = [] source_index = kwargs['index'] unit_given = unit from Device import is_using_gpu if unit == 'lstm': # auto selection if not is_using_gpu(): unit = 'lstme' elif recurrent_transform == 'none' and (not lm or droplm == 0.0): unit = 'lstmp' else: unit = 'lstmc' elif unit in ("lstmc", "lstmp") and not is_using_gpu(): unit = "lstme" if n_out is None: assert encoder n_out = sum([enc.attrs['n_out'] for enc in encoder]) kwargs.setdefault("n_out", n_out) if n_units is not None: assert n_units == n_out self.attention_weight = T.constant(1., 'float32') if len( kwargs['sources'] ) == 1 and kwargs['sources'][0].layer_class.startswith('length'): kwargs['sources'] = [] elif len( kwargs['sources'] ) == 1 and kwargs['sources'][0].layer_class.startswith('signal'): kwargs['sources'] = [] super(RecurrentUnitLayer, self).__init__(**kwargs) self.set_attr( 'from', ",".join([s.name for s in self.sources]) if self.sources else "null") self.set_attr('n_out', n_out) self.set_attr('unit', unit_given.encode("utf8")) self.set_attr('truncation', truncation) self.set_attr('sampling', sampling) self.set_attr('direction', direction) self.set_attr('lm', lm) self.set_attr('force_lm', force_lm) self.set_attr('droplm', droplm) if bias_random_init_forget_shift: self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift) self.set_attr('attention_beam', attention_beam) self.set_attr('recurrent_transform', recurrent_transform.encode("utf8")) if isinstance(recurrent_transform_attribs, str): recurrent_transform_attribs = json.loads( recurrent_transform_attribs) if attention_template is not None: self.set_attr('attention_template', attention_template) self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs) self.set_attr('attention_distance', attention_distance.encode("utf8")) self.set_attr('attention_step', attention_step.encode("utf8")) self.set_attr('attention_norm', attention_norm.encode("utf8")) self.set_attr('attention_sharpening', attention_sharpening) self.set_attr('attention_nbest', attention_nbest) attention_store = attention_store or attention_smooth or attention_momentum != 'none' self.set_attr('attention_store', attention_store) self.set_attr('attention_smooth', attention_smooth) self.set_attr('attention_momentum', attention_momentum.encode('utf8')) self.set_attr('attention_align', attention_align) self.set_attr('attention_glimpse', attention_glimpse) self.set_attr('attention_filters', attention_filters) self.set_attr('attention_lm', attention_lm) self.set_attr('attention_bn', attention_bn) self.set_attr('attention_accumulator', attention_accumulator) self.set_attr('attention_ndec', attention_ndec) self.set_attr('attention_memory', attention_memory) self.set_attr('attention_loss', attention_loss) self.set_attr('n_dec', n_dec) if encoder and hasattr(encoder[0], 'act'): self.set_attr('encoder', ",".join([e.name for e in encoder])) if base: self.set_attr('base', ",".join([b.name for b in base])) else: base = encoder self.base = base self.set_attr('n_units', n_out) unit = eval(unit.upper())(**self.attrs) assert isinstance(unit, Unit) self.unit = unit kwargs.setdefault("n_out", unit.n_out) n_out = unit.n_out self.set_attr('n_out', unit.n_out) if n_dec != 0: self.target_index = self.index if isinstance(n_dec, float): if not source_index: source_index = encoder[0].index if encoder else base[ 0].index lengths = T.cast( T.ceil( T.sum(T.cast(source_index, 'float32'), axis=0) * n_dec), 'int32') idx, _ = theano.map( lambda l_i, l_m: T.concatenate([ T.ones((l_i, ), 'int8'), T.zeros((l_m - l_i, ), 'int8') ]), [lengths], [T.max(lengths) + 1]) self.index = idx.dimshuffle(1, 0)[:-1] n_dec = T.cast( T.ceil( T.cast(source_index.shape[0], 'float32') * numpy.float32(n_dec)), 'int32') else: if encoder: self.index = encoder[0].index self.index = T.ones((n_dec, self.index.shape[1]), 'int8') else: n_dec = self.index.shape[0] # initialize recurrent weights self.W_re = None if unit.n_re > 0: self.W_re = self.add_param( self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name)) # initialize forward weights bias_init_value = self.create_bias(unit.n_in).get_value() if bias_random_init_forget_shift: assert unit.n_units * 4 == unit.n_in # (input gate, forget gate, output gate, net input) bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift self.b.set_value(bias_init_value) if not forward_weights_init: forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re else: self.set_attr('forward_weights_init', forward_weights_init) self.forward_weights_init = forward_weights_init self.W_in = [] if copy_weights_from_base: self.params = {} #self.W_re = self.add_param(base[0].W_re) #self.W_in = [ self.add_param(W) for W in base[0].W_in ] #self.b = self.add_param(base[0].b) self.W_re = base[0].W_re self.W_in = base[0].W_in self.b = base[0].b self.masks = base[0].masks self.mass = base[0].mass else: for s in self.sources: W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name)) self.W_in.append(self.add_param(W)) # make input z = self.b for x_t, m, W in zip(self.sources, self.masks, self.W_in): if x_t.attrs['sparse']: if x_t.output.ndim == 3: out_dim = x_t.output.shape[2] elif x_t.output.ndim == 2: out_dim = 1 else: assert False, x_t.output.ndim if x_t.output.ndim == 3: z += W[T.cast(x_t.output[:, :, 0], 'int32')] elif x_t.output.ndim == 2: z += W[T.cast(x_t.output, 'int32')] else: assert False, x_t.output.ndim elif m is None: z += T.dot(x_t.output, W) else: z += self.dot(self.mass * m * x_t.output, W) #if self.attrs['batch_norm']: # z = self.batch_norm(z, unit.n_in) num_batches = self.index.shape[1] self.num_batches = num_batches non_sequences = [] if self.attrs['lm'] or attention_lm != 'none': if not 'target' in self.attrs: self.attrs['target'] = 'classes' if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm): if copy_weights_from_base: self.W_lm_in = base[0].W_lm_in self.b_lm_in = base[0].b_lm_in else: l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform( low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX) self.W_lm_in = self.add_param( self.shared(value=values, borrow=True, name="W_lm_in_" + self.name)) self.b_lm_in = self.create_bias( self.y_in[self.attrs['target']].n_out, 'b_lm_in') l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform( low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX) if copy_weights_from_base: self.W_lm_out = base[0].W_lm_out else: self.W_lm_out = self.add_param( self.shared(value=values, borrow=True, name="W_lm_out_" + self.name)) if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): self.lmmask = 1 #if recurrent_transform != 'none': # recurrent_transform = recurrent_transform[:-3] elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm): from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams srng = RandomStreams(self.rng.randint(1234) + 1) self.lmmask = T.cast( srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0, 1, 'x').repeat(unit.n_in, axis=2) else: self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle( 0, 1, 'x').repeat(unit.n_in, axis=2) if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible) src = [] src_names = [] n_in = 0 for e in base: #src_base = [ s for s in e.sources if s.name not in src_names ] #src_names += [ s.name for s in e.sources ] src_base = [e] src_names += [e.name] src += [s.output for s in src_base] n_in += sum([s.attrs['n_out'] for s in src_base]) self.xc = T.concatenate(src, axis=2) l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX) self.W_att_xc = self.add_param( self.shared(value=values, borrow=True, name="W_att_xc")) values = numpy.asarray(self.rng.uniform( low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX) self.W_att_in = self.add_param( self.shared(value=values, borrow=True, name="W_att_in")) zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1 self.zc = T.dot( T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat( self.xc.shape[2], axis=2), axis=0, keepdims=True), self.W_att_in) recurrent_transform = 'none' recurrent_transform_inst = RecurrentTransform.transform_classes[ recurrent_transform](layer=self) assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase) unit.recurrent_transform = recurrent_transform_inst self.recurrent_transform = recurrent_transform_inst # scan over sequence for s in range(self.attrs['sampling']): index = self.index[s::self.attrs['sampling']] sequences = z sources = self.sources if encoder: if hasattr(encoder[0], 'act'): outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ] else: outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] sequences += T.alloc( numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) else: outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ] if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and ( self.train_flag or force_lm): if self.network.y[self.attrs['target']].ndim == 3: sequences += T.dot(self.network.y[self.attrs['target']], self.W_lm_out) else: y = self.y_in[self.attrs['target']].flatten() sequences += self.W_lm_out[y].reshape( (index.shape[0], index.shape[1], unit.n_in)) if sequences == self.b: sequences += T.alloc( numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) if unit.recurrent_transform: outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial( ) index_f = T.cast(index, theano.config.floatX) unit.set_parent(self) outputs = unit.scan(x=sources, z=sequences[s::self.attrs['sampling']], non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) if not isinstance(outputs, list): outputs = [outputs] if outputs: outputs[0].name = "%s.act[0]" % self.name if unit.recurrent_transform: unit.recurrent_transform_state_var_seqs = outputs[ -len(unit.recurrent_transform.state_vars):] if self.attrs['sampling'] > 1: if s == 0: self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ] self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot, act in zip(self.act, outputs) ] else: self.act = outputs[:unit.n_act] if len(outputs) > unit.n_act: self.aux = outputs[unit.n_act:] if self.attrs['attention_store']: self.attention = [ self.aux[i].dimshuffle(0, 2, 1) for i, v in enumerate( sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT for i in range(len(self.attention)): vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1)) last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1], axis=1) self.attention[i] = T.concatenate( [self.attention[i][1:], last], axis=0)[::direction] if self.attrs['attention_align']: bp = [ self.aux[i] for i, v in enumerate( sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('K_') ] def backtrace(k, i_p): return i_p - k[i_p, T.arange(k.shape[1], dtype='int32')] self.alignment = [] for K in bp: # K: NTB aln, _ = theano.scan( backtrace, sequences=[T.cast(K, 'int32').dimshuffle(1, 0, 2)], outputs_info=[ T.cast(self.index.shape[0] - 1, 'int32') + T.zeros( (K.shape[2], ), 'int32') ]) aln = theano.printing.Print("aln")(aln) self.alignment.append(aln) # TB if recurrent_transform == 'batch_norm': self.params['sample_mean_batch_norm'].custom_update = T.dot( T.mean(self.act[0], axis=[0, 1]), self.W_re) self.params[ 'sample_mean_batch_norm'].custom_update_normalized = True self.make_output(self.act[0][::direction or 1]) self.params.update(unit.params)
def __init__(self, n_out, n_units = None, direction = 1, truncation = -1, sampling = 1, encoder = None, unit = 'lstm', n_dec = 0, attention = "none", recurrent_transform = "none", recurrent_transform_attribs = "{}", attention_template = 128, attention_distance = 'l2', attention_step = "linear", attention_beam = 0, attention_norm = "exp", attention_momentum = "none", attention_sharpening = 1.0, attention_nbest = 0, attention_store = False, attention_smooth = False, attention_align = False, attention_glimpse = 1, attention_filters = 1, attention_accumulator = 'sum', attention_bn = 0, attention_lm = 'none', attention_ndec = 1, base = None, lm = False, force_lm = False, droplm = 1.0, forward_weights_init=None, bias_random_init_forget_shift=0.0, **kwargs): """ :param n_out: number of cells :param n_units: used when initialized via Network.from_hdf_model_topology :param direction: process sequence in forward (1) or backward (-1) direction :param truncation: gradient truncation :param sampling: scan every nth frame only :param encoder: list of encoder layers used as initalization for the hidden state :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru') :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder :param recurrent_transform: name of recurrent transform :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform :param attention_template: :param attention_distance: :param attention_step: :param attention_beam: :param attention_norm: :param attention_sharpening: :param attention_nbest: :param attention_store: :param attention_align: :param attention_glimpse: :param attention_lm: :param base: list of layers which outputs are considered as based during attention mechanisms :param lm: activate RNNLM :param force_lm: expect previous labels to be given during testing :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value """ source_index = None if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')): kwargs['sources'] = [] source_index = kwargs['index'] unit_given = unit from Device import is_using_gpu if unit == 'lstm': # auto selection if not is_using_gpu(): unit = 'lstme' elif recurrent_transform == 'none' and (not lm or droplm == 0.0): unit = 'lstmp' else: unit = 'lstmc' elif unit in ("lstmc", "lstmp") and not is_using_gpu(): unit = "lstme" kwargs.setdefault("n_out", n_out) if n_units is not None: assert n_units == n_out self.attention_weight = T.constant(1.,'float32') if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'): kwargs['sources'] = [] elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'): kwargs['sources'] = [] super(RecurrentUnitLayer, self).__init__(**kwargs) self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null") self.set_attr('n_out', n_out) self.set_attr('unit', unit_given.encode("utf8")) self.set_attr('truncation', truncation) self.set_attr('sampling', sampling) self.set_attr('direction', direction) self.set_attr('lm', lm) self.set_attr('force_lm', force_lm) self.set_attr('droplm', droplm) if bias_random_init_forget_shift: self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift) self.set_attr('attention_beam', attention_beam) self.set_attr('recurrent_transform', recurrent_transform.encode("utf8")) if isinstance(recurrent_transform_attribs, str): recurrent_transform_attribs = json.loads(recurrent_transform_attribs) if attention_template is not None: self.set_attr('attention_template', attention_template) self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs) self.set_attr('attention_distance', attention_distance.encode("utf8")) self.set_attr('attention_step', attention_step.encode("utf8")) self.set_attr('attention_norm', attention_norm.encode("utf8")) self.set_attr('attention_sharpening', attention_sharpening) self.set_attr('attention_nbest', attention_nbest) attention_store = attention_store or attention_smooth or attention_momentum != 'none' self.set_attr('attention_store', attention_store) self.set_attr('attention_smooth', attention_smooth) self.set_attr('attention_momentum', attention_momentum.encode('utf8')) self.set_attr('attention_align', attention_align) self.set_attr('attention_glimpse', attention_glimpse) self.set_attr('attention_filters', attention_filters) self.set_attr('attention_lm', attention_lm) self.set_attr('attention_bn', attention_bn) self.set_attr('attention_accumulator', attention_accumulator) self.set_attr('attention_ndec', attention_ndec) self.set_attr('n_dec', n_dec) if encoder and hasattr(encoder[0],'act'): self.set_attr('encoder', ",".join([e.name for e in encoder])) if base: self.set_attr('base', ",".join([b.name for b in base])) else: base = encoder self.base = base self.set_attr('n_units', n_out) unit = eval(unit.upper())(**self.attrs) assert isinstance(unit, Unit) self.unit = unit kwargs.setdefault("n_out", unit.n_out) n_out = unit.n_out self.set_attr('n_out', unit.n_out) if n_dec != 0: self.target_index = self.index if isinstance(n_dec,float): if not source_index: source_index = encoder[0].index if encoder else base[0].index lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32') idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]), [lengths], [T.max(lengths)+1]) self.index = idx.dimshuffle(1,0)[:-1] n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32') else: self.index = encoder[0].index self.index = T.ones((n_dec,self.index.shape[1]),'int64') # TODO: this gives a graph replacement error for int8 else: n_dec = self.index.shape[0] # initialize recurrent weights self.W_re = None if unit.n_re > 0: self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name)) # initialize forward weights bias_init_value = self.create_bias(unit.n_in).get_value() if bias_random_init_forget_shift: assert unit.n_units * 4 == unit.n_in # (input gate, forget gate, output gate, net input) bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift self.b.set_value(bias_init_value) if not forward_weights_init: forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re else: self.set_attr('forward_weights_init', forward_weights_init) self.forward_weights_init = forward_weights_init self.W_in = [] for s in self.sources: W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name)) self.W_in.append(self.add_param(W)) # make input z = self.b for x_t, m, W in zip(self.sources, self.masks, self.W_in): if x_t.attrs['sparse']: if x_t.output.ndim == 3: out_dim = x_t.output.shape[2] elif x_t.output.ndim == 2: out_dim = 1 else: assert False, x_t.output.ndim if x_t.output.ndim == 3: z += W[T.cast(x_t.output[:,:,0], 'int32')] elif x_t.output.ndim == 2: z += W[T.cast(x_t.output, 'int32')] else: assert False, x_t.output.ndim elif m is None: z += T.dot(x_t.output, W) else: z += self.dot(self.mass * m * x_t.output, W) #if self.attrs['batch_norm']: # z = self.batch_norm(z, unit.n_in) num_batches = self.index.shape[1] self.num_batches = num_batches non_sequences = [] if self.attrs['lm'] or attention_lm != 'none': if not 'target' in self.attrs: self.attrs['target'] = 'classes' if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm): l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX) self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name)) self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in') l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX) self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name)) if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): self.lmmask = 1 #if recurrent_transform != 'none': # recurrent_transform = recurrent_transform[:-3] elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm): from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams srng = RandomStreams(self.rng.randint(1234) + 1) self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) else: self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible) src = [] src_names = [] n_in = 0 for e in base: src_base = [ s for s in e.sources if s.name not in src_names ] src_names += [ s.name for s in e.sources ] src += [s.output for s in src_base] n_in += sum([s.attrs['n_out'] for s in src_base]) self.xc = T.concatenate(src, axis=2) l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX) self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc")) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX) self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in")) zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1 self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in) recurrent_transform = 'none' recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self) assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase) unit.recurrent_transform = recurrent_transform_inst self.recurrent_transform = recurrent_transform_inst # scan over sequence for s in range(self.attrs['sampling']): index = self.index[s::self.attrs['sampling']] sequences = z sources = self.sources if encoder: if hasattr(encoder[0],'act'): outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ] else: outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else 0) else: outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ] if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): if self.network.y[self.attrs['target']].ndim == 3: sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out) else: y = self.y_in[self.attrs['target']].flatten() sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in)) if sequences == self.b: sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else 0) if unit.recurrent_transform: outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial() index_f = T.cast(index, theano.config.floatX) unit.set_parent(self) outputs = unit.scan(x=sources, z=sequences[s::self.attrs['sampling']], non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) if not isinstance(outputs, list): outputs = [outputs] if outputs: outputs[0].name = "%s.act[0]" % self.name if unit.recurrent_transform: unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):] if self.attrs['sampling'] > 1: if s == 0: self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ] self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ] else: self.act = outputs[:unit.n_act] if len(outputs) > unit.n_act: self.aux = outputs[unit.n_act:] if self.attrs['attention_store']: self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT for i in range(len(self.attention)): vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1)) last = vec.dimshuffle(1,'x', 0).repeat(self.index.shape[1], axis=1) self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction] if self.attrs['attention_align']: bp = [ self.aux[i] for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('K_') ] def backtrace(k,i_p): return i_p - k[i_p,T.arange(k.shape[1],dtype='int32')] self.alignment = [] for K in bp: # K: NTB aln, _ = theano.scan(backtrace, sequences=[T.cast(K,'int32').dimshuffle(1,0,2)], outputs_info=[T.cast(self.index.shape[0] - 1,'int32') + T.zeros((K.shape[2],),'int32')]) aln = theano.printing.Print("aln")(aln) self.alignment.append(aln) # TB if recurrent_transform == 'batch_norm': self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re) self.params['sample_mean_batch_norm'].custom_update_normalized = True self.make_output(self.act[0][::direction or 1]) self.params.update(unit.params)
def __init__(self, n_out = None, n_units = None, direction = 1, truncation = -1, sampling = 1, encoder = None, unit = 'lstm', n_dec = 0, attention = "none", recurrent_transform = "none", recurrent_transform_attribs = "{}", attention_template = 128, attention_distance = 'l2', attention_step = "linear", attention_beam = 0, attention_norm = "exp", attention_momentum = "none", attention_sharpening = 1.0, attention_nbest = 0, attention_store = False, attention_smooth = False, attention_glimpse = 1, attention_filters = 1, attention_accumulator = 'sum', attention_loss = 0, attention_bn = 0, attention_lm = 'none', attention_ndec = 1, attention_memory = 0, attention_alnpts = 0, attention_epoch = 1, attention_segstep=0.01, attention_offset=0.95, attention_method="epoch", attention_scale=10, context=-1, base = None, aligner = None, lm = False, force_lm = False, droplm = 1.0, forward_weights_init=None, bias_random_init_forget_shift=0.0, copy_weights_from_base=False, segment_input=False, join_states=False, state_memory=False, sample_segment=None, **kwargs): """ :param n_out: number of cells :param n_units: used when initialized via Network.from_hdf_model_topology :param direction: process sequence in forward (1) or backward (-1) direction :param truncation: gradient truncation :param sampling: scan every nth frame only :param encoder: list of encoder layers used as initalization for the hidden state :param unit: cell type (one of 'lstm', 'vanilla', 'gru', 'sru') :param n_dec: absolute number of steps to unfold the network if integer, else relative number of steps from encoder :param recurrent_transform: name of recurrent transform :param recurrent_transform_attribs: dictionary containing parameters for a recurrent transform :param attention_template: :param attention_distance: :param attention_step: :param attention_beam: :param attention_norm: :param attention_sharpening: :param attention_nbest: :param attention_store: :param attention_align: :param attention_glimpse: :param attention_lm: :param base: list of layers which outputs are considered as based during attention mechanisms :param lm: activate RNNLM :param force_lm: expect previous labels to be given during testing :param droplm: probability to take the expected output as predecessor instead of the real one when LM=true :param bias_random_init_forget_shift: initialize forget gate bias of lstm networks with this value """ source_index = None if len(kwargs['sources']) == 1 and (kwargs['sources'][0].layer_class.endswith('length') or kwargs['sources'][0].layer_class.startswith('length')): kwargs['sources'] = [] source_index = kwargs['index'] unit_given = unit from Device import is_using_gpu if unit == 'lstm': # auto selection if not is_using_gpu(): unit = 'lstme' elif recurrent_transform == 'none' and (not lm or droplm == 0.0): unit = 'lstmp' else: unit = 'lstmc' elif unit in ("lstmc", "lstmp") and not is_using_gpu(): unit = "lstme" if segment_input: if is_using_gpu(): unit = "lstmps" else: unit = "lstms" if n_out is None: assert encoder n_out = sum([enc.attrs['n_out'] for enc in encoder]) kwargs.setdefault("n_out", n_out) if n_units is not None: assert n_units == n_out self.attention_weight = T.constant(1.,'float32') if len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('length'): kwargs['sources'] = [] elif len(kwargs['sources']) == 1 and kwargs['sources'][0].layer_class.startswith('signal'): kwargs['sources'] = [] super(RecurrentUnitLayer, self).__init__(**kwargs) self.set_attr('from', ",".join([s.name for s in self.sources]) if self.sources else "null") self.set_attr('n_out', n_out) self.set_attr('unit', unit_given.encode("utf8")) self.set_attr('truncation', truncation) self.set_attr('sampling', sampling) self.set_attr('direction', direction) self.set_attr('lm', lm) self.set_attr('force_lm', force_lm) self.set_attr('droplm', droplm) if bias_random_init_forget_shift: self.set_attr("bias_random_init_forget_shift", bias_random_init_forget_shift) self.set_attr('attention_beam', attention_beam) self.set_attr('recurrent_transform', recurrent_transform.encode("utf8")) if isinstance(recurrent_transform_attribs, str): recurrent_transform_attribs = json.loads(recurrent_transform_attribs) if attention_template is not None: self.set_attr('attention_template', attention_template) self.set_attr('recurrent_transform_attribs', recurrent_transform_attribs) self.set_attr('attention_distance', attention_distance.encode("utf8")) self.set_attr('attention_step', attention_step.encode("utf8")) self.set_attr('attention_norm', attention_norm.encode("utf8")) self.set_attr('attention_sharpening', attention_sharpening) self.set_attr('attention_nbest', attention_nbest) attention_store = attention_store or attention_smooth or attention_momentum != 'none' self.set_attr('attention_store', attention_store) self.set_attr('attention_smooth', attention_smooth) self.set_attr('attention_momentum', attention_momentum.encode('utf8')) self.set_attr('attention_glimpse', attention_glimpse) self.set_attr('attention_filters', attention_filters) self.set_attr('attention_lm', attention_lm) self.set_attr('attention_bn', attention_bn) self.set_attr('attention_accumulator', attention_accumulator) self.set_attr('attention_ndec', attention_ndec) self.set_attr('attention_memory', attention_memory) self.set_attr('attention_loss', attention_loss) self.set_attr('n_dec', n_dec) self.set_attr('segment_input', segment_input) self.set_attr('attention_alnpts', attention_alnpts) self.set_attr('attention_epoch', attention_epoch) self.set_attr('attention_segstep', attention_segstep) self.set_attr('attention_offset', attention_offset) self.set_attr('attention_method', attention_method) self.set_attr('attention_scale', attention_scale) if segment_input: if not self.eval_flag: #if self.eval_flag: if isinstance(self.sources[0],RecurrentUnitLayer): self.inv_att = self.sources[0].inv_att #NBT else: if not join_states: self.inv_att = self.sources[0].attention #NBT else: assert hasattr(self.sources[0], "nstates"), "source does not have number of states!" ns = self.sources[0].nstates self.inv_att = self.sources[0].attention[(ns-1)::ns] inv_att = T.roll(self.inv_att.dimshuffle(2, 1, 0),1,axis=0)#TBN inv_att = T.set_subtensor(inv_att[0],T.zeros((inv_att.shape[1],inv_att.shape[2]))) inv_att = T.max(inv_att,axis=-1) else: inv_att = T.zeros((self.sources[0].output.shape[0],self.sources[0].output.shape[1])) if encoder and hasattr(encoder[0],'act'): self.set_attr('encoder', ",".join([e.name for e in encoder])) if base: self.set_attr('base', ",".join([b.name for b in base])) else: base = encoder self.base = base self.encoder = encoder if aligner: self.aligner = aligner self.set_attr('n_units', n_out) unit = eval(unit.upper())(**self.attrs) assert isinstance(unit, Unit) self.unit = unit kwargs.setdefault("n_out", unit.n_out) n_out = unit.n_out self.set_attr('n_out', unit.n_out) if n_dec < 0: source_index = self.index n_dec *= -1 if n_dec != 0: self.target_index = self.index if isinstance(n_dec,float): if not source_index: source_index = encoder[0].index if encoder else base[0].index lengths = T.cast(T.ceil(T.sum(T.cast(source_index,'float32'),axis=0) * n_dec), 'int32') idx, _ = theano.map(lambda l_i, l_m:T.concatenate([T.ones((l_i,),'int8'),T.zeros((l_m-l_i,),'int8')]), [lengths], [T.max(lengths)+1]) self.index = idx.dimshuffle(1,0)[:-1] n_dec = T.cast(T.ceil(T.cast(source_index.shape[0],'float32') * numpy.float32(n_dec)),'int32') else: if encoder: self.index = encoder[0].index self.index = T.ones((n_dec,self.index.shape[1]),'int8') else: n_dec = self.index.shape[0] # initialize recurrent weights self.W_re = None if unit.n_re > 0: self.W_re = self.add_param(self.create_recurrent_weights(unit.n_units, unit.n_re, name="W_re_%s" % self.name)) # initialize forward weights bias_init_value = self.create_bias(unit.n_in).get_value() if bias_random_init_forget_shift: assert unit.n_units * 4 == unit.n_in # (input gate, forget gate, output gate, net input) bias_init_value[unit.n_units:2 * unit.n_units] += bias_random_init_forget_shift self.b.set_value(bias_init_value) if not forward_weights_init: forward_weights_init = "random_uniform(p_add=%i)" % unit.n_re else: self.set_attr('forward_weights_init', forward_weights_init) self.forward_weights_init = forward_weights_init self.W_in = [] sample_mean, gamma = None, None if copy_weights_from_base: self.params = {} #self.W_re = self.add_param(base[0].W_re) #self.W_in = [ self.add_param(W) for W in base[0].W_in ] #self.b = self.add_param(base[0].b) self.W_re = base[0].W_re self.W_in = base[0].W_in self.b = base[0].b if self.attrs.get('batch_norm', False): sample_mean = base[0].sample_mean gamma = base[0].gamma #self.masks = base[0].masks #self.mass = base[0].mass else: for s in self.sources: W = self.create_forward_weights(s.attrs['n_out'], unit.n_in, name="W_in_%s_%s" % (s.name, self.name)) self.W_in.append(self.add_param(W)) # make input z = self.b for x_t, m, W in zip(self.sources, self.masks, self.W_in): if x_t.attrs['sparse']: if x_t.output.ndim == 3: out_dim = x_t.output.shape[2] elif x_t.output.ndim == 2: out_dim = 1 else: assert False, x_t.output.ndim if x_t.output.ndim == 3: z += W[T.cast(x_t.output[:,:,0], 'int32')] elif x_t.output.ndim == 2: z += W[T.cast(x_t.output, 'int32')] else: assert False, x_t.output.ndim elif m is None: z += T.dot(x_t.output, W) else: z += self.dot(self.mass * m * x_t.output, W) #if self.attrs['batch_norm']: # z = self.batch_norm(z, unit.n_in) num_batches = self.index.shape[1] self.num_batches = num_batches non_sequences = [] if self.attrs['lm'] or attention_lm != 'none': if not 'target' in self.attrs: self.attrs['target'] = 'classes' if self.attrs['droplm'] > 0.0 or not (self.train_flag or force_lm): if copy_weights_from_base: self.W_lm_in = base[0].W_lm_in self.b_lm_in = base[0].b_lm_in else: l = sqrt(6.) / sqrt(unit.n_out + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(unit.n_out, self.y_in[self.attrs['target']].n_out)), dtype=theano.config.floatX) self.W_lm_in = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_in_"+self.name)) self.b_lm_in = self.create_bias(self.y_in[self.attrs['target']].n_out, 'b_lm_in') l = sqrt(6.) / sqrt(unit.n_in + self.y_in[self.attrs['target']].n_out) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(self.y_in[self.attrs['target']].n_out, unit.n_in)), dtype=theano.config.floatX) if copy_weights_from_base: self.W_lm_out = base[0].W_lm_out else: self.W_lm_out = self.add_param(self.shared(value=values, borrow=True, name = "W_lm_out_"+self.name)) if self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): self.lmmask = 1 #if recurrent_transform != 'none': # recurrent_transform = recurrent_transform[:-3] elif self.attrs['droplm'] < 1.0 and (self.train_flag or force_lm): from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams srng = RandomStreams(self.rng.randint(1234) + 1) self.lmmask = T.cast(srng.binomial(n=1, p=1.0 - self.attrs['droplm'], size=self.index.shape), theano.config.floatX).dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) else: self.lmmask = T.zeros_like(self.index, dtype='float32').dimshuffle(0,1,'x').repeat(unit.n_in,axis=2) if recurrent_transform == 'input': # attention is just a sequence dependent bias (lstmp compatible) src = [] src_names = [] n_in = 0 for e in base: #src_base = [ s for s in e.sources if s.name not in src_names ] #src_names += [ s.name for s in e.sources ] src_base = [ e ] src_names += [e.name] src += [s.output for s in src_base] n_in += sum([s.attrs['n_out'] for s in src_base]) self.xc = T.concatenate(src, axis=2) l = sqrt(6.) / sqrt(self.attrs['n_out'] + n_in) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, 1)), dtype=theano.config.floatX) self.W_att_xc = self.add_param(self.shared(value=values, borrow=True, name = "W_att_xc")) values = numpy.asarray(self.rng.uniform(low=-l, high=l, size=(n_in, self.attrs['n_out'] * 4)), dtype=theano.config.floatX) self.W_att_in = self.add_param(self.shared(value=values, borrow=True, name = "W_att_in")) zz = T.exp(T.tanh(T.dot(self.xc, self.W_att_xc))) # TB1 self.zc = T.dot(T.sum(self.xc * (zz / T.sum(zz, axis=0, keepdims=True)).repeat(self.xc.shape[2],axis=2), axis=0, keepdims=True), self.W_att_in) recurrent_transform = 'none' elif recurrent_transform == 'attention_align': max_skip = base[0].attrs['max_skip'] values = numpy.zeros((max_skip,), dtype=theano.config.floatX) self.T_b = self.add_param(self.shared(value=values, borrow=True, name="T_b"), name="T_b") l = sqrt(6.) / sqrt(self.attrs['n_out'] + max_skip) values = numpy.asarray(self.rng.uniform( low=-l, high=l, size=(self.attrs['n_out'], max_skip)), dtype=theano.config.floatX) self.T_W = self.add_param(self.shared(value=values, borrow=True, name="T_W"), name="T_W") y_t = T.dot(self.base[0].attention, T.arange(self.base[0].output.shape[0], dtype='float32')) # NB y_t = T.concatenate([T.zeros_like(y_t[:1]), y_t], axis=0) # (N+1)B y_t = y_t[1:] - y_t[:-1] # NB self.y_t = y_t # T.clip(y_t,numpy.float32(0),numpy.float32(max_skip - 1)) self.y_t = T.cast(self.base[0].backtrace,'float32') elif recurrent_transform == 'attention_segment': assert aligner.attention, "Segment-wise attention requires attention points!" recurrent_transform_inst = RecurrentTransform.transform_classes[recurrent_transform](layer=self) assert isinstance(recurrent_transform_inst, RecurrentTransform.RecurrentTransformBase) unit.recurrent_transform = recurrent_transform_inst self.recurrent_transform = recurrent_transform_inst state_memory *= self.train_flag # scan over sequence for s in range(self.attrs['sampling']): index = self.index[s::self.attrs['sampling']] if context > 0: from TheanoUtil import context_batched n_batches = z.shape[1] time, batch, dim = z.shape[0], z.shape[1], z.shape[2] #z = context_batched(z[::direction or 1], window=context)[::direction or 1] # TB(CD) from theano.ifelse import ifelse def context_window(idx, x_in, i_in): x_out = x_in[idx:idx + context] x_out = x_out.dimshuffle('x',1,0,2).reshape((1, batch, dim * context)) i_out = i_in[idx:idx+1].repeat(context, axis=0) i_out = ifelse(T.lt(idx,context),T.set_subtensor(i_out[:context - idx],numpy.int8(0)),i_out).reshape((1, batch * context)) return x_out, i_out z = z[::direction or 1] i = index[::direction or 1] out, _ = theano.map(context_window, sequences = [T.arange(z.shape[0])], non_sequences = [T.concatenate([T.zeros((context - 1,z.shape[1],z.shape[2]),dtype='float32'),z],axis=0), i]) z = out[0][::direction or 1] i = out[1][::direction or 1] # T(BC) direction = 1 z = z.reshape((time * batch, context * dim)) # (TB)(CD) z = z.reshape((time * batch, context, dim)).dimshuffle(1,0,2) # C(TB)D i = i.reshape((time * batch, context)).dimshuffle(1,0) # C(TB) index = i num_batches = time * batch sequences = z sources = self.sources if state_memory: self.init_state = [ self.add_param(self.shared(numpy.zeros((state_memory or 1, unit.n_units), dtype='float32'), name='init_%d_%s' % (a, self.name))) for a in range(unit.n_act)] # has to be initialized for train and test if encoder: if recurrent_transform == "attention_segment": if hasattr(encoder[0],'act'): outputs_info = [T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act)] else: # outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] outputs_info[0] = self.aligner.output[-1] elif hasattr(encoder[0],'act'): outputs_info = [ T.concatenate([e.act[i][-1] for e in encoder], axis=1) for i in range(unit.n_act) ] else: outputs_info = [ T.concatenate([e[i] for e in encoder], axis=1) for i in range(unit.n_act) ] sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) elif state_memory: outputs_info = self.init_state else: outputs_info = [ T.alloc(numpy.cast[theano.config.floatX](0), num_batches, unit.n_units) for a in range(unit.n_act) ] if self.attrs['lm'] and self.attrs['droplm'] == 0.0 and (self.train_flag or force_lm): if self.network.y[self.attrs['target']].ndim == 3: sequences += T.dot(self.network.y[self.attrs['target']],self.W_lm_out) else: y = self.y_in[self.attrs['target']].flatten() sequences += self.W_lm_out[y].reshape((index.shape[0],index.shape[1],unit.n_in)) if sequences == self.b: sequences += T.alloc(numpy.cast[theano.config.floatX](0), n_dec, num_batches, unit.n_in) + (self.zc if self.attrs['recurrent_transform'] == 'input' else numpy.float32(0)) if unit.recurrent_transform: outputs_info += unit.recurrent_transform.get_sorted_state_vars_initial() index_f = T.cast(index, theano.config.floatX) unit.set_parent(self) if segment_input: outputs = unit.scan_seg(x=sources, z=sequences[s::self.attrs['sampling']], att = inv_att, non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) else: outputs = unit.scan(x=sources, z=sequences[s::self.attrs['sampling']], non_sequences=non_sequences, i=index_f, outputs_info=outputs_info, W_re=self.W_re, W_in=self.W_in, b=self.b, go_backwards=direction == -1, truncate_gradient=self.attrs['truncation']) if not isinstance(outputs, list): outputs = [outputs] if outputs: outputs[0].name = "%s.act[0]" % self.name if context > 0: for i in range(len(outputs)): outputs[i] = outputs[i][-1].reshape((outputs[i].shape[1]//n_batches,n_batches,outputs[i].shape[2])) if unit.recurrent_transform: unit.recurrent_transform_state_var_seqs = outputs[-len(unit.recurrent_transform.state_vars):] if self.attrs['sampling'] > 1: if s == 0: self.act = [ T.alloc(numpy.cast['float32'](0), self.index.shape[0], self.index.shape[1], n_out) for act in outputs ] self.act = [ T.set_subtensor(tot[s::self.attrs['sampling']], act) for tot,act in zip(self.act, outputs) ] else: self.act = outputs[:unit.n_act] if len(outputs) > unit.n_act: self.aux = outputs[unit.n_act:] if state_memory: for i in range(len(self.act)): self.init_state[i].live_update = self.act[i][-1] if self.attrs['attention_store']: self.attention = [ self.aux[i].dimshuffle(0,2,1) for i,v in enumerate(sorted(unit.recurrent_transform.state_vars.keys())) if v.startswith('att_') ] # NBT for i in range(len(self.attention)): vec = T.eye(self.attention[i].shape[2], 1, -direction * (self.attention[i].shape[2] - 1)) last = vec.dimshuffle(1, 'x', 0).repeat(self.index.shape[1], axis=1) self.attention[i] = T.concatenate([self.attention[i][1:],last],axis=0)[::direction] self.cost_val = numpy.float32(0) if recurrent_transform == 'attention_align': back = T.ceil(self.aux[sorted(unit.recurrent_transform.state_vars.keys()).index('t')]) def make_output(base, yout, trace, length): length = T.cast(length, 'int32') idx = T.cast(trace[:length][::-1],'int32') x_out = T.concatenate([base[idx],T.zeros((self.index.shape[0] + 1 - length, base.shape[1]), 'float32')],axis=0) y_out = T.concatenate([yout[idx,T.arange(length)],T.zeros((self.index.shape[0] + 1 - length, ), 'float32')],axis=0) return x_out, y_out output, _ = theano.map(make_output, sequences = [base[0].output.dimshuffle(1,0,2), self.y_t.dimshuffle(1,2,0), back.dimshuffle(1,0), T.sum(self.index,axis=0,dtype='float32')]) self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) self.output = output[0].dimshuffle(1,0,2)[:-1] z = T.dot(self.act[0], self.T_W)[:-1] + self.T_b z = z.reshape((z.shape[0] * z.shape[1], z.shape[2])) idx = (self.index[1:].flatten() > 0).nonzero() idy = (self.index[1:][::-1].flatten() > 0).nonzero() y_out = T.cast(output[1],'int32').dimshuffle(1, 0)[:-1].flatten() nll, _ = T.nnet.crossentropy_softmax_1hot(x=z[idx], y_idx=y_out[idy]) self.cost_val = T.sum(nll) recog = T.argmax(z[idx], axis=1) real = y_out[idy] self.errors = lambda: T.sum(T.neq(recog, real)) return back += T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32') idx = (self.index[:-1].flatten() > 0).nonzero() idx = T.cast(back[::-1].flatten()[idx],'int32') x_out = base[0].output #x_out = x_out.dimshuffle(1,0,2).reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] #x_out = x_out.reshape((self.index.shape[1], self.index.shape[0] - 1, x_out.shape[1])).dimshuffle(1,0,2) x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] x_out = x_out.reshape((self.index.shape[0] - 1, self.index.shape[1], x_out.shape[1])) self.output = T.concatenate([x_out, base[0].output[1:]],axis=0) self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) return skips = T.dot(T.nnet.softmax(z), T.arange(z.shape[1], dtype='float32')).reshape(self.index[1:].shape) shift = T.arange(self.index.shape[1], dtype='float32') * T.cast(self.base[0].index.shape[0], 'float32') skips = T.concatenate([T.zeros_like(self.y_t[:1]),self.y_t[:-1]],axis=0) idx = shift + T.cumsum(skips, axis=0) idx = T.cast(idx[:-1].flatten(),'int32') #idx = (idx.flatten() > 0).nonzero() #idx = base[0].attention.flatten() x_out = base[0].output[::-1] x_out = x_out.reshape((x_out.shape[0] * x_out.shape[1], x_out.shape[2]))[idx] x_out = x_out.reshape((self.index.shape[0], self.index.shape[1], x_out.shape[1])) self.output = T.concatenate([base[0].output[-1:], x_out], axis=0)[::-1] self.attrs['n_out'] = base[0].attrs['n_out'] self.params.update(unit.params) return if recurrent_transform == 'batch_norm': self.params['sample_mean_batch_norm'].custom_update = T.dot(T.mean(self.act[0],axis=[0,1]),self.W_re) self.params['sample_mean_batch_norm'].custom_update_normalized = True self.make_output(self.act[0][::direction or 1], sample_mean=sample_mean, gamma=gamma) self.params.update(unit.params)