args = parser.parse_args() P.enable_static() if not os.path.exists(args.from_pretrained): raise ValueError('--from_pretrained not found: %s' % args.from_pretrained) cfg_file_path = os.path.join(args.from_pretrained, 'ernie_config.json') param_path = os.path.join(args.from_pretrained, 'params') vocab_path = os.path.join(args.from_pretrained, 'vocab.txt') assert os.path.exists(cfg_file_path) and os.path.exists( param_path) and os.path.exists(vocab_path) hparams_cli = propeller.parse_hparam(args) hparams_config_file = json.loads(open(cfg_file_path).read()) default_hparams = propeller.HParams(batch_size=32, num_labels=3, warmup_proportion=0.1, learning_rate=5e-5, weight_decay=0.01, use_task_id=False, use_fp16=args.use_amp) hparams = default_hparams.join( propeller.HParams(**hparams_config_file)).join(hparams_cli) default_run_config = dict(max_steps=args.epoch * 390000 / hparams.batch_size, save_steps=1000,
return {'acc': acc} if __name__ == '__main__': parser = propeller.ArgumentParser('distill model with ERNIE') parser.add_argument('--max_seqlen', type=int, default=128) parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--teacher_vocab_file', type=str, required=True) parser.add_argument('--teacher_max_seqlen', type=int, default=128) parser.add_argument('--data_dir', type=str) parser.add_argument('--server_batch_size', type=int, default=64) parser.add_argument('--num_coroutine', type=int, default=1) parser.add_argument('--teacher_host', type=str, required=True) args = parser.parse_args() run_config = propeller.parse_runconfig(args) hparams = propeller.parse_hparam(args) teacher_vocab = { j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.teacher_vocab_file, 'rb')) } vocab = { j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb')) } teacher_sep_id = teacher_vocab['[SEP]'] teacher_cls_id = teacher_vocab['[CLS]'] teacher_unk_id = teacher_vocab['[UNK]'] unk_id = vocab['[UNK]'] char_tokenizer = utils.data.CharTokenizer(vocab.keys())