Example #1
0
    args = parser.parse_args()

    P.enable_static()

    if not os.path.exists(args.from_pretrained):
        raise ValueError('--from_pretrained not found: %s' %
                         args.from_pretrained)
    cfg_file_path = os.path.join(args.from_pretrained, 'ernie_config.json')
    param_path = os.path.join(args.from_pretrained, 'params')
    vocab_path = os.path.join(args.from_pretrained, 'vocab.txt')

    assert os.path.exists(cfg_file_path) and os.path.exists(
        param_path) and os.path.exists(vocab_path)

    hparams_cli = propeller.parse_hparam(args)
    hparams_config_file = json.loads(open(cfg_file_path).read())
    default_hparams = propeller.HParams(batch_size=32,
                                        num_labels=3,
                                        warmup_proportion=0.1,
                                        learning_rate=5e-5,
                                        weight_decay=0.01,
                                        use_task_id=False,
                                        use_fp16=args.use_amp)

    hparams = default_hparams.join(
        propeller.HParams(**hparams_config_file)).join(hparams_cli)

    default_run_config = dict(max_steps=args.epoch * 390000 /
                              hparams.batch_size,
                              save_steps=1000,
Example #2
0
        return {'acc': acc}


if __name__ == '__main__':
    parser = propeller.ArgumentParser('distill model with ERNIE')
    parser.add_argument('--max_seqlen', type=int, default=128)
    parser.add_argument('--vocab_file', type=str, required=True)
    parser.add_argument('--teacher_vocab_file', type=str, required=True)
    parser.add_argument('--teacher_max_seqlen', type=int, default=128)
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--server_batch_size', type=int, default=64)
    parser.add_argument('--num_coroutine', type=int, default=1)
    parser.add_argument('--teacher_host', type=str, required=True)
    args = parser.parse_args()
    run_config = propeller.parse_runconfig(args)
    hparams = propeller.parse_hparam(args)

    teacher_vocab = {
        j.strip().split(b'\t')[0].decode('utf8'): i
        for i, j in enumerate(open(args.teacher_vocab_file, 'rb'))
    }
    vocab = {
        j.strip().split(b'\t')[0].decode('utf8'): i
        for i, j in enumerate(open(args.vocab_file, 'rb'))
    }
    teacher_sep_id = teacher_vocab['[SEP]']
    teacher_cls_id = teacher_vocab['[CLS]']
    teacher_unk_id = teacher_vocab['[UNK]']
    unk_id = vocab['[UNK]']

    char_tokenizer = utils.data.CharTokenizer(vocab.keys())