Ejemplo n.º 1
0
def create_seq_data_graph(in_data, out_data, prefix='decoder'):
    x_arr, x_len = util.hstack_list(in_data, padding=0, dtype=np.int32)
    y_arr, y_len = util.hstack_list(out_data, padding=0, dtype=np.int32)
    seq_weight = np.where(y_len > 0, 1, 0).astype(np.float32)
    token_weight, num_tokens = util.masked_full_like(y_arr,
                                                     1,
                                                     num_non_padding=y_len)
    all_x = tf.constant(x_arr.T, name='data_input')
    all_y = tf.constant(y_arr.T, name='data_label')
    all_len = tf.constant(x_len, name='data_len')
    all_seq_weight = tf.constant(seq_weight, name='data_seq_weight')
    all_token_weight = tf.constant(token_weight.T, name='data_token_weight')
    batch_idx_ = tf.placeholder(tf.int32,
                                shape=[None],
                                name=f'{prefix}_batch_idx')
    input_ = tf.transpose(tf.gather(all_x, batch_idx_, name=f'{prefix}_input'))
    label_ = tf.transpose(tf.gather(all_y, batch_idx_, name=f'{prefix}_label'))
    seq_len_ = tf.gather(all_len, batch_idx_, name=f'{prefix}_seq_len')
    seq_weight_ = tf.gather(all_seq_weight,
                            batch_idx_,
                            name=f'{prefix}_seq_weight')
    token_weight_ = tf.transpose(
        tf.gather(all_token_weight, batch_idx_, name=f'{prefix}_token_weight'))
    return {
        f'{prefix}_{k}': v
        for k, v in util.dict_with_key_endswith(locals(), '_').items()
    }
Ejemplo n.º 2
0
 def _build_wbdef(self, opt, reuse, enc_nodes, enc_scope, collect_key):
     prefix = 'wbdef'
     wbdef_opt = util.dict_with_key_startswith(opt, 'wbdef:')
     with tf.variable_scope('wbdef', reuse=reuse) as wbdef_scope:
         with tfg.tfph_collection(f'{collect_key}_wbdef', True) as get:
             wbdef_word_ = get(f'{prefix}_word', tf.int32, (None,))
             wbdef_char_ = get(f'{prefix}_char', tf.int32, (None, None))
             wbdef_char_len_ = get(f'{prefix}_char_len', tf.int32, (None,))
         word_emb_scope = enc_scope   # if opt['share:enc_word_emb'] else None
         word_emb_opt = util.dict_with_key_startswith(opt, 'enc:emb:')
         with tfg.maybe_scope(word_emb_scope, True):
             wbdef_word_lookup_, _e = tfg.create_lookup(
                 wbdef_word_, prefix='wbdef_word', **word_emb_opt)
         char_emb_opt = util.dict_with_key_startswith(wbdef_opt, 'char_emb:')
         wbdef_char_lookup_, wbdef_char_emb_vars_ = tfg.create_lookup(
             wbdef_char_, **char_emb_opt)
         char_tdnn_opt = util.dict_with_key_startswith(wbdef_opt, 'char_tdnn:')
         wbdef_char_tdnn_ = tfg.create_tdnn(wbdef_char_lookup_, wbdef_char_len_,
                                            **char_tdnn_opt)
         wbdef_ = tf.concat((wbdef_word_lookup_, wbdef_char_tdnn_), axis=-1)
         if wbdef_opt['keep_prob'] < 1.0:
             wbdef_ = tf.nn.dropout(wbdef_, opt['wbdef:keep_prob'])
     nodes = util.dict_with_key_endswith(locals(), '_')
     # add param to super()_build_logit
     self._build_logit = partial(self._build_attn_logit, wbdef=wbdef_,
                                 wbdef_scope=wbdef_scope, wbdef_nodes=nodes,
                                 full_opt=opt, reuse=reuse)
     self._decode_late_attn = partial(self._build_decode_late_attn, wbdef=wbdef_,
                                      wbdef_scope=wbdef_scope)
     return enc_nodes['final_state'], nodes
Ejemplo n.º 3
0
    def _build(self, opt, reuse_scope, initial_state=None, reuse=False,
               collect_key='seq_model', prefix='lm', **kwargs):
        collect_kwargs = {'add_to_collection': True, 'collect_key': collect_key,
                          'prefix': prefix}
        # input and embedding
        input_, seq_len_ = tfg.get_seq_input_placeholders(**collect_kwargs)
        emb_opt = util.dict_with_key_startswith(opt, 'emb:')
        with tfg.maybe_scope(reuse_scope[self._RSK_EMB_], reuse=True) as scope:
            lookup_, emb_vars_ = tfg.create_lookup(input_, **emb_opt)
        # cell and rnn
        cell_opt = util.dict_with_key_startswith(opt, 'cell:')
        with tfg.maybe_scope(reuse_scope[self._RSK_RNN_], reuse=True) as scope:
            _reuse = reuse or scope is not None
            cell_ = tfg.create_cells(reuse=_reuse, input_size=opt['emb:dim'], **cell_opt)
            cell_output_, initial_state_, final_state_ = tfg.create_rnn(
                cell_, lookup_, seq_len_, initial_state, rnn_fn=opt['rnn:fn'])
        predict_fetch = {'cell_output': cell_output_}
        nodes = util.dict_with_key_endswith(locals(), '_')
        graph_args = {'feature_feed': dstruct.SeqFeatureTuple(input_, seq_len_),
                      'predict_fetch': predict_fetch, 'node_dict': nodes,
                      'state_feed': initial_state_, 'state_fetch': final_state_}
        # output
        if opt['out:logit']:
            logit, label_feed, output_fectch, output_nodes = self._build_logit(
                opt, reuse_scope, collect_kwargs, emb_vars_, cell_output_)
            predict_fetch.update(output_fectch)
            nodes.update(output_nodes)
            graph_args.update(label_feed=label_feed)
        # loss
        if opt['out:loss'] and opt['out:logit']:
            train_fetch, eval_fetch, loss_nodes = self._build_loss(
                opt, logit, *label_feed, collect_key,
                collect_kwargs['add_to_collection'])
            nodes.update(loss_nodes)
            graph_args.update(train_fetch=train_fetch, eval_fetch=eval_fetch)
        elif not opt['out:logit'] and opt['out:loss']:
            raise ValueError('out:logit is False, cannot build loss graph')
        # decode
        if opt['out:decode'] and opt['out:logit']:
            if not (opt['decode:add_greedy'] or opt['decode:add_sampling']):
                assert ValueError(('Both decode:add_greedy and decode:add_sampling are '
                                   ' False. out:decode should not be True.'))
            decode_result, decode_nodes = self._build_decoder(
                opt, nodes, reuse_scope[self._RSK_RNN_], collect_key,
                collect_kwargs['add_to_collection'])
            predict_fetch.update(decode_result)
            nodes.update(decode_nodes)
        elif not opt['out:logit'] and opt['out:decode']:
            raise ValueError('out:logit is False, cannot build decode graph')

        return nodes, graph_args
Ejemplo n.º 4
0
    def _build_loss(self, opt, logit, label, weight, seq_weight,
                    collect_key, add_to_collection):
        if opt['loss:type'] == 'xent':
            with tfg.tfph_collection(collect_key, add_to_collection) as get:
                name = 'train_loss_denom'
                train_loss_denom_ = get(name, tf.float32, shape=[])
            mean_loss_, train_loss_, batch_loss_, nll_ = tfg.create_xent_loss(
                logit, label, weight, seq_weight, train_loss_denom_)
            if opt['loss:add_entropy']:
                _sum_minus_ent, minus_avg_ent_ = tfg.create_ent_loss(
                    tf.nn.softmax(logit), tf.abs(weight), tf.abs(seq_weight))
                train_loss_ = train_loss_ + minus_avg_ent_
            train_fetch = {'train_loss': train_loss_, 'eval_loss': mean_loss_}
            eval_fetch = {'eval_loss': mean_loss_}

        else:
            raise ValueError(f'{opt["loss:type"]} is not supported, use (xent or mse)')
        nodes = util.dict_with_key_endswith(locals(), '_')
        return train_fetch, eval_fetch, nodes
Ejemplo n.º 5
0
 def _build_logit(self, opt, reuse_scope, collect_kwargs, emb_vars, cell_output):
     # logit
     logit_w_ = emb_vars if opt['share:input_emb_logit'] else None
     logit_opt = util.dict_with_key_startswith(opt, 'logit:')
     with tfg.maybe_scope(reuse_scope[self._RSK_LOGIT_]) as scope:
         logit_, temperature_, logit_w_, logit_b_ = tfg.get_logit_layer(
             cell_output, logit_w=logit_w_, **logit_opt, **collect_kwargs)
     dist_, dec_max_, dec_sample_ = tfg.select_from_logit(logit_)
     # label
     label_, token_weight_, seq_weight_ = tfg.get_seq_label_placeholders(
         label_dtype=tf.int32, **collect_kwargs)
     # format
     predict_fetch = {
         'logit': logit_, 'dist': dist_, 'dec_max': dec_max_,
         'dec_max_id': dec_max_.index, 'dec_sample': dec_sample_,
         'dec_sample_id': dec_sample_.index}
     label_feed = dstruct.SeqLabelTuple(label_, token_weight_, seq_weight_)
     nodes = util.dict_with_key_endswith(locals(), '_')
     return logit_, label_feed, predict_fetch, nodes
Ejemplo n.º 6
0
 def _build_attn_logit(self, opt, reuse_scope, collect_kwargs, emb_vars, cell_output,
                       wbdef, wbdef_scope, wbdef_nodes, full_opt, reuse):
     wbdef_nodes = {} if wbdef_nodes is None else wbdef_nodes
     with tfg.maybe_scope(wbdef_scope, reuse):
         _multiples = [tf.shape(cell_output)[0], 1, 1]
         tiled_wbdef_ = tf.tile(tf.expand_dims(wbdef, 0), _multiples)
         carried_output = cell_output
         if full_opt['dec:cell:out_keep_prob'] < 1.0:  # no variational dropout :(
             cell_output = tf.nn.dropout(
                 cell_output, full_opt['dec:cell:out_keep_prob'])
         updated_output_, attention_ = tfg.create_gru_layer(
             cell_output, tiled_wbdef_, carried_output)
         if full_opt['dec:cell:out_keep_prob'] < 1.0:  # no variational dropout :(
             updated_output_ = tf.nn.dropout(
                 updated_output_, full_opt['dec:cell:out_keep_prob'])
     wbdef_nodes.update(util.dict_with_key_endswith(locals(), '_'))
     wbdef_nodes.pop('__class_', None)
     logit_, label_feed, predict_fetch, nodes = super()._build_logit(
         opt, reuse_scope, collect_kwargs, emb_vars, updated_output_)
     return logit_, label_feed, predict_fetch, nodes
Ejemplo n.º 7
0
 def _build_decoder(self, opt, nodes, cell_scope, collect_key,
                    add_to_collection, start_id=1, end_id=0):
     output = {}
     with tfg.tfph_collection(collect_key, add_to_collection) as get:
         decode_max_len_ = get('decode_max_len', tf.int32, None)
     if hasattr(self, '_batch_size'):
         batch_size = self._batch_size
     else:
         batch_size = tf.shape(nodes['input'])[1]
     late_attn_fn = None
     if hasattr(self, '_decode_late_attn'):
         late_attn_fn = self._decode_late_attn
     decode_fn = partial(
         tfg.create_decode, nodes['emb_vars'], nodes['cell'], nodes['logit_w'],
         nodes['initial_state'], tf.tile((1, ), (batch_size, )),
         tf.tile([False], (batch_size, )), logit_b=nodes['logit_b'],
         logit_temperature=nodes['temperature'], max_len=decode_max_len_,
         cell_scope=cell_scope, late_attn_fn=late_attn_fn)
     if opt['decode:add_greedy']:
         decode_greedy_, decode_greedy_score_, decode_greedy_len_ = decode_fn()
         output['decode_greedy'] = decode_greedy_
         output['decode_greedy_score'] = (decode_greedy_, decode_greedy_score_)
         output['decode_greedy_len'] = decode_greedy_len_
     if opt['decode:add_sampling']:
         def select_fn(logit):
             idx = tf.cast(tf.multinomial(logit, 1), tf.int32)
             gather_idx = tf.expand_dims(
                 tf.range(start=0, limit=tf.shape(idx)[0]), axis=-1)
             gather_idx = tf.concat([gather_idx, idx], axis=-1)
             score = tf.gather_nd(tf.nn.log_softmax(logit), gather_idx)
             idx = tf.squeeze(idx, axis=(1, ))
             return idx, score
         decode_sampling_, decode_sampling_score_, decode_sampling_len_ = decode_fn(
             select_fn=select_fn)
         output['decode_sampling'] = decode_sampling_
         output['decode_sampling_score'] = (decode_sampling_, decode_sampling_score_)
         output['decode_sampling_len'] = decode_sampling_len_
     nodes = util.dict_with_key_endswith(locals(), '_')
     return output, nodes
Ejemplo n.º 8
0
 def test_dict_with_key_endswith(self):
     people = {'biggus_': 'dickus', 'incontinentia_': 'buttocks',
               'jew:jesus': 'christ', 'jew:brian': 'cohen'}
     roman = {'biggus': 'dickus', 'incontinentia': 'buttocks'}
     roman_ = util.dict_with_key_endswith(people, '_')
     self.assertEqual(roman, roman_, 'filter and remove suffix')