def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): conf = ConvBertConfig.from_json_file(convbert_config_file) model = ConvBertModel(conf) model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) model.save_pretrained(pytorch_dump_path) tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) tf_model.save_pretrained(pytorch_dump_path)
def test_inference_masked_lm(self): model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base") input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] expected_shape = [1, 6, 768] self.assertEqual(output.shape, expected_shape) expected_slice = tf.constant([[ [-0.03475493, -0.4686034, -0.30638832], [0.22637248, -0.26988646, -0.7423424], [0.10324868, -0.45013508, -0.58280784], ]]) tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
def test_inference_masked_lm(self): model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base") input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] expected_shape = [1, 6, 768] self.assertEqual(output.shape, expected_shape) print(output[:, :3, :3]) expected_slice = tf.constant([[ [-0.10334751, -0.37152207, -0.2682219], [0.20078957, -0.3918426, -0.78811496], [0.08000169, -0.509474, -0.59314483], ]]) tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
def test_model_from_pretrained(self): model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base") self.assertIsNotNone(model)