예제 #1
0
파일: gpt2_test.py 프로젝트: lileicc/neurst
def test_openai_gpt2():
    from transformers import GPT2Model, GPT2Tokenizer

    input_text = "Here is some text to encode"
    pt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    pt_model = GPT2Model.from_pretrained("gpt2", return_dict=True)
    pt_outputs = pt_model(**pt_tokenizer([input_text], return_tensors="pt"))

    task = build_task({
        "class": "lm",
        "params": {
            "data_pipeline.class": "GPT2DataPipeline",
            "max_len": 50,
            "begin_of_sentence": "eos"
        }
    })

    model_cfgs = get_hyper_parameters("gpt2_117m")
    model = task.build_model(model_cfgs)
    restore_checkpoint_if_possible_v2(model, "117M", model_name="OpenAIGPT2")
    input_ids = task._data_pipeline.process(input_text)
    tf_inputs = {
        "trg_input": tf.convert_to_tensor([input_ids], tf.int64),
        "trg_length": tf.convert_to_tensor([len(input_ids)], tf.int64)
    }
    _, gen_init = model.get_symbols_to_logits_fn(tf_inputs, is_training=False, is_inference=False)
    tf_outputs = model.get_decoder_output(gen_init["decoder_input"],
                                          cache=gen_init["decoder_internal_cache"],
                                          is_training=False)
    assert_equal_numpy(pt_outputs.last_hidden_state.detach().numpy(), tf_outputs[:, :-1].numpy(), 5e-4)
예제 #2
0
    def _restore_ckpt_or_pretrain(self):
        """ restoring checkpoint from model_dir or pretrain_model dir. """
        stat = restore_checkpoint_if_possible(self.model, self.model_dir)
        continue_training = False
        if stat:
            logging.info(f"Successfully restoring checkpoint from model_dir={self.model_dir}")
            continue_training = True
        else:
            logging.info(f"No checkpoint restored from model_dir={self.model_dir}")
            if self._pretrain_model:
                if self._pretrain_v2:
                    for pt in self._pretrain_model:
                        logging.info(f"Trying to restore from pretrain_model={pt}")
                        logging.info("NOTE THAT, one must first check the variable names in this checkpoint, "
                                     "otherwise no variables will be restored.")
                        restore_checkpoint_if_possible_v2(self.model, **pt)
                else:
                    for pt, pt_varname in zip(self._pretrain_model, self._pretrain_variable_pattern):
                        logging.info(f"Trying to restore from pretrain_model={pt}")
                        logging.info("NOTE THAT, one must first check the variable names in this checkpoint, "
                                     "otherwise no variables will be restored.")
                        restore_checkpoint_if_possible(self.model, pt, var_name_pattern=pt_varname)

        if self._initial_global_step is None and continue_training:
            _step = compat.hack_global_step(self.model_dir)
            if _step:
                compat.register_initial_step(_step or 0)  # must do this before creating optimizer and training
                logging.info(f"Restored initial global step={_step}")
        else:
            compat.register_initial_step(self._initial_global_step or 0)