Example #1
0
    def create_and_check_model(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        fake_token_labels,
    ):
        model = FunnelModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids,
                       attention_mask=input_mask,
                       token_type_ids=token_type_ids)
        result = model(input_ids, token_type_ids=token_type_ids)
        result = model(input_ids)
        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.d_model))

        model.config.truncate_seq = False
        result = model(input_ids)
        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.d_model))

        model.config.separate_cls = False
        result = model(input_ids)
        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.d_model))
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model):
    # Initialise PyTorch model
    config = FunnelConfig.from_json_file(config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = FunnelBaseModel(config) if base_model else FunnelModel(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_funnel(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)