예제 #1
0
    if not os.path.exists(args.from_pretrained):
        raise ValueError('--from_pretrained not found: %s' %
                         args.from_pretrained)
    cfg_file_path = os.path.join(args.from_pretrained, 'ernie_config.json')
    param_path = os.path.join(args.from_pretrained, 'params')
    vocab_path = os.path.join(args.from_pretrained, 'vocab.txt')

    assert os.path.exists(cfg_file_path) and os.path.exists(
        param_path) and os.path.exists(vocab_path)

    hparams_cli = propeller.parse_hparam(args)
    hparams_config_file = json.loads(open(cfg_file_path).read())
    default_hparams = propeller.HParams(batch_size=32,
                                        num_labels=3,
                                        warmup_proportion=0.1,
                                        learning_rate=5e-5,
                                        weight_decay=0.01,
                                        use_task_id=False,
                                        use_fp16=args.use_amp)

    hparams = default_hparams.join(
        propeller.HParams(**hparams_config_file)).join(hparams_cli)

    default_run_config = dict(max_steps=args.epoch * 390000 /
                              hparams.batch_size,
                              save_steps=1000,
                              log_steps=10,
                              max_ckpt=1,
                              skip_steps=0,
                              model_dir=tempfile.mkdtemp(),
                              eval_steps=100)
예제 #2
0
파일: pretrain.py 프로젝트: leo038/ERNIE
    if not os.path.exists(args.from_pretrained):
        raise ValueError('--from_pretrained not found: %s' %
                         args.from_pretrained)
    cfg_file_path = os.path.join(args.from_pretrained, 'ernie_config.json')
    param_path = os.path.join(args.from_pretrained, 'params')
    vocab_path = os.path.join(args.from_pretrained, 'vocab.txt')
    assert os.path.exists(cfg_file_path) and os.path.exists(
        param_path) and os.path.exists(vocab_path)

    hparams_cli = propeller.parse_hparam(args)
    hparams_config_file = json.loads(open(cfg_file_path).read())
    default_hparams = propeller.HParams(
        batch_size=50,
        warmup_steps=10000,
        learning_rate=1e-4,
        weight_decay=0.01,
        use_fp16=False,
    )

    hparams = default_hparams.join(
        propeller.HParams(**hparams_config_file)).join(hparams_cli)

    default_run_config = dict(max_steps=1000000,
                              save_steps=10000,
                              log_steps=10,
                              max_ckpt=3,
                              skip_steps=0,
                              eval_steps=-1)

    run_config = dict(default_run_config, **json.loads(args.run_config))