def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Ejemplo n.º 2
0
def convert_tf_checkpoint_to_pytorch(config):
    tf_checkpoint_path = config["tf_checkpoint_path"]
    bert_config_file = config["bert_config_file"]
    pytorch_dump_path = Path(config["pytorch_dump_path"])
    # 初始化pytorch模型
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # 加载tf权重
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # 保持pytorch模型
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    #     print(f"Building PyTorch model from configuration: {config}")
    #     model = BertForPreTraining(config)
    config.num_labels = 2
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForSequenceClassification(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)
Ejemplo n.º 4
0
    def download(self):
        # Iterate over urls: download, unzip, verify sha256sum
        found_mismatch_sha = False
        for model in self.model_urls:
          url = self.model_urls[model][0]
          file = self.save_path + '/' + self.model_urls[model][1]

          print('Downloading', url)
          response = urllib.request.urlopen(url)
          with open(file, 'wb') as handle:
            handle.write(response.read())

          print('Unzipping', file)
          zip = zipfile.ZipFile(file, 'r')
          extract_to_path = pathlib.Path(self.save_path) / (pathlib.Path(file).stem if model == "bert_tiny_uncased" else "")
          zip.extractall(path=extract_to_path)
          zip.close()

          sha_dict = self.model_sha[model]
          for extracted_file in sha_dict:
            sha = sha_dict[extracted_file]
            if sha != self.sha256sum(file[:-4] + '/' + extracted_file):
              found_mismatch_sha = True
              print('SHA256sum does not match on file:', extracted_file, 'from download url:', url)
            else:
              print(file[:-4] + '/' + extracted_file, '\t', 'verified')

          config = BertConfig.from_json_file(extract_to_path / "bert_config.json")
          print("Building PyTorch model from configuration: {}".format(str(config)))
          model = BertForPreTraining(config)

          # Load weights from tf checkpoint
          load_tf_weights_in_bert(model, config, extract_to_path / "bert_model.ckpt")

          # Save pytorch-model
          print("Save PyTorch model to {}".format(extract_to_path))
          torch.save({'model': model.state_dict(),
                      'optimizer': None,
                      'master params': None,
                      'files': None,
                      'epoch': None,
                      'data_loader': None}, extract_to_path / "ckpt_pretrained.pt")

        if not found_mismatch_sha:
          print("All downloads pass sha256sum verification.")
Ejemplo n.º 5
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    """
    :param tf_checkpoint_path: Path to the TensorFlow checkpoint path.
    :param bert_config_file: The config json file corresponding to the pre-trained BERT model.
    :param pytorch_dump_path: Path to the output PyTorch model.
    :return:
    """
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    '''
        tf_checkpoint_path: ckpt文件
        bert_config_file: json文件
        pytorch_dump_path: pytorch模型保存位置
    '''

    # 初始化pytorch模型
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # 从checkpoint中加载权重
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # 保存pytorch模型
    print("Save Pytorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Ejemplo n.º 7
0
    def __init__(self, bert_config, checkpoint_path=None):
        super(BertForQuestionAnswering, self).__init__(bert_config)
        if config.bert == 'bert':
            self.bert = BertModel(bert_config)
        elif config.bert == 'albert':
            self.bert = AlbertModel(bert_config)
        if checkpoint_path:
            """加载tf模型"""
            self.bert = load_tf_weights_in_bert(
                self.bert, config=None, tf_checkpoint_path=checkpoint_path)

        if config.use_origin_bert == 'dym':
            self.hidden_size = 512
        else:
            self.hidden_size = 768

        self.qa_outputs = nn.Linear(self.hidden_size, 2)  # start/end
        self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
        # self.cls = nn.Linear(self.hidden_size, 2) #for cls
        self.classifier = nn.Linear(768, 1)  #for dym's dense
        self.dense_final = nn.Sequential(nn.Linear(768, self.hidden_size),
                                         nn.ReLU(True))  #动态最后的维度
        # self.classifier = nn.Linear(1024, 1)  #for dym's dense
        # self.dense_final = nn.Sequential(nn.Linear(1024, self.hidden_size), nn.ReLU(True)) #动态最后的维度#large
        if config.lstm:
            num_layers = config.num_layer
            lstm_num = int(self.hidden_size / 2)
            self.lstm = nn.LSTM(
                self.hidden_size,
                lstm_num,
                num_layers,
                batch_first=True,  #第一维度是否为batch_size
                bidirectional=True)  #双向
        elif config.gru:
            num_layers = config.num_layer
            lstm_num = int(self.hidden_size / 2)
            self.lstm = nn.GRU(
                self.hidden_size,
                lstm_num,
                num_layers,
                batch_first=True,  #第一维度是否为batch_size
                bidirectional=True)  #双向

        self.init_weights()