Example #1
0
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
    conf = ConvBertConfig.from_json_file(convbert_config_file)
    model = ConvBertModel(conf)

    model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
    model.save_pretrained(pytorch_dump_path)

    tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
    tf_model.save_pretrained(pytorch_dump_path)
Example #2
0
    def create_and_check_model(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = TFConvBertModel(config=config)
        inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}

        inputs = [input_ids, input_mask]
        result = model(inputs)

        result = model(input_ids)

        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
    def test_inference_masked_lm(self):
        model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
        input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
        output = model(input_ids)[0]

        expected_shape = [1, 6, 768]
        self.assertEqual(output.shape, expected_shape)

        expected_slice = tf.constant([[
            [-0.03475493, -0.4686034, -0.30638832],
            [0.22637248, -0.26988646, -0.7423424],
            [0.10324868, -0.45013508, -0.58280784],
        ]])
        tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
Example #4
0
    def test_inference_masked_lm(self):
        model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
        input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
        output = model(input_ids)[0]

        expected_shape = [1, 6, 768]
        self.assertEqual(output.shape, expected_shape)

        print(output[:, :3, :3])

        expected_slice = tf.constant([[
            [-0.10334751, -0.37152207, -0.2682219],
            [0.20078957, -0.3918426, -0.78811496],
            [0.08000169, -0.509474, -0.59314483],
        ]])
        tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
 def test_model_from_pretrained(self):
     model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
     self.assertIsNotNone(model)