def get_pretraining_model(model_name, ctx_l): cfg, tokenizer, _, _ = get_pretrained_bert(model_name, load_backbone=False, load_mlm=False) cfg = BertModel.get_cfg().clone_merge(cfg) model = BertForPretrain(cfg) return cfg, tokenizer, model
def get_pretraining_model(model_name, ctx_l, max_seq_length=512): cfg, tokenizer, _, _ = get_pretrained_bert( model_name, load_backbone=False, load_mlm=False) cfg = BertModel.get_cfg().clone_merge(cfg) cfg.defrost() cfg.MODEL.max_length = max_seq_length cfg.freeze() model = BertForPretrain(cfg) model.initialize(ctx=ctx_l) model.hybridize() return cfg, tokenizer, model
def test_bert_get_pretrained(model_name, ctx): assert len(list_pretrained_bert()) > 0 with tempfile.TemporaryDirectory() as root, ctx: cfg, tokenizer, backbone_params_path, mlm_params_path =\ get_pretrained_bert(model_name, load_backbone=True, load_mlm=True, root=root) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) bert_model = BertModel.from_cfg(cfg) bert_model.load_parameters(backbone_params_path) bert_mlm_model = BertForMLM(cfg) if mlm_params_path is not None: bert_mlm_model.load_parameters(mlm_params_path) bert_mlm_model = BertForMLM(cfg) bert_mlm_model.backbone_model.load_parameters(backbone_params_path)
def main(): """Main function.""" time_start = time.time() # random seed random.seed(args.random_seed) # create output dir output_dir = os.path.expanduser(args.output_dir) if not os.path.exists(output_dir): os.mkdir(output_dir) # vocabulary and tokenizer _, tokenizer, _, _ = get_pretrained_bert(args.model_name, load_backbone=False, load_mlm=False) # count the number of input files input_files = [] datasets = args.input_dir.split(',') for dataset in datasets: tmp_names = os.listdir(dataset) for file in tmp_names: input_files.append(os.path.expanduser(os.path.join(dataset, file))) # seperate input_files total_num = len(input_files) part_num = total_num // args.shard_num input_files.sort() input_files = input_files[ part_num * args.current_shard:min(part_num * (args.current_shard + 1), total_num)] for input_file in input_files: logging.info('\t%s', input_file) num_inputs = len(input_files) num_outputs = min(args.num_outputs, len(input_files)) logging.info('*** Reading from %d input files ***', num_inputs) # calculate the number of splits file_splits = [] split_size = (num_inputs + num_outputs - 1) // num_outputs for i in range(num_outputs): split_start = i * split_size split_end = min(num_inputs, (i + 1) * split_size) file_splits.append(input_files[split_start:split_end]) # prepare workload count = 0 process_args = [] for i, file_split in enumerate(file_splits): output_file = os.path.join( output_dir, 'shard-{}-{}.npz'.format(str(args.current_shard), str(i).zfill(3))) count += len(file_split) process_args.append( (file_split, tokenizer, args.max_seq_length, args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq, args.whole_word_mask, tokenizer.vocab, args.dupe_factor, 1, None, output_file, args.random_next_sentence)) # sanity check assert count == len(input_files) # dispatch to workers nworker = args.num_workers if nworker > 1: pool = Pool(nworker) pool.map(create_training_instances, process_args) else: for process_arg in process_args: create_training_instances(process_arg) time_end = time.time() logging.info('Time cost=%.1f', time_end - time_start)