def create_and_check_bert_for_pretraining(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = TFBertForPreTraining(config=config)
     inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
     result = model(inputs)
     self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
     self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
    def setup_class(cls):
        cls.micro_batch_size = 2
        cls.seq_length = 128
        cls.max_predictions_per_seq = 20
        cls.sample_inputs, cls.sample_labels = load_dataset(
            micro_batch_size=cls.micro_batch_size,
            dataset_dir=Path(
                get_path()).joinpath("data_utils").joinpath("wikipedia"),
            seq_length=cls.seq_length,
        )

        cfg = ipu.config.IPUConfig()
        cfg.device_connection.type = ipu.config.DeviceConnectionType.ON_DEMAND
        cfg.configure_ipu_system()
        cls.strategy = ipu.ipu_strategy.IPUStrategy()
        with cls.strategy.scope():

            config = BertConfig(
                vocab_size=30528,
                num_hidden_layers=4,
                num_attention_heads=4,
                hidden_size=256,
                intermediate_size=768,
                max_predictions_per_seq=cls.max_predictions_per_seq,
            )

            set_random_seeds(seed=42)
            cls.orig_model = TFBertForPreTraining(config)
            cls.orig_outputs = cls.orig_model(cls.sample_inputs)
            cls.orig_logits = gather_positions(
                cls.orig_outputs.prediction_logits,
                cls.sample_inputs['masked_lm_positions'])

            set_random_seeds(seed=42)
            ipu_subclass_model = IpuTFBertForPreTraining(config)
            cls.sub_outputs = ipu_subclass_model(cls.sample_inputs)

            cls.functional_model = convert_tf_bert_model(
                ipu_subclass_model,
                cls.sample_inputs,
                post_process_bert_input_layer,
                replace_layers=True,
                use_outlining=True,
                embedding_serialization_factor=2)
            cls.functional_model.compile(loss={
                'mlm___cls': mlm_loss,
                'nsp___cls': nsp_loss
            })
            cls.func_outputs = cls.functional_model(cls.sample_inputs)
    def test_inference_masked_lm(self):
        model = TFBertForPreTraining.from_pretrained(
            "lysandre/tiny-bert-random")
        input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
        output = model(input_ids)[0]

        expected_shape = [1, 6, 32000]
        self.assertEqual(output.shape, expected_shape)

        print(output[:, :3, :3])

        expected_slice = tf.constant([[
            [-0.05243197, -0.04498899, 0.05512108],
            [-0.07444685, -0.01064632, 0.04352357],
            [-0.05020351, 0.05530146, 0.00700043],
        ]])
        tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
    def test_inference_masked_lm(self):
        model = TFBertForPreTraining.from_pretrained(
            "lysandre/tiny-bert-random")
        input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
        output = model(input_ids)[0]

        expected_shape = [1, 6, 10]
        self.assertEqual(output.shape, expected_shape)

        print(output[:, :3, :3])

        expected_slice = tf.constant([[
            [0.03706957, 0.10124919, 0.03616843],
            [-0.06099961, 0.02266058, 0.00601412],
            [-0.06066202, 0.05684517, 0.02038802],
        ]])
        tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)