Exemplo n.º 1
0
def main(eval_idx):
    os.mkdir('path_cmn')
    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=cfg.vocab_file)

    writer_train = FileWriter(cfg.train_file_mindrecord, cfg.num_splits)
    writer_eval = FileWriter(cfg.eval_file_mindrecord, cfg.num_splits)
    data_schema = {"source_sos_ids": {"type": "int32", "shape": [-1]},
                   "source_sos_mask": {"type": "int32", "shape": [-1]},
                   "source_eos_ids": {"type": "int32", "shape": [-1]},
                   "source_eos_mask": {"type": "int32", "shape": [-1]},
                   "target_sos_ids": {"type": "int32", "shape": [-1]},
                   "target_sos_mask": {"type": "int32", "shape": [-1]},
                   "target_eos_ids": {"type": "int32", "shape": [-1]},
                   "target_eos_mask": {"type": "int32", "shape": [-1]}
                   }

    writer_train.add_schema(data_schema, "tranformer train")
    writer_eval.add_schema(data_schema, "tranformer eval")

    index = 0
    f_train = open(cfg.train_file_source, 'w', encoding='utf-8')
    f_test = open(cfg.eval_file_source,'w',encoding='utf-8')
    f = open(cfg.input_file, "r", encoding='utf-8')
    for s_line in f:
        line = tokenization.convert_to_unicode(s_line)

        source_line, target_line = line.strip().split("\t")
        source_tokens = tokenizer.tokenize(source_line)
        target_tokens = tokenizer.tokenize(target_line)

        if len(source_tokens) >= (cfg.max_seq_length-1) or len(target_tokens) >= (cfg.max_seq_length-1):
            if cfg.clip_to_max_len:
                source_tokens = source_tokens[:cfg.max_seq_length-1]
                target_tokens = target_tokens[:cfg.max_seq_length-1]
            else:
                continue
        
        index = index + 1
        print(source_tokens)
        instance = create_training_instance(source_tokens, target_tokens, cfg.max_seq_length)
        
        if index in eval_idx:
            f_test.write(s_line)
            features = write_instance_to_file(writer_eval, instance, tokenizer, cfg.max_seq_length)
        else:
            f_train.write(s_line)
            features = write_instance_to_file(writer_train, instance, tokenizer, cfg.max_seq_length)
    f.close()
    f_test.close()
    f_train.close()
    writer_train.commit()
    writer_eval.commit()
Exemplo n.º 2
0
def read_examples(input_file):
    """Read a list of `InputExample`s from an input file."""
    examples = []
    unique_id = 0
    with open(input_file, "r") as reader:
        while True:
            line = tokenization.convert_to_unicode(reader.readline())
            if not line:
                break
            line = line.strip()
            text_a = None
            text_b = None
            m = re.match(r"^(.*) \|\|\| (.*)$", line)
            if m is None:
                text_a = line
            else:
                text_a = m.group(1)
                text_b = m.group(2)
            examples.append(
                InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
            unique_id += 1
    return examples
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_file",
        type=str,
        required=True,
        help='Input raw text file (or comma-separated list of files).')
    parser.add_argument("--output_file",
                        type=str,
                        required=True,
                        help='Output MindRecord file.')
    parser.add_argument(
        "--num_splits",
        type=int,
        default=16,
        help='The MindRecord file will be split into the number of partition.')
    parser.add_argument(
        "--vocab_file",
        type=str,
        required=True,
        help='The vocabulary file that the Transformer model was trained on.')
    parser.add_argument("--clip_to_max_len",
                        type=bool,
                        default=False,
                        help='clip sequences to maximum sequence length.')
    parser.add_argument("--max_seq_length",
                        type=int,
                        default=128,
                        help='Maximum sequence length.')
    args = parser.parse_args()

    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)

    input_files = []
    for input_pattern in args.input_file.split(","):
        input_files.append(input_pattern)

    logging.info("*** Reading from input files ***")
    for input_file in input_files:
        logging.info("  %s", input_file)

    output_file = args.output_file
    logging.info("*** Writing to output files ***")
    logging.info("  %s", output_file)

    writer = FileWriter(output_file, args.num_splits)
    data_schema = {
        "source_sos_ids": {
            "type": "int64",
            "shape": [-1]
        },
        "source_sos_mask": {
            "type": "int64",
            "shape": [-1]
        },
        "source_eos_ids": {
            "type": "int64",
            "shape": [-1]
        },
        "source_eos_mask": {
            "type": "int64",
            "shape": [-1]
        },
        "target_sos_ids": {
            "type": "int64",
            "shape": [-1]
        },
        "target_sos_mask": {
            "type": "int64",
            "shape": [-1]
        },
        "target_eos_ids": {
            "type": "int64",
            "shape": [-1]
        },
        "target_eos_mask": {
            "type": "int64",
            "shape": [-1]
        }
    }
    writer.add_schema(data_schema, "tranformer hisi")

    total_written = 0
    total_read = 0

    for input_file in input_files:
        logging.info("*** Reading from   %s ***", input_file)
        with open(input_file, "r") as reader:
            while True:
                line = tokenization.convert_to_unicode(reader.readline())
                if not line:
                    break

                total_read += 1
                if total_read % 100000 == 0:
                    logging.info("%d ...", total_read)

                source_line, target_line = line.strip().split("\t")
                source_tokens = tokenizer.tokenize(source_line)
                target_tokens = tokenizer.tokenize(target_line)

                if len(source_tokens) >= args.max_seq_length or len(
                        target_tokens) >= args.max_seq_length:
                    logging.info("ignore long sentence!")
                    continue

                instance = create_training_instance(
                    source_tokens,
                    target_tokens,
                    args.max_seq_length,
                    clip_to_max_len=args.clip_to_max_len)
                if instance is None:
                    continue

                features = write_instance_to_file(writer, instance, tokenizer,
                                                  args.max_seq_length)
                total_written += 1

                if total_written <= 20:
                    logging.info("*** Example ***")
                    logging.info(
                        "source tokens: %s", " ".join([
                            tokenization.printable_text(x)
                            for x in instance.source_eos_tokens
                        ]))
                    logging.info(
                        "target tokens: %s", " ".join([
                            tokenization.printable_text(x)
                            for x in instance.target_sos_tokens
                        ]))

                    for feature_name in features.keys():
                        feature = features[feature_name]
                        logging.info("%s: %s", feature_name, feature)

    writer.commit()
    logging.info("Wrote %d total instances", total_written)
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True,
                        help='Input raw text file (or comma-separated list of files).')
    parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
    parser.add_argument("--num_splits", type=int, default=16,
                        help='The MindRecord file will be split into the number of partition.')
    parser.add_argument("--vocab_file", type=str, required=True,
                        help='The vocabulary file that the Transformer model was trained on.')
    parser.add_argument("--clip_to_max_len", type=bool, default=False,
                        help='clip sequences to maximum sequence length.')
    parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
    parser.add_argument("--bucket", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
                        help='bucket sequence length')

    args = parser.parse_args()

    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)

    input_files = []
    for input_pattern in args.input_file.split(","):
        input_files.append(input_pattern)

    logging.info("*** Read from input files ***")
    for input_file in input_files:
        logging.info("  %s", input_file)

    output_file = args.output_file
    logging.info("*** Write to output files ***")
    logging.info("  %s", output_file)

    total_written = 0
    total_read = 0

    feature_dict = {}
    for i in args.bucket:
        feature_dict[i] = []

    for input_file in input_files:
        logging.info("*** Reading from   %s ***", input_file)
        with open(input_file, "r") as reader:
            while True:
                line = tokenization.convert_to_unicode(reader.readline())
                if not line:
                    break

                total_read += 1
                if total_read % 100000 == 0:
                    logging.info("Read %d ...", total_read)

                source_line, target_line = line.strip().split("\t")
                source_tokens = tokenizer.tokenize(source_line)
                target_tokens = tokenizer.tokenize(target_line)

                if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
                    logging.info("ignore long sentence!")
                    continue

                instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
                                                    clip_to_max_len=args.clip_to_max_len)
                if instance is None:
                    continue

                features, seq_max_bucket_length = get_instance_features(instance, tokenizer, args.max_seq_length,
                                                                        args.bucket)
                for key in feature_dict:
                    if key == seq_max_bucket_length:
                        feature_dict[key].append(features)

                if total_read <= 10:
                    logging.info("*** Example ***")
                    logging.info("source tokens: %s", " ".join(
                        [tokenization.convert_to_printable(x) for x in instance.source_eos_tokens]))
                    logging.info("target tokens: %s", " ".join(
                        [tokenization.convert_to_printable(x) for x in instance.target_sos_tokens]))

                    for feature_name in features.keys():
                        feature = features[feature_name]
                        logging.info("%s: %s", feature_name, feature)

    for i in args.bucket:
        if args.num_splits == 1:
            output_file_name = output_file
        else:
            output_file_name = output_file + '_' + str(i) + '_'
        writer = FileWriter(output_file_name, args.num_splits)
        data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
                       "source_sos_mask": {"type": "int64", "shape": [-1]},
                       "source_eos_ids": {"type": "int64", "shape": [-1]},
                       "source_eos_mask": {"type": "int64", "shape": [-1]},
                       "target_sos_ids": {"type": "int64", "shape": [-1]},
                       "target_sos_mask": {"type": "int64", "shape": [-1]},
                       "target_eos_ids": {"type": "int64", "shape": [-1]},
                       "target_eos_mask": {"type": "int64", "shape": [-1]}
                       }
        writer.add_schema(data_schema, "tranformer")
        features_ = feature_dict[i]
        logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))

        for item in features_:
            writer.write_raw_data([item])
            total_written += 1
        writer.commit()

    logging.info("Wrote %d total instances", total_written)