def Net(aa, yt, x): s=aa.shape[1] with tf.sg_context(name='NNReg', stride=1, act='leaky_relu', bn=True, reuse=tf.AUTO_REUSE): yt=tf.expand_dims(yt,2) v1=tf.expand_dims(x,2).sg_conv(dim=16, size=(1,1), name='gen9',pad="SAME",bn=True) v2=v1.sg_conv(dim=64, size=(1,1), name='gen1',pad="SAME",bn=True) v3=v2.sg_conv(dim=128, size=(1,1), name='gen2',pad="SAME",bn=True) v4=v3.sg_conv(dim=256, size=(1,1), name='gen3',pad="SAME",bn=True) v5=v4.sg_conv(dim=512, size=(1,1), name='gen4',pad="SAME",bn=True) v5=tf.tile(tf.expand_dims(tf.reduce_max(v5, axis=1),axis=1),[1,s,1,1]) vv5=v5 v1=yt.sg_conv(dim=16, size=(1,1), name='gen99',pad="SAME",bn=True) v2=v1.sg_conv(dim=64, size=(1,1), name='gen11',pad="SAME",bn=True) v3=v2.sg_conv(dim=128, size=(1,1), name='gen22',pad="SAME",bn=True) v4=v3.sg_conv(dim=256, size=(1,1), name='gen33',pad="SAME",bn=True) v5=v4.sg_conv(dim=512, size=(1,1), name='gen44',pad="SAME",bn=True) v5=tf.tile(tf.expand_dims(tf.reduce_max(v5, axis=1),axis=1),[1,s,1,1]) ff=tf.concat([tf.expand_dims(aa,2),v5], axis=-1) ff=tf.concat([ff,vv5], axis=-1) f1=ff.sg_conv(dim=256, size=(1,1), name='f1',pad="SAME",bn=True) f2=f1.sg_conv(dim=128, size=(1,1), name='f2',pad="SAME",bn=True) f3=f2.sg_conv(dim=2, size=(1,1), name='f3',pad="SAME",bn=False, act="linear") f3=tf.squeeze(f3,axis=2) return f3
def pairwise_dist(xt, y_p): a = xt.shape[1] b = y_p.shape[1] dist = tf.tile(tf.expand_dims(y_p, 1), [1, a, 1, 1]) - tf.tile( tf.expand_dims(xt, 2), [1, 1, b, 1]) dist = (dist[:, :, :, 0]**2 + dist[:, :, :, 1]**2) return dist
def sg_quasi_rnn(tensor, opt): # Split if opt.att: H, Z, F, O = tf.split(tensor, 4, axis=0) # (16, 150, 320) for all else: Z, F, O = tf.split(tensor, 3, axis=0) # (16, 150, 320) for all # step func def step(z, f, o, c): ''' Runs fo-pooling at each time step ''' c = f * c + (1 - f) * z if opt.att: # attention a = tf.nn.softmax(tf.einsum("ijk,ik->ij", H, c)) # alpha. (16, 150) k = (a.sg_expand_dims() * H).sg_sum( axis=1) # attentional sum. (16, 320) h = o * (k.sg_dense(act="linear") + \ c.sg_dense(act="linear")) else: h = o * c return h, c # hidden states, (new) cell memories # Do rnn loop c, hs = 0, [] timesteps = tensor.get_shape().as_list()[1] for t in range(timesteps): z = Z[:, t, :] # (16, 320) f = F[:, t, :] # (16, 320) o = O[:, t, :] # (16, 320) # apply step function h, c = step(z, f, o, c) # (16, 320), (16, 320) # save result hs.append(h.sg_expand_dims(axis=1)) # Concat to return H = tf.concat(hs, 1) # (16, 150, 320) seqlen = tf.to_int32( tf.reduce_sum(tf.sign(tf.abs(tf.reduce_sum(H, axis=-1))), 1)) # (16,) float32 h = tf.reverse_sequence(input=H, seq_lengths=seqlen, seq_dim=1)[:, 0, :] # last hidden state vector if opt.is_enc: H_z = tf.tile((h.sg_dense(act="linear").sg_expand_dims(axis=1)), [1, timesteps, 1]) H_f = tf.tile((h.sg_dense(act="linear").sg_expand_dims(axis=1)), [1, timesteps, 1]) H_o = tf.tile((h.sg_dense(act="linear").sg_expand_dims(axis=1)), [1, timesteps, 1]) concatenated = tf.concat([H, H_z, H_f, H_o], 0) # (16*4, 150, 320) return concatenated else: return H # (16, 150, 320)
def sg_quasi_rnn(tensor, opt): # Split if opt.att: H, Z, F, O = tf.split(axis=0, num_or_size_splits=4, value=tensor) # (16, 150, 320) for all else: Z, F, O = tf.split(axis=0, num_or_size_splits=3, value=tensor) # (16, 150, 320) for all # step func def step(z, f, o, c): ''' Runs fo-pooling at each time step ''' c = f * c + (1 - f) * z if opt.att: # attention a = tf.nn.softmax(tf.einsum("ijk,ik->ij", H, c)) # alpha. (16, 150) k = (a.sg_expand_dims() * H).sg_sum( dims=1) # attentional sum. (16, 150) h = o * (k.sg_dense(act="linear") + c.sg_dense(act="linear")) else: h = o * c return h, c # hidden states, (new) cell memories # Do rnn loop c, hs = 0, [] timesteps = tensor.get_shape().as_list()[1] for t in range(timesteps): z = Z[:, t, :] # (16, 320) f = F[:, t, :] # (16, 320) o = O[:, t, :] # (16, 320) # apply step function h, c = step(z, f, o, c) # (16, 320), (16, 320) # save result hs.append(h.sg_expand_dims(dim=1)) # Concat to return H = tf.concat(axis=1, values=hs) # (16, 150, 320) if opt.is_enc: H_z = tf.tile((h.sg_dense(act="linear").sg_expand_dims(dim=1)), [1, timesteps, 1]) H_f = tf.tile((h.sg_dense(act="linear").sg_expand_dims(dim=1)), [1, timesteps, 1]) H_o = tf.tile((h.sg_dense(act="linear").sg_expand_dims(dim=1)), [1, timesteps, 1]) concatenated = tf.concat(axis=0, values=[H, H_z, H_f, H_o]) # (16*4, 150, 320) return concatenated else: return H # (16, 150, 320)
def zero_state2(self, batch_size): dtype = tf.float32 state_size = self.state_size zeros = [0] * 2 for i in range(2): zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) temp = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype) temp.set_shape(_state_size_with_prefix(state_size, prefix=[None])) zeros[i] = tf.tile((temp.sg_expand_dims(axis=1)), [1, self._seqlen, 1])
def tower_infer_dec(chars, scope, rnn_cell, dec_cell, word_emb, rnn_state, out_reuse_vars=False, dev='/cpu:0'): with tf.device(dev): with tf.variable_scope('embatch_size', reuse=True): # (vocab_size, latent_dim) emb_char = tf.sg_emb(name='emb_char', voca_size=Hp.char_vs, dim=Hp.hd, dev=dev) emb_word = tf.sg_emb(name='emb_word', emb=word_emb, voca_size=Hp.word_vs, dim=300, dev=dev) print(chars) ch = chars ch = tf.reverse_sequence(input=ch, seq_lengths=[Hp.c_maxlen] * Hp.batch_size, seq_dim=1) reuse_vars = reuse_vars_enc = True # -------------------------- BYTENET ENCODER -------------------------- with tf.variable_scope('encoder'): # embed table lookup enc = ch.sg_lookup(emb=emb_char) #(batch, sentlen, latentdim) # loop dilated conv block for i in range(Hp.num_blocks): enc = (enc.sg_res_block(size=5, rate=1, name="enc1_%d" % (i), is_first=True, reuse_vars=reuse_vars, dev=dev).sg_res_block( size=5, rate=2, name="enc2_%d" % (i), reuse_vars=reuse_vars, dev=dev).sg_res_block( size=5, rate=4, name="enc4_%d" % (i), reuse_vars=reuse_vars, dev=dev).sg_res_block( size=5, rate=8, name="enc8_%d" % (i), reuse_vars=reuse_vars, dev=dev).sg_res_block( size=5, rate=16, name="enc16_%d" % (i), reuse_vars=reuse_vars, dev=dev)) byte_enc = enc # -------------------------- QCNN + QPOOL ENCODER #1 -------------------------- with tf.variable_scope('quazi'): #quasi cnn layer ZFO [batch * 3, seqlen, dim2 ] conv = byte_enc.sg_quasi_conv1d(is_enc=True, size=4, name="qconv_1", dev=dev, reuse_vars=reuse_vars) # c = f * c + (1 - f) * z, h = o*c [batch * 4, seqlen, hd] pool0 = conv.sg_quasi_rnn(is_enc=False, att=False, name="qrnn_1", reuse_vars=reuse_vars, dev=dev) qpool_last = pool0[:, -1, :] # -------------------------- MAXPOOL along time dimension -------------------------- inpt_maxpl = tf.expand_dims(byte_enc, 1) # [batch, 1, seqlen, channels] maxpool = tf.nn.max_pool(inpt_maxpl, [1, 1, Hp.c_maxlen, 1], [1, 1, 1, 1], 'VALID') maxpool = tf.squeeze(maxpool, [1, 2]) # -------------------------- HIGHWAY -------------------------- concat = qpool_last + maxpool with tf.variable_scope('highway', reuse=reuse_vars): input_lstm = highway(concat, concat.get_shape()[-1], num_layers=1) # -------------------------- CONTEXT LSTM -------------------------- input_lstm = tf.nn.dropout(input_lstm, Hp.keep_prob) with tf.variable_scope('contx_lstm', reuse=reuse_vars): output, rnn_state = rnn_cell(input_lstm, rnn_state) beam_size = 8 reuse_vars = out_reuse_vars greedy = False if greedy: dec_state = rnn_state dec_out = [] d_out = tf.constant([1] * Hp.batch_size) for idx in range(Hp.w_maxlen): w_input = d_out.sg_lookup(emb=emb_word) dec_state = tf.contrib.rnn.LSTMStateTuple(c=dec_state.c, h=dec_state.h) with tf.variable_scope('dec_lstm', reuse=idx > 0 or reuse_vars): d_out, dec_state = dec_cell(w_input, dec_state) dec_out.append(d_out) d_out = tf.expand_dims(d_out, 1).sg_conv1d_gpus(size=1, dim=Hp.word_vs, name="out_conv", act="linear", dev=dev, reuse=idx > 0 or reuse_vars) d_out = tf.squeeze(d_out).sg_argmax() dec_out = tf.stack(dec_out, 1) dec = dec_out.sg_conv1d_gpus(size=1, dim=Hp.word_vs, name="out_conv", act="linear", dev=dev, reuse=True) return dec.sg_argmax(), rnn_state else: # ------------------ BEAM SEARCH -------------------- dec_state = tf.contrib.rnn.LSTMStateTuple( tf.tile(tf.expand_dims(rnn_state[0], 1), [1, beam_size, 1]), tf.tile(tf.expand_dims(rnn_state[1], 1), [1, beam_size, 1])) initial_ids = tf.constant([1] * Hp.batch_size) def symbols_to_logits_fn(ids, dec_state): dec = [] dec_c, dec_h = [], [] # (batch x beam_size x decoded_seq) ids = tf.reshape(ids, [Hp.batch_size, beam_size, -1]) print("dec_state ", dec_state[0].get_shape().as_list()) for ind in range(beam_size): with tf.variable_scope('dec_lstm', reuse=ind > 0 or reuse_vars): w_input = ids[:, ind, -1].sg_lookup(emb=emb_word) dec_state0 = tf.contrib.rnn.LSTMStateTuple( c=dec_state.c[:, ind, :], h=dec_state.h[:, ind, :]) dec_out, dec_state_i = dec_cell(w_input, dec_state0) dec_out = tf.expand_dims(dec_out, 1) dec_i = dec_out.sg_conv1d_gpus(size=1, dim=Hp.word_vs, name="out_conv", act="linear", dev=dev, reuse=ind > 0 or reuse_vars) dec.append(tf.squeeze(dec_i, 1)) dec_c.append(dec_state_i[0]) dec_h.append(dec_state_i[1]) return tf.stack(dec, 1), tf.contrib.rnn.LSTMStateTuple( tf.stack(dec_c, 1), tf.stack(dec_h, 1)) final_ids, final_probs = beam_search.beam_search(symbols_to_logits_fn, dec_state, initial_ids, beam_size, Hp.w_maxlen - 1, Hp.word_vs, 3.5, eos_id=2) return final_ids[:, 0, :], rnn_state
def sg_quasi_rnn(tensor, opt): # Split if opt.att: H, Z, F, O = tf.split(axis=0, num_or_size_splits=4, value=tensor) # (b, seqlen, hd) for all else: Z, F, O = tf.split(axis=0, num_or_size_splits=3, value=tensor) # (b, seqlen, hd) for all # step func def step(z, f, o, c): ''' Runs fo-pooling at each time step ''' c = f * c + (1 - f) * z if opt.att: # attention a = tf.nn.softmax(tf.einsum("ijk,ik->ij", H, c)) # alpha. (b, seqlen) k = (a.sg_expand_dims() * H).sg_sum( axis=1) # attentional sum. (b, seqlen) h = o * (k.sg_dense_gpus(act="linear",name = "k%d_%s"%(t,opt.name),dev = opt.dev,reuse=opt.reuse_vars)\ + c.sg_dense_gpus(act="linear",name = "c%d_%s"%(t,opt.name),dev = opt.dev,reuse=opt.reuse_vars)) else: h = o * c return h, c # hidden states, (new) cell memories # Do rnn loop c, hs = 0, [] timesteps = tensor.get_shape().as_list()[1] for t in range(timesteps): z = Z[:, t, :] # (b, hd) f = F[:, t, :] # (b, hd) o = O[:, t, :] # (b, hd) # apply step function h, c = step(z, f, o, c) # (b, hd), (b, hd) # save result hs.append(h.sg_expand_dims(axis=1)) # Concat to return H = tf.concat(hs, 1) # (b, seqlen, hd) if opt.is_enc: H_z = tf.tile( (h.sg_dense_gpus(act="linear", name="z_%s" % (opt.name), dev=opt.dev, reuse=opt.reuse_vars).sg_expand_dims(axis=1)), [1, timesteps, 1]) H_f = tf.tile( (h.sg_dense_gpus(act="linear", name="f_%s" % (opt.name), dev=opt.dev, reuse=opt.reuse_vars).sg_expand_dims(axis=1)), [1, timesteps, 1]) H_o = tf.tile( (h.sg_dense_gpus(act="linear", name="o_%s" % (opt.name), dev=opt.dev, reuse=opt.reuse_vars).sg_expand_dims(axis=1)), [1, timesteps, 1]) concatenated = tf.concat(axis=0, values=[H, H_z, H_f, H_o]) # (b*4, seqlen, hd) return concatenated else: return H # (b, seqlen, hd)
def rnn_body(time, subrec1, subrec2, rnn_state, rnn_h, crnn_state, crnn_h, losses): x = x_sent.read(time) y = x_sent.read(time + 1) # (batch, sentlen) = (16, 200) # shift target by one step for training source y_src = tf.concat([tf.zeros((Hp.batch_size, 1), tf.int32), y[:, :-1]], 1) reuse_vars = time == tf.constant(0) or reu_vars # -------------------------- BYTENET ENCODER -------------------------- # embed table lookup enc = x.sg_lookup(emb=emb_x) #(batch, sentlen, latentdim) # loop dilated conv block for i in range(num_blocks): enc = (enc.sg_res_block( size=5, rate=1, name="enc1_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=5, rate=2, name="enc2_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=5, rate=4, name="enc4_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=5, rate=8, name="enc8_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=5, rate=16, name="enc16_%d" % (i), reuse_vars=reuse_vars)) # -------------------------- QCNN + QPOOL ENCODER with attention #1 -------------------------- #quasi cnn layer ZFO [batch * 3, t, dim2 ] conv = enc.sg_quasi_conv1d(is_enc=True, size=3, name="qconv_1", reuse_vars=reuse_vars) #attention layer # recurrent layer # 1 + final encoder hidden state subrec1 = tf.tile((subrec1.sg_expand_dims(axis=1)), [1, Hp.maxlen, 1]) concat = conv.sg_concat(target=subrec1, axis=0) # (batch*4, sentlen, latentdim) pool = concat.sg_quasi_rnn(is_enc=True, att=True, name="qrnn_1", reuse_vars=reuse_vars) subrec1 = pool[:Hp.batch_size, -1, :] # last character in sequence # -------------------------- QCNN + QPOOL ENCODER with attention #2 -------------------------- # quazi cnn ZFO (batch*3, sentlen, latentdim) conv = pool.sg_quasi_conv1d(is_enc=True, size=2, name="qconv_2", reuse_vars=reuse_vars) # (batch, sentlen-duplicated, latentdim) subrec2 = tf.tile((subrec2.sg_expand_dims(axis=1)), [1, Hp.maxlen, 1]) # (batch*4, sentlen, latentdim) concat = conv.sg_concat(target=subrec2, axis=0) pool = concat.sg_quasi_rnn(is_enc=True, att=True, name="qrnn_2", reuse_vars=reuse_vars) subrec2 = pool[:Hp.batch_size, -1, :] # last character in sequence # -------------------------- ConvLSTM with RESIDUAL connection and MULTIPLICATIVE block -------------------------- #residual block causal = False # for encoder crnn_input = (pool[:Hp.batch_size, :, :].sg_bypass_gpus( name='relu_0', act='relu', bn=(not causal), ln=causal).sg_conv1d_gpus(name="dimred_0", size=1, dev="/cpu:0", reuse=reuse_vars, dim=Hp.hd / 2, act='relu', bn=(not causal), ln=causal)) # conv LSTM with tf.variable_scope("mem/clstm") as scp: (crnn_state, crnn_h) = crnn_cell(crnn_input, (crnn_state, crnn_h), size=5, reuse_vars=reuse_vars) # dimension recover and residual connection rnn_input0 = pool[:Hp.batch_size,:,:] + crnn_h\ .sg_conv1d_gpus(name = "diminc_0",size=1,dev="/cpu:0", dim=Hp.hd,reuse=reuse_vars, act='relu', bn=(not causal), ln=causal) # -------------------------- QCNN + QPOOL ENCODER with attention #3 -------------------------- # pooling for lstm input # quazi cnn ZFO (batch*3, sentlen, latentdim) conv = rnn_input0.sg_quasi_conv1d(is_enc=True, size=2, name="qconv_3", reuse_vars=reuse_vars) pool = conv.sg_quasi_rnn(is_enc=True, att=False, name="qrnn_3", reuse_vars=reuse_vars) rnn_input = pool[:Hp.batch_size, -1, :] # last character in sequence # -------------------------- LSTM with RESIDUAL connection and MULTIPLICATIVE block -------------------------- # recurrent block with tf.variable_scope("mem/lstm") as scp: (rnn_state, rnn_h) = rnn_cell(rnn_input, (rnn_state, rnn_h)) rnn_h2 = tf.tile(((rnn_h + rnn_input).sg_expand_dims(axis=1)), [1, Hp.maxlen, 1]) # -------------------------- BYTENET DECODER -------------------------- # CNN decoder dec = y_src.sg_lookup(emb=emb_y).sg_concat(target=rnn_h2, name="dec") for i in range(num_blocks): dec = (dec.sg_res_block( size=3, rate=1, causal=True, name="dec1_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=3, rate=2, causal=True, name="dec2_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=3, rate=4, causal=True, name="dec4_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=3, rate=8, causal=True, name="dec8_%d" % (i), reuse_vars=reuse_vars).sg_res_block( size=3, rate=16, causal=True, name="dec16_%d" % (i), reuse_vars=reuse_vars)) # final fully convolution layer for softmax dec = dec.sg_conv1d_gpus(size=1, dim=Hp.vs, name="out", summary=False, dev=self._dev, reuse=reuse_vars) ce_array = dec.sg_ce(target=y, mask=True, name="cross_ent_example") cross_entropy_mean = tf.reduce_mean(ce_array, name='cross_entropy') losses = tf.add_n([losses, cross_entropy_mean], name='total_loss') return (time + 1, subrec1, subrec2, rnn_state, rnn_h, crnn_state, crnn_h, losses)