示例#1
0
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = copy.deepcopy(inputs_dict)

        if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            inputs_dict = {
                k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
                if isinstance(v, tf.Tensor) and v.ndim > 0
                else v
                for k, v in inputs_dict.items()
            }

        if return_labels:
            if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
                inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
                inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in [
                *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
                *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
                *TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
                *TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
            ]:
                inputs_dict["labels"] = tf.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
                )
        return inputs_dict
示例#2
0
    def _prepare_for_class(self,
                           inputs_dict,
                           model_class,
                           return_labels=False):
        if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            inputs_dict = {
                k:
                tf.tile(tf.expand_dims(v, 1),
                        (1, self.model_tester.num_choices,
                         1)) if isinstance(v, tf.Tensor) and v.ndim != 0 else v
                for k, v in inputs_dict.items()
            }

        if return_labels:
            if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
            elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(
            ):
                inputs_dict["start_positions"] = tf.zeros(
                    self.model_tester.batch_size)
                inputs_dict["end_positions"] = tf.zeros(
                    self.model_tester.batch_size)
            elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(
            ):
                inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
            elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(
            ):
                inputs_dict["labels"] = tf.zeros(
                    (self.model_tester.batch_size,
                     self.model_tester.seq_length))
        return inputs_dict
示例#3
0
 def _prepare_for_class(self, inputs_dict, model_class):
     if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
         return {
             k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
             if isinstance(v, tf.Tensor) and v.ndim != 0
             else v
             for k, v in inputs_dict.items()
         }
     return inputs_dict
    def test_compile_tf_model(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5,
                                             epsilon=1e-08,
                                             clipnorm=1.0)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")

        for model_class in self.all_model_classes:
            if self.is_encoder_decoder:
                input_ids = {
                    "decoder_input_ids":
                    tf.keras.Input(batch_shape=(2, 2000),
                                   name="decoder_input_ids",
                                   dtype="int32"),
                    "input_ids":
                    tf.keras.Input(batch_shape=(2, 2000),
                                   name="input_ids",
                                   dtype="int32"),
                }
            elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                input_ids = tf.keras.Input(batch_shape=(4, 2, 2000),
                                           name="input_ids",
                                           dtype="int32")
            else:
                input_ids = tf.keras.Input(batch_shape=(2, 2000),
                                           name="input_ids",
                                           dtype="int32")

            # Prepare our model
            model = model_class(config)

            # Let's load it from the disk to be sure we can use pretrained weights
            with tempfile.TemporaryDirectory() as tmpdirname:
                outputs = model(
                    self._prepare_for_class(inputs_dict,
                                            model_class))  # build the model
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)

            outputs_dict = model(input_ids)
            hidden_states = outputs_dict[0]

            # Add a dense layer on top to test integration with other keras modules
            outputs = tf.keras.layers.Dense(2,
                                            activation="softmax",
                                            name="outputs")(hidden_states)

            # Compile extended model
            extended_model = tf.keras.Model(inputs=[input_ids],
                                            outputs=[outputs])
            extended_model.compile(optimizer=optimizer,
                                   loss=loss,
                                   metrics=[metric])