Ejemplo n.º 1
0
def get_pretraining_model(model_name, ctx_l):
    cfg, tokenizer, _, _ = get_pretrained_bert(model_name,
                                               load_backbone=False,
                                               load_mlm=False)
    cfg = BertModel.get_cfg().clone_merge(cfg)
    model = BertForPretrain(cfg)
    return cfg, tokenizer, model
Ejemplo n.º 2
0
def get_pretraining_model(model_name, ctx_l, max_seq_length=512):
    cfg, tokenizer, _, _ = get_pretrained_bert(
        model_name, load_backbone=False, load_mlm=False)
    cfg = BertModel.get_cfg().clone_merge(cfg)
    cfg.defrost()
    cfg.MODEL.max_length = max_seq_length
    cfg.freeze()
    model = BertForPretrain(cfg)
    model.initialize(ctx=ctx_l)
    model.hybridize()
    return cfg, tokenizer, model
Ejemplo n.º 3
0
def test_bert_get_pretrained(model_name, ctx):
    assert len(list_pretrained_bert()) > 0
    with tempfile.TemporaryDirectory() as root, ctx:
        cfg, tokenizer, backbone_params_path, mlm_params_path =\
            get_pretrained_bert(model_name, load_backbone=True, load_mlm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        bert_model = BertModel.from_cfg(cfg)
        bert_model.load_parameters(backbone_params_path)
        bert_mlm_model = BertForMLM(cfg)
        if mlm_params_path is not None:
            bert_mlm_model.load_parameters(mlm_params_path)
        bert_mlm_model = BertForMLM(cfg)
        bert_mlm_model.backbone_model.load_parameters(backbone_params_path)
def main():
    """Main function."""
    time_start = time.time()

    # random seed
    random.seed(args.random_seed)

    # create output dir
    output_dir = os.path.expanduser(args.output_dir)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    # vocabulary and tokenizer
    _, tokenizer, _, _ = get_pretrained_bert(args.model_name,
                                             load_backbone=False,
                                             load_mlm=False)
    # count the number of input files
    input_files = []
    datasets = args.input_dir.split(',')
    for dataset in datasets:
        tmp_names = os.listdir(dataset)
        for file in tmp_names:
            input_files.append(os.path.expanduser(os.path.join(dataset, file)))

    # seperate input_files
    total_num = len(input_files)
    part_num = total_num // args.shard_num
    input_files.sort()
    input_files = input_files[
        part_num * args.current_shard:min(part_num *
                                          (args.current_shard + 1), total_num)]

    for input_file in input_files:
        logging.info('\t%s', input_file)
    num_inputs = len(input_files)
    num_outputs = min(args.num_outputs, len(input_files))
    logging.info('*** Reading from %d input files ***', num_inputs)

    # calculate the number of splits
    file_splits = []
    split_size = (num_inputs + num_outputs - 1) // num_outputs
    for i in range(num_outputs):
        split_start = i * split_size
        split_end = min(num_inputs, (i + 1) * split_size)
        file_splits.append(input_files[split_start:split_end])

    # prepare workload
    count = 0
    process_args = []

    for i, file_split in enumerate(file_splits):
        output_file = os.path.join(
            output_dir, 'shard-{}-{}.npz'.format(str(args.current_shard),
                                                 str(i).zfill(3)))
        count += len(file_split)
        process_args.append(
            (file_split, tokenizer, args.max_seq_length, args.short_seq_prob,
             args.masked_lm_prob, args.max_predictions_per_seq,
             args.whole_word_mask, tokenizer.vocab, args.dupe_factor, 1, None,
             output_file, args.random_next_sentence))

    # sanity check
    assert count == len(input_files)

    # dispatch to workers
    nworker = args.num_workers
    if nworker > 1:
        pool = Pool(nworker)
        pool.map(create_training_instances, process_args)
    else:
        for process_arg in process_args:
            create_training_instances(process_arg)

    time_end = time.time()
    logging.info('Time cost=%.1f', time_end - time_start)