def create_seq_data_graph(in_data, out_data, prefix='decoder'): x_arr, x_len = util.hstack_list(in_data, padding=0, dtype=np.int32) y_arr, y_len = util.hstack_list(out_data, padding=0, dtype=np.int32) seq_weight = np.where(y_len > 0, 1, 0).astype(np.float32) token_weight, num_tokens = util.masked_full_like(y_arr, 1, num_non_padding=y_len) all_x = tf.constant(x_arr.T, name='data_input') all_y = tf.constant(y_arr.T, name='data_label') all_len = tf.constant(x_len, name='data_len') all_seq_weight = tf.constant(seq_weight, name='data_seq_weight') all_token_weight = tf.constant(token_weight.T, name='data_token_weight') batch_idx_ = tf.placeholder(tf.int32, shape=[None], name=f'{prefix}_batch_idx') input_ = tf.transpose(tf.gather(all_x, batch_idx_, name=f'{prefix}_input')) label_ = tf.transpose(tf.gather(all_y, batch_idx_, name=f'{prefix}_label')) seq_len_ = tf.gather(all_len, batch_idx_, name=f'{prefix}_seq_len') seq_weight_ = tf.gather(all_seq_weight, batch_idx_, name=f'{prefix}_seq_weight') token_weight_ = tf.transpose( tf.gather(all_token_weight, batch_idx_, name=f'{prefix}_token_weight')) return { f'{prefix}_{k}': v for k, v in util.dict_with_key_endswith(locals(), '_').items() }
def _build_wbdef(self, opt, reuse, enc_nodes, enc_scope, collect_key): prefix = 'wbdef' wbdef_opt = util.dict_with_key_startswith(opt, 'wbdef:') with tf.variable_scope('wbdef', reuse=reuse) as wbdef_scope: with tfg.tfph_collection(f'{collect_key}_wbdef', True) as get: wbdef_word_ = get(f'{prefix}_word', tf.int32, (None,)) wbdef_char_ = get(f'{prefix}_char', tf.int32, (None, None)) wbdef_char_len_ = get(f'{prefix}_char_len', tf.int32, (None,)) word_emb_scope = enc_scope # if opt['share:enc_word_emb'] else None word_emb_opt = util.dict_with_key_startswith(opt, 'enc:emb:') with tfg.maybe_scope(word_emb_scope, True): wbdef_word_lookup_, _e = tfg.create_lookup( wbdef_word_, prefix='wbdef_word', **word_emb_opt) char_emb_opt = util.dict_with_key_startswith(wbdef_opt, 'char_emb:') wbdef_char_lookup_, wbdef_char_emb_vars_ = tfg.create_lookup( wbdef_char_, **char_emb_opt) char_tdnn_opt = util.dict_with_key_startswith(wbdef_opt, 'char_tdnn:') wbdef_char_tdnn_ = tfg.create_tdnn(wbdef_char_lookup_, wbdef_char_len_, **char_tdnn_opt) wbdef_ = tf.concat((wbdef_word_lookup_, wbdef_char_tdnn_), axis=-1) if wbdef_opt['keep_prob'] < 1.0: wbdef_ = tf.nn.dropout(wbdef_, opt['wbdef:keep_prob']) nodes = util.dict_with_key_endswith(locals(), '_') # add param to super()_build_logit self._build_logit = partial(self._build_attn_logit, wbdef=wbdef_, wbdef_scope=wbdef_scope, wbdef_nodes=nodes, full_opt=opt, reuse=reuse) self._decode_late_attn = partial(self._build_decode_late_attn, wbdef=wbdef_, wbdef_scope=wbdef_scope) return enc_nodes['final_state'], nodes
def _build(self, opt, reuse_scope, initial_state=None, reuse=False, collect_key='seq_model', prefix='lm', **kwargs): collect_kwargs = {'add_to_collection': True, 'collect_key': collect_key, 'prefix': prefix} # input and embedding input_, seq_len_ = tfg.get_seq_input_placeholders(**collect_kwargs) emb_opt = util.dict_with_key_startswith(opt, 'emb:') with tfg.maybe_scope(reuse_scope[self._RSK_EMB_], reuse=True) as scope: lookup_, emb_vars_ = tfg.create_lookup(input_, **emb_opt) # cell and rnn cell_opt = util.dict_with_key_startswith(opt, 'cell:') with tfg.maybe_scope(reuse_scope[self._RSK_RNN_], reuse=True) as scope: _reuse = reuse or scope is not None cell_ = tfg.create_cells(reuse=_reuse, input_size=opt['emb:dim'], **cell_opt) cell_output_, initial_state_, final_state_ = tfg.create_rnn( cell_, lookup_, seq_len_, initial_state, rnn_fn=opt['rnn:fn']) predict_fetch = {'cell_output': cell_output_} nodes = util.dict_with_key_endswith(locals(), '_') graph_args = {'feature_feed': dstruct.SeqFeatureTuple(input_, seq_len_), 'predict_fetch': predict_fetch, 'node_dict': nodes, 'state_feed': initial_state_, 'state_fetch': final_state_} # output if opt['out:logit']: logit, label_feed, output_fectch, output_nodes = self._build_logit( opt, reuse_scope, collect_kwargs, emb_vars_, cell_output_) predict_fetch.update(output_fectch) nodes.update(output_nodes) graph_args.update(label_feed=label_feed) # loss if opt['out:loss'] and opt['out:logit']: train_fetch, eval_fetch, loss_nodes = self._build_loss( opt, logit, *label_feed, collect_key, collect_kwargs['add_to_collection']) nodes.update(loss_nodes) graph_args.update(train_fetch=train_fetch, eval_fetch=eval_fetch) elif not opt['out:logit'] and opt['out:loss']: raise ValueError('out:logit is False, cannot build loss graph') # decode if opt['out:decode'] and opt['out:logit']: if not (opt['decode:add_greedy'] or opt['decode:add_sampling']): assert ValueError(('Both decode:add_greedy and decode:add_sampling are ' ' False. out:decode should not be True.')) decode_result, decode_nodes = self._build_decoder( opt, nodes, reuse_scope[self._RSK_RNN_], collect_key, collect_kwargs['add_to_collection']) predict_fetch.update(decode_result) nodes.update(decode_nodes) elif not opt['out:logit'] and opt['out:decode']: raise ValueError('out:logit is False, cannot build decode graph') return nodes, graph_args
def _build_loss(self, opt, logit, label, weight, seq_weight, collect_key, add_to_collection): if opt['loss:type'] == 'xent': with tfg.tfph_collection(collect_key, add_to_collection) as get: name = 'train_loss_denom' train_loss_denom_ = get(name, tf.float32, shape=[]) mean_loss_, train_loss_, batch_loss_, nll_ = tfg.create_xent_loss( logit, label, weight, seq_weight, train_loss_denom_) if opt['loss:add_entropy']: _sum_minus_ent, minus_avg_ent_ = tfg.create_ent_loss( tf.nn.softmax(logit), tf.abs(weight), tf.abs(seq_weight)) train_loss_ = train_loss_ + minus_avg_ent_ train_fetch = {'train_loss': train_loss_, 'eval_loss': mean_loss_} eval_fetch = {'eval_loss': mean_loss_} else: raise ValueError(f'{opt["loss:type"]} is not supported, use (xent or mse)') nodes = util.dict_with_key_endswith(locals(), '_') return train_fetch, eval_fetch, nodes
def _build_logit(self, opt, reuse_scope, collect_kwargs, emb_vars, cell_output): # logit logit_w_ = emb_vars if opt['share:input_emb_logit'] else None logit_opt = util.dict_with_key_startswith(opt, 'logit:') with tfg.maybe_scope(reuse_scope[self._RSK_LOGIT_]) as scope: logit_, temperature_, logit_w_, logit_b_ = tfg.get_logit_layer( cell_output, logit_w=logit_w_, **logit_opt, **collect_kwargs) dist_, dec_max_, dec_sample_ = tfg.select_from_logit(logit_) # label label_, token_weight_, seq_weight_ = tfg.get_seq_label_placeholders( label_dtype=tf.int32, **collect_kwargs) # format predict_fetch = { 'logit': logit_, 'dist': dist_, 'dec_max': dec_max_, 'dec_max_id': dec_max_.index, 'dec_sample': dec_sample_, 'dec_sample_id': dec_sample_.index} label_feed = dstruct.SeqLabelTuple(label_, token_weight_, seq_weight_) nodes = util.dict_with_key_endswith(locals(), '_') return logit_, label_feed, predict_fetch, nodes
def _build_attn_logit(self, opt, reuse_scope, collect_kwargs, emb_vars, cell_output, wbdef, wbdef_scope, wbdef_nodes, full_opt, reuse): wbdef_nodes = {} if wbdef_nodes is None else wbdef_nodes with tfg.maybe_scope(wbdef_scope, reuse): _multiples = [tf.shape(cell_output)[0], 1, 1] tiled_wbdef_ = tf.tile(tf.expand_dims(wbdef, 0), _multiples) carried_output = cell_output if full_opt['dec:cell:out_keep_prob'] < 1.0: # no variational dropout :( cell_output = tf.nn.dropout( cell_output, full_opt['dec:cell:out_keep_prob']) updated_output_, attention_ = tfg.create_gru_layer( cell_output, tiled_wbdef_, carried_output) if full_opt['dec:cell:out_keep_prob'] < 1.0: # no variational dropout :( updated_output_ = tf.nn.dropout( updated_output_, full_opt['dec:cell:out_keep_prob']) wbdef_nodes.update(util.dict_with_key_endswith(locals(), '_')) wbdef_nodes.pop('__class_', None) logit_, label_feed, predict_fetch, nodes = super()._build_logit( opt, reuse_scope, collect_kwargs, emb_vars, updated_output_) return logit_, label_feed, predict_fetch, nodes
def _build_decoder(self, opt, nodes, cell_scope, collect_key, add_to_collection, start_id=1, end_id=0): output = {} with tfg.tfph_collection(collect_key, add_to_collection) as get: decode_max_len_ = get('decode_max_len', tf.int32, None) if hasattr(self, '_batch_size'): batch_size = self._batch_size else: batch_size = tf.shape(nodes['input'])[1] late_attn_fn = None if hasattr(self, '_decode_late_attn'): late_attn_fn = self._decode_late_attn decode_fn = partial( tfg.create_decode, nodes['emb_vars'], nodes['cell'], nodes['logit_w'], nodes['initial_state'], tf.tile((1, ), (batch_size, )), tf.tile([False], (batch_size, )), logit_b=nodes['logit_b'], logit_temperature=nodes['temperature'], max_len=decode_max_len_, cell_scope=cell_scope, late_attn_fn=late_attn_fn) if opt['decode:add_greedy']: decode_greedy_, decode_greedy_score_, decode_greedy_len_ = decode_fn() output['decode_greedy'] = decode_greedy_ output['decode_greedy_score'] = (decode_greedy_, decode_greedy_score_) output['decode_greedy_len'] = decode_greedy_len_ if opt['decode:add_sampling']: def select_fn(logit): idx = tf.cast(tf.multinomial(logit, 1), tf.int32) gather_idx = tf.expand_dims( tf.range(start=0, limit=tf.shape(idx)[0]), axis=-1) gather_idx = tf.concat([gather_idx, idx], axis=-1) score = tf.gather_nd(tf.nn.log_softmax(logit), gather_idx) idx = tf.squeeze(idx, axis=(1, )) return idx, score decode_sampling_, decode_sampling_score_, decode_sampling_len_ = decode_fn( select_fn=select_fn) output['decode_sampling'] = decode_sampling_ output['decode_sampling_score'] = (decode_sampling_, decode_sampling_score_) output['decode_sampling_len'] = decode_sampling_len_ nodes = util.dict_with_key_endswith(locals(), '_') return output, nodes
def test_dict_with_key_endswith(self): people = {'biggus_': 'dickus', 'incontinentia_': 'buttocks', 'jew:jesus': 'christ', 'jew:brian': 'cohen'} roman = {'biggus': 'dickus', 'incontinentia': 'buttocks'} roman_ = util.dict_with_key_endswith(people, '_') self.assertEqual(roman, roman_, 'filter and remove suffix')