Esempio n. 1
0
def create_seq_data_graph(in_data, out_data, prefix='decoder'):
    x_arr, x_len = util.hstack_list(in_data, padding=0, dtype=np.int32)
    y_arr, y_len = util.hstack_list(out_data, padding=0, dtype=np.int32)
    seq_weight = np.where(y_len > 0, 1, 0).astype(np.float32)
    token_weight, num_tokens = util.masked_full_like(y_arr,
                                                     1,
                                                     num_non_padding=y_len)
    all_x = tf.constant(x_arr.T, name='data_input')
    all_y = tf.constant(y_arr.T, name='data_label')
    all_len = tf.constant(x_len, name='data_len')
    all_seq_weight = tf.constant(seq_weight, name='data_seq_weight')
    all_token_weight = tf.constant(token_weight.T, name='data_token_weight')
    batch_idx_ = tf.placeholder(tf.int32,
                                shape=[None],
                                name=f'{prefix}_batch_idx')
    input_ = tf.transpose(tf.gather(all_x, batch_idx_, name=f'{prefix}_input'))
    label_ = tf.transpose(tf.gather(all_y, batch_idx_, name=f'{prefix}_label'))
    seq_len_ = tf.gather(all_len, batch_idx_, name=f'{prefix}_seq_len')
    seq_weight_ = tf.gather(all_seq_weight,
                            batch_idx_,
                            name=f'{prefix}_seq_weight')
    token_weight_ = tf.transpose(
        tf.gather(all_token_weight, batch_idx_, name=f'{prefix}_token_weight'))
    return {
        f'{prefix}_{k}': v
        for k, v in util.dict_with_key_endswith(locals(), '_').items()
    }
Esempio n. 2
0
def lseq2seq_batch_iter(enc_data,
                        dec_data,
                        label_data,
                        mask_data,
                        batch_size=1,
                        shuffle=True):
    """same as seq2seq_batch_iter, just add label"""
    data_tuple = (enc_data, dec_data, label_data, mask_data)
    for x, y, L, M in batch_iter(batch_size,
                                 shuffle,
                                 *data_tuple,
                                 pad=[[], [], 0, 2]):
        enc, enc_len = util.hstack_list(x)
        dec, dec_len = util.hstack_list(y)
        label = np.array(L, dtype=np.int32)
        mask = np.array(M, dtype=np.int32)
        in_dec = dec[:-1, :]
        out_dec = dec[1:, :]
        seq_weight = np.where(dec_len > 0, 1, 0)
        dec_len -= seq_weight
        token_weight, num_tokens = util.masked_full_like(
            out_dec, 1, num_non_padding=dec_len)
        seq_weight = seq_weight.astype(np.float32)
        features = ds.LSeq2SeqFeatureTuple(enc, enc_len, in_dec, dec_len,
                                           label, mask)
        labels = ds.SeqLabelTuple(out_dec, token_weight, seq_weight)
        yield ds.BatchTuple(features, labels, num_tokens, False)
Esempio n. 3
0
def _format_word2def(x, w, c, y, sw):
    enc, enc_len = util.hstack_list(x)
    dec, dec_len = util.hstack_list(y)
    word = np.array(w, dtype=np.int32)
    char, char_len = util.vstack_list(c)
    in_dec = dec[:-1, :]
    out_dec = dec[1:, :]
    seq_weight = np.array(sw, dtype=np.float32)
    dec_len -= np.where(dec_len > 0, 1, 0)
    token_weight, num_tokens = util.masked_full_like(out_dec,
                                                     1,
                                                     num_non_padding=dec_len)
    seq_weight = seq_weight.astype(np.float32)
    features = ds.Word2DefFeatureTuple(enc, enc_len, word, char, char_len,
                                       in_dec, dec_len)
    labels = ds.SeqLabelTuple(out_dec, token_weight, seq_weight)
    return ds.BatchTuple(features, labels, num_tokens, False)
Esempio n. 4
0
def seq2seq_batch_iter(enc_data, dec_data, batch_size=1, shuffle=True):
    """wrapper of batch_iter to format seq2seq data"""
    for x, y in batch_iter(batch_size,
                           shuffle,
                           enc_data,
                           dec_data,
                           pad=[[], []]):
        enc, enc_len = util.hstack_list(x)
        dec, dec_len = util.hstack_list(y)
        in_dec = dec[:-1, :]
        out_dec = dec[1:, :]
        seq_weight = np.where(dec_len > 0, 1, 0)
        dec_len -= seq_weight
        token_weight, num_tokens = util.masked_full_like(
            out_dec, 1, num_non_padding=dec_len)
        seq_weight = seq_weight.astype(np.float32)
        features = ds.Seq2SeqFeatureTuple(enc, enc_len, in_dec, dec_len)
        labels = ds.SeqLabelTuple(out_dec, token_weight, seq_weight)
        yield ds.BatchTuple(features, labels, num_tokens, False)
Esempio n. 5
0
def seq_batch_iter(in_data,
                   out_data,
                   weights,
                   batch_size=1,
                   shuffle=True,
                   keep_sentence=True):
    """wrapper of batch_iter to format seq data"""
    keep_state = not keep_sentence
    # add one more argumennt and pass it to "batch_iter" below
    # also add 0 for the padding
    if weights:
        # import pdb; pdb.set_trace()
        for x, y, w in batch_iter(batch_size,
                                  shuffle,
                                  in_data,
                                  out_data,
                                  weights,
                                  pad=[[], [], 0]):
            x_arr, x_len = util.hstack_list(x)
            y_arr, y_len = util.hstack_list(y)
            # w_arr, w_len = util.hstack_list(w)
            # change seq_weight to be the input weight
            seq_weight = np.where(y_len > 0, w, 0).astype(np.float32)
            # import pdb; pdb.set_trace()
            token_weight, num_tokens = util.masked_full_like(
                y_arr, w, num_non_padding=y_len)
            features = ds.SeqFeatureTuple(x_arr, x_len)
            labels = ds.SeqLabelTuple(y_arr, token_weight, seq_weight)
            yield ds.BatchTuple(features, labels, num_tokens, keep_state)
    else:
        for x, y in batch_iter(batch_size,
                               shuffle,
                               in_data,
                               out_data,
                               pad=[[], []]):
            x_arr, x_len = util.hstack_list(x)
            y_arr, y_len = util.hstack_list(y)
            seq_weight = np.where(y_len > 0, 1, 0).astype(np.float32)
            token_weight, num_tokens = util.masked_full_like(
                y_arr, 1, num_non_padding=y_len)
            features = ds.SeqFeatureTuple(x_arr, x_len)
            labels = ds.SeqLabelTuple(y_arr, token_weight, seq_weight)
            yield ds.BatchTuple(features, labels, num_tokens, keep_state)
Esempio n. 6
0
 def test_hstack_list(self):
     inputs = [[1, 2, 3, 4], [5, 6], [7, 8, 9, 10, 11], []]
     targets = np.array([[1, 2, 3, 4, 0],
                         [5, 6, 0, 0, 0],
                         [7, 8, 9, 10, 11],
                         [0, 0, 0, 0, 0]], dtype=np.int32).T
     outputs, lengths = util.hstack_list(inputs, padding=0, dtype=np.int32)
     self.assertTrue(np.all(outputs == targets), 'data is correct')
     self.assertTrue(np.all(lengths == np.array(list(map(len, inputs)),
                                                dtype=np.int32)), 'length is correct')
Esempio n. 7
0
def seq_batch_iter(in_data,
                   out_data,
                   batch_size=1,
                   shuffle=True,
                   keep_sentence=True):
    """wrapper of batch_iter to format seq data"""
    keep_state = not keep_sentence
    for x, y in batch_iter(batch_size,
                           shuffle,
                           in_data,
                           out_data,
                           pad=[[], []]):
        x_arr, x_len = util.hstack_list(x)
        y_arr, y_len = util.hstack_list(y)
        seq_weight = np.where(y_len > 0, 1, 0).astype(np.float32)
        token_weight, num_tokens = util.masked_full_like(y_arr,
                                                         1,
                                                         num_non_padding=y_len)
        features = ds.SeqFeatureTuple(x_arr, x_len)
        labels = ds.SeqLabelTuple(y_arr, token_weight, seq_weight)
        yield ds.BatchTuple(features, labels, num_tokens, keep_state)