class BertMultiTask: def __init__(self, job_config, use_pretrain, tokenizer, cache_dir, device, write_log, summary_writer): self.job_config = job_config if not use_pretrain: model_config = self.job_config.get_model_config() bert_config = BertConfig(**model_config) bert_config.vocab_size = len(tokenizer.vocab) self.bert_encoder = BertModel(bert_config) # Use pretrained bert weights else: self.bert_encoder = BertModel.from_pretrained( self.job_config.get_model_file_type(), cache_dir=cache_dir) bert_config = self.bert_encoder.config self.network = MTLRouting(self.bert_encoder, write_log=write_log, summary_writer=summary_writer) #config_data=self.config['data'] # Pretrain Dataset self.network.register_batch(BatchType.PRETRAIN_BATCH, "pretrain_dataset", loss_calculation=BertPretrainingLoss( self.bert_encoder, bert_config)) self.device = device # self.network = self.network.float() # print(f"Bert ID: {id(self.bert_encoder)} from GPU: {dist.get_rank()}") def save(self, filename: str): network = self.network.module return torch.save(network.state_dict(), filename) def load(self, model_state_dict: str): return self.network.module.load_state_dict( torch.load(model_state_dict, map_location=lambda storage, loc: storage)) def move_batch(self, batch, non_blocking=False): return batch.to(self.device, non_blocking) def eval(self): self.network.eval() def train(self): self.network.train() def save_bert(self, filename: str): return torch.save(self.bert_encoder.state_dict(), filename) def to(self, device): assert isinstance(device, torch.device) self.network.to(device) def half(self): self.network.half()
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): """ :param model:BertModel Pytorch model instance to be converted :param ckpt_dir: Tensorflow model directory :param model_name: model name :return: Currently supported HF models: Y BertModel N BertForMaskedLM N BertForPreTraining N BertForMultipleChoice N BertForNextSentencePrediction N BertForSequenceClassification N BertForQuestionAnswering """ tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") var_map = (('layer.', 'layer_'), ('word_embeddings.weight', 'word_embeddings'), ('position_embeddings.weight', 'position_embeddings'), ('token_type_embeddings.weight', 'token_type_embeddings'), ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'), ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel')) if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) state_dict = model.state_dict() def to_tf_var_name(name: str): for patt, repl in iter(var_map): name = name.replace(patt, repl) return 'bert/{}'.format(name) def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): tf_dtype = tf.dtypes.as_dtype(tensor.dtype) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) session.run(tf.variables_initializer([tf_var])) session.run(tf_var) return tf_var tf.reset_default_graph() with tf.Session() as session: for var_name in state_dict: tf_name = to_tf_var_name(var_name) torch_tensor = state_dict[var_name].numpy() if any([x in var_name for x in tensors_to_transpose]): torch_tensor = torch_tensor.T tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf.keras.backend.set_value(tf_var, torch_tensor) tf_weight = session.run(tf_var) print("Successfully created {}: {}".format( tf_name, np.allclose(tf_weight, torch_tensor))) saver = tf.train.Saver(tf.trainable_variables()) saver.save( session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): """ :param model:BertModel Pytorch model instance to be converted :param ckpt_dir: Tensorflow model directory :param model_name: model name :return: Currently supported HF models: Y BertModel N BertForMaskedLM N BertForPreTraining N BertForMultipleChoice N BertForNextSentencePrediction N BertForSequenceClassification N BertForQuestionAnswering """ tensors_to_transopse = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") var_map = (('layer.', 'layer_'), ('word_embeddings.weight', 'word_embeddings'), ('position_embeddings.weight', 'position_embeddings'), ('token_type_embeddings.weight', 'token_type_embeddings'), ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'), ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel')) if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) session = tf.Session() state_dict = model.state_dict() tf_vars = [] def to_tf_var_name(name: str): for patt, repl in iter(var_map): name = name.replace(patt, repl) return 'bert/{}'.format(name) def assign_tf_var(tensor: np.ndarray, name: str): tmp_var = tf.Variable(initial_value=tensor) tf_var = tf.get_variable(dtype=tmp_var.dtype, shape=tmp_var.shape, name=name) op = tf.assign(ref=tf_var, value=tmp_var) session.run(tf.variables_initializer([tmp_var, tf_var])) session.run(fetches=[op, tf_var]) return tf_var for var_name in state_dict: tf_name = to_tf_var_name(var_name) torch_tensor = state_dict[var_name].numpy() if any([x in var_name for x in tensors_to_transopse]): torch_tensor = torch_tensor.T tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name) tf_vars.append(tf_tensor) print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name)))) saver = tf.train.Saver(tf_vars) saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))