Esempio n. 1
0
 def tokenize_fn(text):
     text = preprocess_text(text, lower=FLAGS.uncased)
     return encode_ids(sp, text)
Esempio n. 2
0
def convert_examples_to_features(
    examples,
    sp_model,
    max_seq_length,
    doc_stride,
    max_query_length,
    is_training,
):
    """Loads a data file into a list of `InputBatch`s."""

    cnt_pos, cnt_neg = 0, 0
    unique_id = 1000000000
    max_N, max_M = 1024, 1024
    f = np.zeros((max_N, max_M), dtype = np.float32)
    features = []

    for n in tqdm(range(len(examples))):
        example_index = n
        example = examples[n]

        query_tokens = encode_ids(
            sp_model, preprocess_text(example.question_text, lower = False)
        )

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        paragraph_text = example.paragraph_text
        para_tokens = encode_pieces(
            sp_model, preprocess_text(example.paragraph_text, lower = False)
        )

        chartok_to_tok_index = []
        tok_start_to_chartok_index = []
        tok_end_to_chartok_index = []
        char_cnt = 0
        for i, token in enumerate(para_tokens):
            chartok_to_tok_index.extend([i] * len(token))
            tok_start_to_chartok_index.append(char_cnt)
            char_cnt += len(token)
            tok_end_to_chartok_index.append(char_cnt - 1)

        tok_cat_text = ''.join(para_tokens).replace(SPIECE_UNDERLINE, ' ')
        N, M = len(paragraph_text), len(tok_cat_text)

        if N > max_N or M > max_M:
            max_N = max(N, max_N)
            max_M = max(M, max_M)
            f = np.zeros((max_N, max_M), dtype = np.float32)
            gc.collect()

        g = {}

        def _lcs_match(max_dist):
            f.fill(0)
            g.clear()

            ### longest common sub sequence
            # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
            for i in range(N):

                # note(zhiliny):
                # unlike standard LCS, this is specifically optimized for the setting
                # because the mismatch between sentence pieces and original text will
                # be small
                for j in range(i - max_dist, i + max_dist):
                    if j >= M or j < 0:
                        continue

                    if i > 0:
                        g[(i, j)] = 0
                        f[i, j] = f[i - 1, j]

                    if j > 0 and f[i, j - 1] > f[i, j]:
                        g[(i, j)] = 1
                        f[i, j] = f[i, j - 1]

                    f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
                    if (
                        preprocess_text(
                            paragraph_text[i],
                            lower = False,
                            remove_space = False,
                        )
                        == tok_cat_text[j]
                        and f_prev + 1 > f[i, j]
                    ):
                        g[(i, j)] = 2
                        f[i, j] = f_prev + 1

        max_dist = abs(N - M) + 5
        for _ in range(2):
            _lcs_match(max_dist)
            if f[N - 1, M - 1] > 0.8 * N:
                break
            max_dist *= 2

        orig_to_chartok_index = [None] * N
        chartok_to_orig_index = [None] * M
        i, j = N - 1, M - 1
        while i >= 0 and j >= 0:
            if (i, j) not in g:
                break
            if g[(i, j)] == 2:
                orig_to_chartok_index[i] = j
                chartok_to_orig_index[j] = i
                i, j = i - 1, j - 1
            elif g[(i, j)] == 1:
                j = j - 1
            else:
                i = i - 1

        if (
            all(v is None for v in orig_to_chartok_index)
            or f[N - 1, M - 1] < 0.8 * N
        ):
            print('MISMATCH DETECTED!')
            continue

        tok_start_to_orig_index = []
        tok_end_to_orig_index = []
        for i in range(len(para_tokens)):
            start_chartok_pos = tok_start_to_chartok_index[i]
            end_chartok_pos = tok_end_to_chartok_index[i]
            start_orig_pos = _convert_index(
                chartok_to_orig_index, start_chartok_pos, N, is_start = True
            )
            end_orig_pos = _convert_index(
                chartok_to_orig_index, end_chartok_pos, N, is_start = False
            )

            tok_start_to_orig_index.append(start_orig_pos)
            tok_end_to_orig_index.append(end_orig_pos)

        if not is_training:
            tok_start_position = tok_end_position = None

        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1

        if is_training and not example.is_impossible:
            start_position = example.start_position
            end_position = start_position + len(example.orig_answer_text) - 1

            start_chartok_pos = _convert_index(
                orig_to_chartok_index, start_position, is_start = True
            )
            tok_start_position = chartok_to_tok_index[start_chartok_pos]

            end_chartok_pos = _convert_index(
                orig_to_chartok_index, end_position, is_start = False
            )
            tok_end_position = chartok_to_tok_index[end_chartok_pos]
            assert tok_start_position <= tok_end_position

        def _piece_to_id(x):
            if six.PY2 and isinstance(x, unicode):
                x = x.encode('utf-8')
            return sp_model.PieceToId(x)

        all_doc_tokens = list(map(_piece_to_id, para_tokens))

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            'DocSpan', ['start', 'length']
        )
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start = start_offset, length = length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_is_max_context = {}
            segment_ids = []
            p_mask = []

            cur_tok_start_to_orig_index = []
            cur_tok_end_to_orig_index = []

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i

                cur_tok_start_to_orig_index.append(
                    tok_start_to_orig_index[split_token_index]
                )
                cur_tok_end_to_orig_index.append(
                    tok_end_to_orig_index[split_token_index]
                )

                is_max_context = _check_is_max_context(
                    doc_spans, doc_span_index, split_token_index
                )
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(SEG_ID_P)
                p_mask.append(0)

            paragraph_len = len(tokens)

            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_P)
            p_mask.append(1)

            # note(zhiliny): we put P before Q
            # because during pretraining, B is always shorter than A
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(SEG_ID_Q)
                p_mask.append(1)
            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_Q)
            p_mask.append(1)

            cls_index = len(segment_ids)
            tokens.append(CLS_ID)
            segment_ids.append(SEG_ID_CLS)
            p_mask.append(0)

            input_ids = tokens

            # The mask has 0 for real tokens and 1 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(1)
                segment_ids.append(SEG_ID_PAD)
                p_mask.append(1)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(p_mask) == max_seq_length

            span_is_impossible = example.is_impossible
            start_position = None
            end_position = None
            if is_training and not span_is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (
                    tok_start_position >= doc_start
                    and tok_end_position <= doc_end
                ):
                    out_of_span = True
                if out_of_span:
                    # continue
                    start_position = 0
                    end_position = 0
                    span_is_impossible = True
                else:
                    # note(zhiliny): we put P before Q, so doc_offset should be zero.
                    # doc_offset = len(query_tokens) + 2
                    doc_offset = 0
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and span_is_impossible:
                start_position = cls_index
                end_position = cls_index

            if example_index < 20:
                tf.logging.info('*** Example ***')
                tf.logging.info('unique_id: %s' % (unique_id))
                tf.logging.info('example_index: %s' % (example_index))
                tf.logging.info('doc_span_index: %s' % (doc_span_index))
                tf.logging.info(
                    'tok_start_to_orig_index: %s'
                    % ' '.join([str(x) for x in cur_tok_start_to_orig_index])
                )
                tf.logging.info(
                    'tok_end_to_orig_index: %s'
                    % ' '.join([str(x) for x in cur_tok_end_to_orig_index])
                )
                tf.logging.info(
                    'token_is_max_context: %s'
                    % ' '.join(
                        [
                            '%d:%s' % (x, y)
                            for (x, y) in six.iteritems(token_is_max_context)
                        ]
                    )
                )
                tf.logging.info(
                    'input_ids: %s' % ' '.join([str(x) for x in input_ids])
                )
                tf.logging.info(
                    'input_mask: %s' % ' '.join([str(x) for x in input_mask])
                )
                tf.logging.info(
                    'segment_ids: %s' % ' '.join([str(x) for x in segment_ids])
                )

                if is_training and span_is_impossible:
                    tf.logging.info('impossible example span')

                if is_training and not span_is_impossible:
                    pieces = [
                        sp_model.IdToPiece(token)
                        for token in tokens[start_position : (end_position + 1)]
                    ]
                    answer_text = sp_model.DecodePieces(pieces)
                    tf.logging.info('start_position: %d' % (start_position))
                    tf.logging.info('end_position: %d' % (end_position))
                    tf.logging.info(
                        'answer: %s' % (printable_text(answer_text))
                    )

                    # note(zhiliny): With multi processing,
                    # the example_index is actually the index within the current process
                    # therefore we use example_index=None to avoid being used in the future.
                    # The current code does not use example_index of training data.
            if is_training:
                feat_example_index = None
            else:
                feat_example_index = example_index

            feature = InputFeatures(
                unique_id = unique_id,
                example_index = feat_example_index,
                doc_span_index = doc_span_index,
                tok_start_to_orig_index = cur_tok_start_to_orig_index,
                tok_end_to_orig_index = cur_tok_end_to_orig_index,
                token_is_max_context = token_is_max_context,
                input_ids = input_ids,
                input_mask = input_mask,
                p_mask = p_mask,
                segment_ids = segment_ids,
                paragraph_len = paragraph_len,
                cls_index = cls_index,
                start_position = start_position,
                end_position = end_position,
                is_impossible = span_is_impossible,
            )

            features.append(feature)

            unique_id += 1
            if span_is_impossible:
                cnt_neg += 1
            else:
                cnt_pos += 1

    tf.logging.info(
        'Total number of instances: {} = pos {} neg {}'.format(
            cnt_pos + cnt_neg, cnt_pos, cnt_neg
        )
    )
    return features
Esempio n. 3
0
def _create_data(idx, input_paths):
    # Load sentence-piece model
    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.sp_path)

    input_shards = []
    total_line_cnt = 0
    for input_path in input_paths:
        input_data, sent_ids = [], []
        sent_id, line_cnt = True, 0
        tf.logging.info('Processing %s', input_path)
        for line in tf.gfile.Open(input_path):
            if line_cnt % 100000 == 0:
                tf.logging.info('Loading line %d', line_cnt)
            line_cnt += 1

            if not line.strip():
                if FLAGS.use_eod:
                    sent_id = not sent_id
                    cur_sent = [EOD_ID]
                else:
                    continue
            else:
                if FLAGS.from_raw_text:
                    cur_sent = preprocess_text(line.strip(),
                                               lower=FLAGS.uncased)
                    cur_sent = encode_ids(sp, cur_sent)
                else:
                    cur_sent = list(map(int, line.strip().split()))

            input_data.extend(cur_sent)
            sent_ids.extend([sent_id] * len(cur_sent))
            sent_id = not sent_id

        tf.logging.info('Finish with line %d', line_cnt)
        if line_cnt == 0:
            continue

        input_data = np.array(input_data, dtype=np.int64)
        sent_ids = np.array(sent_ids, dtype=np.bool)

        total_line_cnt += line_cnt
        input_shards.append((input_data, sent_ids))

    tf.logging.info('[Task %d] Total number line: %d', idx, total_line_cnt)

    tfrecord_dir = os.path.join(FLAGS.save_dir, 'tfrecords')

    filenames, num_batch = [], 0

    # Randomly shuffle input shards (with a fixed but distinct random seed)
    np.random.seed(100 * FLAGS.task + FLAGS.pass_id)

    perm_indices = np.random.permutation(len(input_shards))
    tf.logging.info(
        'Using perm indices %s for pass %d',
        perm_indices.tolist(),
        FLAGS.pass_id,
    )

    input_data_list, sent_ids_list = [], []
    prev_sent_id = None
    for perm_idx in perm_indices:
        input_data, sent_ids = input_shards[perm_idx]
        # make sure the `send_ids[0] == not prev_sent_id`
        if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
            sent_ids = np.logical_not(sent_ids)

        # append to temporary list
        input_data_list.append(input_data)
        sent_ids_list.append(sent_ids)

        # update `prev_sent_id`
        prev_sent_id = sent_ids[-1]

    input_data = np.concatenate(input_data_list)
    sent_ids = np.concatenate(sent_ids_list)

    file_name, cur_num_batch = create_tfrecords(
        save_dir=tfrecord_dir,
        basename='{}-{}-{}'.format(FLAGS.split, idx, FLAGS.pass_id),
        data=[input_data, sent_ids],
        bsz_per_host=FLAGS.bsz_per_host,
        seq_len=FLAGS.seq_len,
        bi_data=FLAGS.bi_data,
        sp=sp,
    )

    filenames.append(file_name)
    num_batch += cur_num_batch

    record_info = {'filenames': filenames, 'num_batch': num_batch}

    return record_info
Esempio n. 4
0
import sentencepiece as spm
from prepro_utils import preprocess_text, encode_ids
import os
# some code omitted here...
# initialize FLAGS

wkdir = os.path.dirname(os.path.abspath(''))
spiece_model_file = os.path.join(
    wkdir, 'model_cache/xlnet_cased_L-12_H-768_A-12/spiece.model')

text = "An input text string. pan viva build ElasticSearch to host netgear v7610 documents"

sp_model = spm.SentencePieceProcessor()
sp_model.Load(spiece_model_file)
text2 = preprocess_text(text, lower=False)
ids = encode_ids(sp_model, text2)

cc = 0
Esempio n. 5
0
def tokenize_fn(text):
    text = preprocess_text(text, lower = False)
    return encode_ids(sp_model, text)
Esempio n. 6
0
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenize_fn):
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    with open(FLAGS.output_dir + "/label2id.pkl", "wb") as w:
        pickle.dump(label_map, w)
    textlist = example.text.split(" ")
    labellist = example.label.split(" ")
    tokens = []
    labels = []
    for i, word in enumerate(textlist):
        # print(word)
        token = [preprocess_text(word, lower=True)]
        # print(token)
        tokens.extend(token)
        label_1 = labellist[i]
        for m in range(len(token)):
            if m == 0:
                labels.append(label_1)
            else:
                labels.append("X")

    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]
        labels = labels[0:(max_seq_length - 2)]
    ntokens = []
    segment_ids = []
    label_ids = []
    input_ids = []
    ntokens.append("[CLS]")
    segment_ids.append(0)
    # append("O") or append("[CLS]") not sure!
    label_ids.append(label_map["[CLS]"])
    # print('tokens')
    # print(tokens)
    for i, token in enumerate(tokens):
        ntokens.append(token)
        segment_ids.append(0)
        label_ids.append(label_map[labels[i]])
    #ntokens.append("[SEP]")
    #segment_ids.append(0)
    # append("O") or append("[SEP]") not sure!
    #label_ids.append(label_map["[SEP]"])
    for i in ntokens:
        input_ids.append(tokenize_fn(i))
    input_mask = [1] * len(input_ids)

    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        ntokens.append(0)
        label_ids.append(0)
    # print(len(input_mask))
    # print(len(input_ids))

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(label_ids) == max_seq_length
    assert len(ntokens) == max_seq_length
    if ex_index < 3:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info("tokens: %s" % " ".join([printable_text(x) for x in tokens]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))

    # print('input_ids')
    # print(input_ids)
    # print(type(input_ids))

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_ids=label_ids)
    return feature
Esempio n. 7
0
 def tokenize_fn(text):
     text = preprocess_text(text, lower=FLAGS.uncased)
     if sp.PieceToId(text) == 0:
         return 99999
     return sp.PieceToId(text)
Esempio n. 8
0
 def encode(self,
            text):
     """Encode text for XLNet"""
     processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case)
     encoded_ids = prepro_utils.encode_ids(self.sp_processor, processed_text)
     return encoded_ids
Esempio n. 9
0
 def tokenize(self,
              text):
     """Tokenize text for XLNet"""
     processed_text = prepro_utils.preprocess_text(text, lower=self.lower_case)
     tokenized_pieces = prepro_utils.encode_pieces(self.sp_processor, processed_text, return_unicode=False)
     return tokenized_pieces
Esempio n. 10
0
def convert_examples_to_features(examples, max_seq_length, sp_model, uncased):
    """Converts a single `InputExample` into a single `InputFeatures`."""

    features = []
    for ex_index, example in enumerate(examples):
        if isinstance(example, PaddingInputExample):
            features.append(
                InputFeatures(unique_id=ex_index,
                              tokens=[''] * max_seq_length,
                              input_ids=[0] * max_seq_length,
                              input_mask=[1] * max_seq_length,
                              segment_ids=[0] * max_seq_length,
                              label_id=0,
                              is_real_example=False))
            continue

        tokens_a_preprocessed = preprocess_text(example.text_a, lower=uncased)
        tokens_a_unicode, tokens_a = _encode_ids(sp_model,
                                                 tokens_a_preprocessed)
        tokens_a_str = [
            token.encode("ascii", "ignore").decode('utf-8', 'ignore')
            for token in tokens_a_unicode
        ]
        tokens_b = None
        tokens_b_str = None
        if example.text_b:
            tokens_b_preprocessed = preprocess_text(example.text_b,
                                                    lower=uncased)
            tokens_b_unicode, tokens_b = _encode_ids(sp_model,
                                                     tokens_b_preprocessed)
            tokens_b_str = [
                token.encode("ascii", "ignore").decode('utf-8', 'ignore')
                for token in tokens_b_unicode
            ]

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for two [SEP] & one [CLS] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for one [SEP] & one [CLS] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:max_seq_length - 2]

        tokens = []
        tokens_str = []
        segment_ids = []
        for token, token_str in zip(tokens_a, tokens_a_str):
            tokens.append(token)
            tokens_str.append(token_str)
            segment_ids.append(SEG_ID_A)
        tokens.append(SEP_ID)
        tokens_str.append("<sep>")
        segment_ids.append(SEG_ID_A)

        if tokens_b:
            for token, token_str in zip(tokens_b, tokens_b_str):
                tokens.append(token)
                tokens_str.append(token_str)
                segment_ids.append(SEG_ID_B)
            tokens.append(SEP_ID)
            tokens_str.append("<sep>")
            segment_ids.append(SEG_ID_B)

        tokens.append(CLS_ID)
        tokens_str.append("<sep>")
        segment_ids.append(SEG_ID_CLS)

        input_ids = tokens

        # The mask has 0 for real tokens and 1 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [0] * len(input_ids)

        # Zero-pad up to the sequence length.
        if len(input_ids) < max_seq_length:
            delta_len = max_seq_length - len(input_ids)
            input_ids = [0] * delta_len + input_ids
            input_mask = [1] * delta_len + input_mask
            segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        if ex_index < 5:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % ex_index)
            tf.logging.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" %
                            " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
            tf.logging.info("label: {} (id = {})".format(0.0, 0))

        features.append(
            InputFeatures(unique_id=ex_index,
                          tokens=tokens_str,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=0,
                          is_real_example=True))
    return features
Esempio n. 11
0
 def tokenize_fn(text):
     text = preprocess_text(text)
     return encode_ids(text)
Esempio n. 12
0
def _create_data(idx,
                 src_file,
                 tgt_file,
                 src_lang,
                 tgt_lang,
                 transliterate=True,
                 language_tag=True):
    # Load sentence-piece model
    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.sp_path)

    input_data = []
    target_data = []
    target_mask_data = []
    input_mask_data = []
    total_line_cnt = 0
    for src_line, tgt_line in zip(tf.gfile.Open(src_file),
                                  tf.gfile.Open(tgt_file)):
        if total_line_cnt % 100000 == 0:
            tf.logging.info("Loading line %d", total_line_cnt)

        if not src_line.strip() or not tgt_line.strip():
            continue

        if FLAGS.from_raw_text:
            src_sent = preprocess_text(src_line.strip(), lower=FLAGS.uncased)
            tgt_sent = preprocess_text(tgt_line.strip(), lower=FLAGS.uncased)
            src_sent = encode_ids(sp,
                                  src_sent,
                                  transliterate=transliterate,
                                  language_tag=False)
            tgt_sent = encode_ids(sp,
                                  tgt_sent,
                                  transliterate=transliterate,
                                  language_tag=False)
            tgt_sent = tgt_sent + [EOS_ID]
            tgt_sent_input = tgt_sent[:-1]
            tgt_sent_output = tgt_sent

            #Maximum size allowed for target
            tgt_sent_output = tgt_sent_output[:FLAGS.tgt_len]
            tgt_sent_input = tgt_sent_input[:FLAGS.tgt_len]

            if FLAGS.language_tag:
                src_id = ENG_ID if src_lang == "english" else HIN_ID
                tgt_id = ENG_ID if tgt_lang == "english" else HIN_ID
                src_sent_e = [src_id] + src_sent
                tgt_sent_input = [tgt_id] + tgt_sent_input

            if FLAGS.use_sos:
                src_sent_e = [SOS_ID] + src_sent_e
                tgt_sent_input = [SOS_ID] + tgt_sent_input

            input_len = len(src_sent_e) + len(
                tgt_sent_input) + 1  #One extra for EOS after source
            if input_len > FLAGS.seq_len:
                if FLAGS.long_sentences == 'ignore':
                    continue
                else:
                    # Truncate in ratio of their original lenghts
                    to_trunc = input_len - FLAGS.seq_len
                    len_ratio = len(src_sent_e) / len(tgt_sent_input)
                    to_trunc_src = min(int(len_ratio * to_trunc), to_trunc)
                    to_trunc_tgt = to_trunc - to_trunc_src
                    if to_trunc_src > 0:
                        src_sent_e = src_sent_e[:-to_trunc_src]
                    if to_trunc_tgt > 0:
                        tgt_sent_input = tgt_sent_input[:-to_trunc_tgt]
                        tgt_sent_output = tgt_sent_output[:-to_trunc_tgt]
                    input_len = FLAGS.seq_len
                    assert len(src_sent_e) + len(
                        tgt_sent_input) + 1 == input_len

            # Target padding to tgt_len on the left side
            target_mask = [0] * (FLAGS.tgt_len - len(tgt_sent_output)
                                 ) + [1] * len(tgt_sent_output)
            target = [PAD_ID] * (FLAGS.tgt_len -
                                 len(tgt_sent_output)) + tgt_sent_output

            # Paddings for input
            pads = [PAD_ID] * (FLAGS.seq_len - input_len)
            instance = pads + src_sent_e + [EOS_ID] + tgt_sent_input
            input_mask = [0] * len(pads) + [1] * (len(instance) - len(pads))

            assert len(instance) == FLAGS.seq_len, len(instance)
            assert len(input_mask) == FLAGS.seq_len, len(input_mask)
            assert len(target) == FLAGS.tgt_len, len(target)
            assert len(target_mask) == FLAGS.tgt_len, len(target_mask)
        else:
            raise Exception("Loading from id files not yet supported")

        input_data.append(np.array(instance, dtype=np.int64))
        target_data.append(np.array(target, dtype=np.int64))
        target_mask_data.append(np.array(target_mask, dtype=np.float32))
        input_mask_data.append(np.array(input_mask, dtype=np.float32))
        total_line_cnt += 1

    tf.logging.info("Finish with line %d", total_line_cnt)
    if total_line_cnt == 0:
        raise Exception("Files have no valid data")

    tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)

    tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")

    file_name, num_batch = create_tfrecords(
        save_dir=tfrecord_dir,
        basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
        data=(input_data, target_data, target_mask_data, input_mask_data),
        seq_len=FLAGS.seq_len,
        tgt_len=FLAGS.tgt_len,
        bi_data=FLAGS.bi_data,
        sp=sp)

    record_info = {
        "filenames": [file_name],
        "langs": [src_lang, tgt_lang],
        "num_batch": num_batch
    }

    return record_info