Example #1
0
    def test_lmhead_model_from_pretrained(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelWithLMHead.from_pretrained(model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)
Example #2
0
    def test_question_answering_model_from_pretrained(self):
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
        for model_name in ["bert-base-uncased"]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)
Example #3
0
    def test_model_for_causal_lm(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForCausalLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForCausalLM.from_pretrained(
                model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertLMHeadModel)
Example #4
0
    def test_sequence_classification_model_from_pretrained(self):
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
        for model_name in ["bert-base-uncased"]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForSequenceClassification.from_pretrained(
                model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)
Example #5
0
    def test_model_for_pretraining_from_pretrained(self):
        import h5py

        self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))

        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
        for model_name in ["bert-base-uncased"]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForPreTraining.from_pretrained(model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForPreTraining)
Example #6
0
 def test_inference_with_configs_graph(self):
     MODEL_ID = "albert-base-v2"
     config = AutoConfig.from_pretrained(MODEL_ID)
     benchmark_args = TensorFlowBenchmarkArguments(
         models=[MODEL_ID],
         training=False,
         no_inference=False,
         sequence_lengths=[8],
         batch_sizes=[1],
         no_multi_process=True,
     )
     benchmark = TensorFlowBenchmark(benchmark_args, [config])
     results = benchmark.run()
     self.check_results_dict_not_empty(results.time_inference_result)
     self.check_results_dict_not_empty(results.memory_inference_result)
Example #7
0
    def __init__(self,
                 args: BenchmarkArguments = None,
                 configs: PretrainedConfig = None):
        self.args = args
        if configs is None:
            self.config_dict = {
                model_name: AutoConfig.from_pretrained(model_name)
                for model_name in self.args.model_names
            }
        else:
            self.config_dict = {
                model_name: config
                for model_name, config in zip(self.args.model_names, configs)
            }

        if not self.args.no_memory and os.getenv(
                "TRANSFORMERS_USE_MULTIPROCESSING") == 0:
            logger.warning(
                "Memory consumption will not be measured accurately if `args.no_multi_process` is set to `True.` The flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
            )

        self._print_fn = None
        self._framework_version = None
        self._environment_info = None