def lstm_doc_dec(input_cnn, final_enc_state, batch_size=20, num_rnn_layers=2, rnn_size=650, max_doc_length=35, dropout=0.0): # scoring each sentence with another LSTM that reads the doc again with tf.variable_scope('LSTMdec'): def create_rnn_cell(): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True, forget_bias=0.0) if dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.-dropout) return cell if num_rnn_layers > 1: cell = tf.contrib.rnn.MultiRNNCell([create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) else: cell = create_rnn_cell() initial_rnn_state = final_enc_state input_cnn = tf.reshape(input_cnn, [batch_size, max_doc_length, -1]) input_cnn2 = [tf.squeeze(x, [1]) for x in tf.split(input_cnn, max_doc_length, 1)] outputs, final_rnn_state = tf.contrib.rnn.static_rnn(cell, input_cnn2, initial_state=initial_rnn_state, dtype=tf.float32) return adict( initial_dec_state=initial_rnn_state, final_dec_state=final_rnn_state, dec_outputs=outputs )
def compute_tstats(self, t): nw, ns = zip(*t) nw = np.array(nw, dtype=np.float32) ns = np.array(ns, dtype=np.float32) d = {} d['w'] = np.mean(nw) d['s'] = np.std(ns) return U.adict(d)
def set_flags(FLAGS,i): flags = U.adict(FLAGS.copy()) flags.rnn_cell = FLAGS.rnn_cells[i] flags.rnn_size = FLAGS.rnn_sizes[i] flags.bidirectional = FLAGS.rnn_bis[i] flags.attn_size = FLAGS.attn_sizes[i] flags.attn_depth = FLAGS.attn_depths[i] flags.attn_temp = FLAGS.attn_temps[i] flags.pad = FLAGS.pads[i] return flags
def append(self, value): block = self.blocks[value] = adict(val=value) if self.current is None: self.current = block else: self.tail.r = block block.r = self.current self.tail = block
def cnn_sen_enc(word_vocab_size, word_embed_size=50, batch_size=20, num_highway_layers=2, max_sen_length=65, kernels=[1, 2, 3, 4, 5, 6, 7], kernel_features=[50, 100, 150, 200, 200, 200, 200], max_doc_length=35, pretrained=None): # cnn sentence encoder assert len(kernels) == len( kernel_features), 'Kernel and Features must have the same size' input_ = tf.placeholder(tf.int32, shape=[batch_size, max_doc_length, max_sen_length], name="input") ''' First, embed words to sentence ''' with tf.variable_scope('embedding'): if pretrained is not None: word_embedding = tf.get_variable( name='word_embedding', shape=[word_vocab_size, word_embed_size], initializer=tf.constant_initializer(pretrained)) else: word_embedding = tf.get_variable( name='word_embedding', shape=[word_vocab_size, word_embed_size]) ''' this op clears embedding vector of first symbol (symbol at position 0, which is by convention the position of the padding symbol). It can be used to mimic Torch7 embedding operator that keeps padding mapped to zero embedding vector and ignores gradient updates. For that do the following in TF: 1. after parameter initialization, apply this op to zero out padding embedding vector 2. after each gradient update, apply this op to keep padding at zero''' clear_word_embedding_padding = tf.scatter_update( word_embedding, [0], tf.constant(0.0, shape=[1, word_embed_size])) # [batch_size, max_doc_length, max_sen_length, word_embed_size] input_embedded = tf.nn.embedding_lookup(word_embedding, input_) input_embedded = tf.reshape(input_embedded, [-1, max_sen_length, word_embed_size]) ''' Second, apply convolutions ''' # [batch_size x max_doc_length, cnn_size] # where cnn_size=sum(kernel_features) input_cnn = tdnn(input_embedded, kernels, kernel_features) ''' Maybe apply Highway ''' if num_highway_layers > 0: input_cnn = highway(input_cnn, input_cnn.get_shape()[-1], num_layers=num_highway_layers) return adict(input=input_, clear_word_embedding_padding=clear_word_embedding_padding, input_embedded=input_embedded, input_cnn=input_cnn)
def batch_stream(self, stop=False): tok_stream = self.reader.chunk_stream(stop=stop) while True: batches = self.make_batches(tok_stream) if batches is None: break for c, w in zip(batches[0], batches[1]): if self.trim_chars: c = self.trim_batch(c) yield U.adict( { 'w':w , 'c':c } )
def label_prediction_att(outputs_enc, outputs_dec): # scoring labels with att logits = [] with tf.variable_scope('Prediction') as scope: for idx, output in enumerate(zip(outputs_enc, outputs_dec)): if idx > 0: scope.reuse_variables() output_enc, output_dec = output logits.append(linear(tf.concat([output_enc, output_dec], 1), 2)) return adict(logits=logits)
def self_prediction(outputs, word_vocab_size): # predicting the words in therein, like a paragraph vector logits_pretrain = [] with tf.variable_scope('SelfPrediction') as scope: for idx, output in enumerate(outputs): if idx > 0: scope.reuse_variables() logits_pretrain.append(linear(output, word_vocab_size)) return adict(plogits=logits_pretrain)
def label_prediction(outputs): # scoring labels logits = [] with tf.variable_scope('Prediction') as scope: for idx, output in enumerate(outputs): if idx > 0: scope.reuse_variables() logits.append(linear(output, 2)) return adict(logits=logits)
def bilstm_doc_enc(input_cnn, batch_size=20, num_rnn_layers=2, rnn_size=650, max_doc_length=35, dropout=0.0): ''' Bilstm document encoder It constructs a list of sentence vectors in the document ''' with tf.variable_scope('bilstm_enc'): def create_rnn_cell(): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True, forget_bias=0.0) if dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1. - dropout) return cell if num_rnn_layers > 1: cell_fw = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) cell_bw = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) else: cell_fw = create_rnn_cell() cell_bw = create_rnn_cell() initial_rnn_state_fw = cell_fw.zero_state(batch_size, dtype=tf.float32) initial_rnn_state_bw = cell_bw.zero_state(batch_size, dtype=tf.float32) input_cnn = tf.reshape(input_cnn, [batch_size, max_doc_length, -1]) input_cnn2 = [ tf.squeeze(x, [1]) for x in tf.split(input_cnn, max_doc_length, 1) ] outputs, final_rnn_state_fw, final_rnn_state_bw = tf.contrib.rnn.static_bidirectional_rnn( cell_fw, cell_bw, input_cnn2, initial_state_fw=initial_rnn_state_fw, initial_state_bw=initial_rnn_state_bw, dtype=tf.float32) return adict(initial_enc_state_fw=initial_rnn_state_fw, initial_enc_state_bw=initial_rnn_state_bw, final_enc_state_fw=final_rnn_state_fw, final_enc_state_bw=final_rnn_state_bw, enc_outputs=outputs)
def compute_ystats(self, y): y = np.array(y, dtype=np.float32) v, c = np.unique(y, return_counts=True) d = {} d['mean'] = np.mean(y) d['std'] = np.std(y) d['min'] = np.min(y) d['max'] = np.max(y) d['n'] = len(y) d['v'] = v d['c'] = c return U.adict(d)
def loss_extraction(logits, batch_size, max_doc_length): # extraction loss with tf.variable_scope('Loss'): targets = tf.placeholder(tf.int64, [batch_size, max_doc_length], name='targets') target_list = [tf.squeeze(x, [1]) for x in tf.split(targets, max_doc_length, 1)] loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logits, labels = target_list), name='loss') return adict( targets=targets, loss=loss )
def loss_pretrain(logits, batch_size, max_doc_length, word_vocab_size): # reconstruction loss with tf.variable_scope('Loss'): targets = tf.placeholder(tf.float32, [batch_size, max_doc_length, word_vocab_size], name='targets') target_list = [tf.squeeze(x, [1]) for x in tf.split(targets, max_doc_length, 1)] loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = logits, labels = target_list), name='loss') return adict( targets=targets, loss=loss )
def parse_line(self, line, t=None, keys=None): rec = line.strip().split(self.sep) d = {} for k,v in self.fields.items(): if keys and v not in keys: continue if isinstance(v, basestring): d[v] = rec[k].strip() else: d.update(v.parse_line(rec[k].strip())) if t and v=='y': y = float(d[v]) p = t[y] if self.rng.rand()>p: #print('sample NO\t[{},{}]'.format(y,p)) return None #print('sample YES\t[{},{}]'.format(y,p)) return U.adict(d)
def loss_generation(logits, batch_size, max_output_length): '''compute sequence generation loss''' with tf.variable_scope('Loss'): targets = tf.placeholder(tf.int64, [batch_size, max_output_length], name='targets') mask = tf.placeholder(tf.float32, [batch_size, max_output_length], name='mask') target_list = [ tf.squeeze(x, [1]) for x in tf.split(targets, max_output_length, 1) ] mask_list = [ tf.squeeze(x, [1]) for x in tf.split(mask, max_output_length, 1) ] loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, target_list, mask_list) return adict(targets=targets, mask=mask, loss=loss)
def training_graph(loss, learning_rate=1.0, max_grad_norm=5.0): ''' Builds training graph. ''' global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('SGD_Training'): # SGD learning parameter learning_rate = tf.Variable(learning_rate, trainable=False, name='learning_rate') # collect all trainable variables tvars = tf.trainable_variables() grads, global_norm = tf.clip_by_global_norm(tf.gradients(loss, tvars), max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) return adict( learning_rate=learning_rate, global_step=global_step, global_norm=global_norm, train_op=train_op)
def lstm_doc_enc(input_cnn, question_cnn, batch_size=20, num_rnn_layers=2, rnn_size=150, max_doc_length=35, dropout=0.0): # lstm document encoder with tf.variable_scope('LSTMenc'): def create_rnn_cell(): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True, forget_bias=0.0) if dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1. - dropout) return cell if num_rnn_layers > 1: cell = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) else: cell = create_rnn_cell() initial_rnn_state = cell.zero_state(batch_size, dtype=tf.float32) input_cnn = tf.reshape(input_cnn, [batch_size, max_doc_length, -1]) question_cnn = tf.reshape(question_cnn, [batch_size, max_doc_length, -1]) print(input_cnn.get_shape(), question_cnn.get_shape()) with tf.variable_scope('Sentence-level_Compare-Aggregate'): input_cnn, question_cnn = sent_compare_aggregate( input_cnn, question_cnn, batch_size, max_doc_length) print(input_cnn.get_shape(), question_cnn.get_shape()) input_cnn2 = [ tf.squeeze(x, [1]) for x in tf.split(input_cnn, max_doc_length, 1) ] question_cnn2 = [ tf.squeeze(x, [1]) for x in tf.split(question_cnn, max_doc_length, 1) ] print(np.asarray(input_cnn2).shape, np.asarray(question_cnn2).shape) with tf.variable_scope('/input'): outputs, final_rnn_state = tf.contrib.rnn.static_rnn( cell, input_cnn2, initial_state=initial_rnn_state, dtype=tf.float32) with tf.variable_scope('/question'): question_outputs, _ = tf.contrib.rnn.static_rnn( cell, question_cnn2, initial_state=initial_rnn_state, dtype=tf.float32) #question_outputs = tf.reduce_mean(tf.transpose(question_outputs, [1,0,2]),1) print(np.asarray(outputs).shape, np.asarray(question_outputs).shape) return adict(initial_enc_state=initial_rnn_state, final_enc_state=final_rnn_state, enc_outputs=outputs, question_outputs=question_outputs)
def package(self, word_tokens, char_tokens, ws, cs, text ): return U.adict( { self.words:word_tokens , self.chars:char_tokens, self.ws:ws , self.cs:cs, self.text:text } )
def flexible_attention_decoder(enc_outputs, batch_size=20, num_rnn_layers=1, rnn_size=80, enc_state_size=650, max_output_length=5, dropout=0.0, word_vocab_size=100, word_embed_size=50, mode='train'): ''' FIX THIS an attention decoder which supports customized output layers ''' input_dec = tf.placeholder(tf.int32, shape=[batch_size, max_output_length], name="input_dec") dec_input = [ tf.squeeze(s, [1]) for s in tf.split(input_dec, max_output_length, 1) ] with tf.variable_scope('target_embedding'): target_embedding = tf.get_variable( 'target_embedding', [word_vocab_size, word_embed_size], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=1e-4)) clear_target_embedding_padding = tf.scatter_update( target_embedding, [0], tf.constant(0.0, shape=[1, word_embed_size])) embed_dec_input = [ tf.nn.embedding_lookup(target_embedding, x) for x in dec_input ] with tf.variable_scope('output_projection'): w = tf.get_variable( 'w', [rnn_size, word_vocab_size], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=1e-4)) w_t = tf.transpose(w) v = tf.get_variable( 'v', [word_vocab_size], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=1e-4)) with tf.variable_scope('lstm_dec'): def create_rnn_cell(): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True, forget_bias=0.0) if dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1. - dropout) return cell if num_rnn_layers > 1: cell = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) else: cell = create_rnn_cell() initial_rnn_state = cell.zero_state(batch_size, dtype=tf.float32) enc_outputs = [ tf.reshape(x, [batch_size, 1, enc_state_size]) for x in enc_outputs ] enc_states = tf.concat(axis=1, values=enc_outputs) initial_state_attention = (mode == 'decode') loop_function = None if mode == 'decode': loop_function = _extract_argmax_and_embed(target_embedding, (w, v), update_embedding=False) decoder_outputs, dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder( embed_dec_input, initial_rnn_state, enc_states, cell, num_heads=1, loop_function=loop_function, initial_state_attention=initial_state_attention) with tf.variable_scope('output'): model_outputs = [] for i in range(len(decoder_outputs)): if i > 0: tf.get_variable_scope().reuse_variables() model_outputs.append(tf.nn.xw_plus_b(decoder_outputs[i], w, v)) outputs = None topk_log_probs = None topk_ids = None if mode == 'decode': with tf.variable_scope('decode_output'): best_outputs = [tf.argmax(x, 1) for x in model_outputs] tf.logging.info('best_outputs%s', best_outputs[0].get_shape()) outputs = tf.concat( axis=1, values=[tf.reshape(x, [batch_size, 1]) for x in best_outputs]) topk_log_probs, topk_ids = tf.nn.top_k( tf.log(tf.nn.softmax(model_outputs[-1])), batch_size * 2) return adict(input_dec=input_dec, clear_target_embedding_padding=clear_target_embedding_padding, logits=model_outputs, outputs=outputs, topk_log_probs=topk_log_probs, topk_ids=topk_ids)
def batch(self, ids=None, labels=None, words=None, chars=None, w=None, c=None, trim_words=None, trim_chars=None, spad='pre', wpad='post', cpad='post', split_sentences=False, batch_size=0, lines=None, is_test=False, ): if ids: self.last = (ids, labels, words, chars, w, c) else: (ids, labels, words, chars, w, c) = self.last if trim_words==None: trim_words=self.trim_words if not trim_words and self.max_text_length==None: self.max_text_length = self.reader.get_maxlen() if trim_chars==None: trim_chars=self.trim_chars n = len(ids) b = { 'n' : n } ''' partial ??? ''' # # if not full batch.....just copy first item to fill # for i in range(batch_size-n): # ids.append(ids[0]) # labels.append(labels[0]) # words.append(words[0]) # chars.append(chars[0]) b['id'] = ids # <-- THIS key ('id') SHOULD COME FROM FIELD_PARSER.fields y = np.array(labels, dtype=np.float32) if self.normy: y = self.normalize(y) y = y[...,None]#y = np.expand_dims(y, 1) b['y'] = y # <-- THIS key ('y') SHOULD COME FROM FIELD_PARSER.fields if w and not isListEmpty(words): m = (self.max_text_length,) if trim_words: m = (None,) if split_sentences: m = (None,) + m m = (None,) + m p = (wpad,) if split_sentences: p = (spad,) + p p = (None,) + p ## p = (None,spad,wpad) ??? word_tensor, seq_lengths= U.pad_sequences(words, m=m, p=p) b['w'] = word_tensor b['s'] = seq_lengths b['p'] = p[1:] ## b.p = ( spad, wpad [,cpad] ) ??? b['x'] = b['w'] if c and not isListEmpty(chars): m = (self.max_word_length,) if trim_chars: m = (None,) if trim_words: m = (None,) + m else: m = (self.max_text_length,) + m if split_sentences: m = (None,) + m m = (None,) + m p = (wpad, cpad) if split_sentences: p = (spad,) + p p = (None,) + p ## p = (None,spad,wpad,cpad) ??? try: char_tensor, seq_lengths = U.pad_sequences(chars, m=m, p=p) except (IndexError, ValueError): print('ERROR!') for i, cc in enumerate(chars): if len(cc)==0: print('') print(ids[i]) print(cc) sys.exit() b['c'] = char_tensor b['s'] = seq_lengths b['p'] = p[1:] ## b.p = ( spad, wpad [,cpad] ) ??? b['x'] = b['c'] if not (c or w) and not isListEmpty(lines): b['t'] = lines #b['x'] = b['t'] b['is_test']=is_test ################## # pickle b here... if self.pkl and self.epoch_ct==1: if self.batch_ct==0: # delete old pkl files cmd = 'rm {}'.format(self.batch_file(ct='*')) print(cmd) os.system(cmd) self.batch_ct +=1 self.max_batch = self.batch_ct with open(self.batch_file(), 'wb') as handle: pickle.dump(b, handle, protocol=pickle.HIGHEST_PROTOCOL) ################## return U.adict(b)
def dict2ns(dict): return U.adict(dict)
def batch_stream(self, stop=False, skip_ids=None, hit_ids=None, w=None, c=None, sample=False, partial=False, FLAGS=None, SZ=0, IS=None, skip_test=None, ): if FLAGS is None: FLAGS=self.FLAGS self.epoch_ct +=1 self.batch_ct =0 ## if self.pkl and self.epoch_ct>1: for i in range(self.max_batch): with open(self.batch_file(ct=i+1), 'rb') as handle: b = pickle.load(handle) yield U.adict(b) ## END FUNCTION HERE !! return ####################################################### spad = FLAGS.spad wpad = FLAGS.wpad cpad = FLAGS.cpad if w is None: w=FLAGS.embed.word if c is None: c=FLAGS.embed.char trim_words = FLAGS.trim_words split_sentences = FLAGS.split_sentences if not split_sentences: SZ=0 self._word_count = 0 i, ids, labels, words, chars, ns, nw, lines = 0,[],[],[],[],0,0,[] is_test = FLAGS.is_test for d in self.reader.line_stream(stop=stop, t=self.t if sample else None):# reader=FieldParser! if skip_ids is not None: if d.id in skip_ids: continue if is_test and skip_test is not None: if d.id in skip_test: continue if hit_ids: if d.id not in hit_ids: continue # SKIP ? if w and len(d.w)==0: print('SKIP : len(d.w)==0!!!!') continue # SKIP ? if len(d.t)==0: print('SKIP : len(d.t)==0!!!!') continue dw = d.w dc = d.c if split_sentences: dw = U.lindexsplit(d.w, d.ws) dc = U.lindexsplit(d.c, d.cs) ############################# if SZ>0: _ns = max(ns, len(dw)) _nw = max(nw, max(map(len, dw))) sz = (i+1) * _ns * _nw if sz > 2*SZ: #print('SKIP\tsz={0}\t[{1}x{2}x{3}] '.format(sz, i+1, _ns, _nw)) continue ns, nw = _ns, _nw ############################# ids.append(d.id) labels.append(d.y) lines.append(d.t) words.append(dw) chars.append(dc) if c or w: self._word_count+=len(d.w) else: self._word_count+= d.t.count(' ') i+=1 #if i==self.batch_size: if (SZ>0 and (sz>SZ or i==self.batch_size*4)) or (SZ==0 and i==self.batch_size): #print('sz={0}\t[{1}x{2}x{3}] '.format(i*ns*nw, i, ns, nw)) yield self.batch(ids, labels, words, chars, w, c, trim_words, spad=spad, wpad=wpad, cpad=cpad, split_sentences=split_sentences, lines=lines, is_test=is_test) i, ids, labels, words, chars, ns, nw, lines = 0,[],[],[],[],0,0,[] is_test = FLAGS.is_test if i>0 and partial: yield self.batch(ids, labels, words, chars, w, c, trim_words, spad=spad, wpad=wpad, cpad=cpad, split_sentences=split_sentences, lines=lines, is_test=is_test)
def parse_config(config_file, parser): #parser = options.get_parser() argv=[]# override config file here FLAGS = get_config(parser=parser, config_file=config_file, argv=argv) FLAGS.chkpt_dir = U.make_abs(FLAGS.chkpt_dir) if FLAGS.load_model: if FLAGS.load_chkpt_dir: FLAGS.load_chkpt_dir = U.make_abs(FLAGS.load_chkpt_dir) else: FLAGS.load_chkpt_dir = FLAGS.chkpt_dir else: if FLAGS.model=='HANModel': FLAGS.epoch_unfreeze_word = 0 FLAGS.cwd = os.getcwd() FLAGS.log_file = os.path.abspath(os.path.join(FLAGS.cwd, 'log.txt')) FLAGS.rand_seed = U.seed_random(FLAGS.rand_seed) if FLAGS.id_dir is None: FLAGS.id_dir = FLAGS.data_dir else: FLAGS.id_dir = os.path.join(FLAGS.data_dir, FLAGS.id_dir).format(FLAGS.item_id) if FLAGS.attn_size>0: FLAGS.mean_pool = False if FLAGS.attn_type<0: FLAGS.attn_type=0 if FLAGS.embed_type=='word': FLAGS.model_std = None FLAGS.attn_std = None #### test ids test_ids, test_id_file = None, None FLAGS.test_y, FLAGS.test_yint = None, None if FLAGS.test_pat is None: FLAGS.save_test = None FLAGS.load_test = None else: trait = '' if FLAGS.trait is not None: trait = '_{}'.format(FLAGS.trait) test_id_file = os.path.join(FLAGS.data_dir, FLAGS.test_pat).format(FLAGS.item_id, trait) ######################### if FLAGS.load_test and U.check_file(test_id_file): ################################# data = U.read_cols(test_id_file) test_ids = data[:,0] if test_ids.dtype.name.startswith('float'): test_ids = test_ids.astype('int32') test_ids = test_ids.astype('unicode') if data.shape[1]>1: FLAGS.test_yint = data[:,1].astype('int32') FLAGS.test_y = data[:,2].astype('float32') ######################### if FLAGS.save_test and test_ids is not None: FLAGS.save_test = False # FLAGS.test_ids = set(test_ids) if test_ids is not None else [] FLAGS.test_ids = test_ids if test_ids is not None else [] FLAGS.test_id_file = test_id_file ''' don't overwrite MLT test ids!!! ''' if 'test_ids' in FLAGS.test_id_file: FLAGS.save_test = False #### valid ids valid_ids, valid_id_file = None, None if FLAGS.valid_pat is None: FLAGS.save_valid = None FLAGS.load_valid = None else: trait = '' if FLAGS.trait is not None: trait = '_{}'.format(FLAGS.trait) valid_id_file = os.path.join(FLAGS.data_dir, FLAGS.valid_pat).format(FLAGS.item_id, trait) if FLAGS.load_valid: valid_ids = get_ids(valid_id_file) if FLAGS.save_valid and valid_ids is not None: FLAGS.save_valid = False #FLAGS.valid_ids = set(valid_ids) if valid_ids is not None else [] FLAGS.valid_ids = valid_ids if valid_ids is not None else [] FLAGS.valid_id_file = valid_id_file #### train ids train_ids, train_id_file =None, None if FLAGS.train_pat: trait = '' if FLAGS.trait is not None: trait = '_{}'.format(FLAGS.trait) train_id_file = os.path.join(FLAGS.data_dir, FLAGS.train_pat).format(FLAGS.item_id, trait) train_ids = get_ids(train_id_file, default=[]) #FLAGS.train_ids = set(train_ids) if train_ids is not None else [] FLAGS.train_ids = train_ids if train_ids is not None else [] FLAGS.train_id_file = train_id_file ################################### FLAGS.embed = U.adict({'type':FLAGS.embed_type, 'char':FLAGS.embed_type=='char', 'word':FLAGS.embed_type=='word' }) FLAGS.word_embed_dir = os.path.join(FLAGS.embed_dir, 'word') FLAGS.char_embed_dir = os.path.join(FLAGS.embed_dir, 'char') feats = ['kernel_widths','kernel_features','rnn_cells','rnn_sizes','rnn_bis','attn_sizes','attn_depths','attn_temps','pads','learning_rates'] for feat in feats: if feat in FLAGS and FLAGS[feat]: FLAGS[feat] = eval(eval(FLAGS[feat])) FLAGS.wpad = FLAGS.pads[0] FLAGS.spad = None if len(FLAGS.pads)<2 else FLAGS.pads[1] if FLAGS.attn_depths[0]>1 or (len(FLAGS.attn_depths)>1 and FLAGS.attn_depths[1]>1): FLAGS.attn_vis=False if FLAGS.attn_sizes[0]<1: FLAGS.attn_vis=False if FLAGS.embed.char: FLAGS.attn_vis=False ################################### return FLAGS
def bilstm_doc_enc(input_cnn, question_cnn, batch_size=20, num_rnn_layers=2, rnn_size=150, max_doc_length=35, dropout=0.0): # bilstm document encoder with tf.variable_scope('BILSTMenc'): def create_rnn_cell(): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size, state_is_tuple=True, forget_bias=0.0) if dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1. - dropout) return cell if num_rnn_layers > 1: cell_fw = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) cell_bw = tf.contrib.rnn.MultiRNNCell( [create_rnn_cell() for _ in range(num_rnn_layers)], state_is_tuple=True) else: cell_fw = create_rnn_cell() cell_bw = create_rnn_cell() initial_rnn_state_fw = cell_fw.zero_state(batch_size, dtype=tf.float32) initial_rnn_state_bw = cell_bw.zero_state(batch_size, dtype=tf.float32) input_cnn = tf.reshape(input_cnn, [batch_size, max_doc_length, -1]) input_cnn2 = [ tf.squeeze(x, [1]) for x in tf.split(input_cnn, max_doc_length, 1) ] question_cnn = tf.reshape(question_cnn, [batch_size, max_doc_length, -1]) question_cnn2 = [ tf.squeeze(x, [1]) for x in tf.split(question_cnn, max_doc_length, 1) ] print(input_cnn2.get_shape(), question_cnn2.get_shape()) with tf.variable_scope('Sentence-level_Compare-Aggregate'): input_cnn2, question_cnn2 = sent_compare_aggregate( input_cnn2, question_cnn2, batch_size, max_doc_length) outputs, final_rnn_state_fw, final_rnn_state_bw = tf.contrib.rnn.static_bidirectional_rnn( cell_fw, cell_bw, input_cnn2, initial_state_fw=initial_rnn_state_fw, initial_state_bw=initial_rnn_state_bw, dtype=tf.float32) return adict(initial_enc_state_fw=initial_rnn_state_fw, initial_enc_state_bw=initial_rnn_state_bw, final_enc_state_fw=final_rnn_state_fw, final_enc_state_bw=final_rnn_state_bw, enc_outputs=outputs)