class BertMultiTask:
    def __init__(self, job_config, use_pretrain, tokenizer, cache_dir, device,
                 write_log, summary_writer):
        self.job_config = job_config

        if not use_pretrain:
            model_config = self.job_config.get_model_config()
            bert_config = BertConfig(**model_config)
            bert_config.vocab_size = len(tokenizer.vocab)

            self.bert_encoder = BertModel(bert_config)
        # Use pretrained bert weights
        else:
            self.bert_encoder = BertModel.from_pretrained(
                self.job_config.get_model_file_type(), cache_dir=cache_dir)
            bert_config = self.bert_encoder.config

        self.network = MTLRouting(self.bert_encoder,
                                  write_log=write_log,
                                  summary_writer=summary_writer)

        #config_data=self.config['data']

        # Pretrain Dataset
        self.network.register_batch(BatchType.PRETRAIN_BATCH,
                                    "pretrain_dataset",
                                    loss_calculation=BertPretrainingLoss(
                                        self.bert_encoder, bert_config))

        self.device = device
        # self.network = self.network.float()
        # print(f"Bert ID: {id(self.bert_encoder)}  from GPU: {dist.get_rank()}")

    def save(self, filename: str):
        network = self.network.module
        return torch.save(network.state_dict(), filename)

    def load(self, model_state_dict: str):
        return self.network.module.load_state_dict(
            torch.load(model_state_dict,
                       map_location=lambda storage, loc: storage))

    def move_batch(self, batch, non_blocking=False):
        return batch.to(self.device, non_blocking)

    def eval(self):
        self.network.eval()

    def train(self):
        self.network.train()

    def save_bert(self, filename: str):
        return torch.save(self.bert_encoder.state_dict(), filename)

    def to(self, device):
        assert isinstance(device, torch.device)
        self.network.to(device)

    def half(self):
        self.network.half()
Exemplo n.º 2
0
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str,
                                     model_name: str):
    """
    :param model:BertModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:

    Currently supported HF models:
        Y BertModel
        N BertForMaskedLM
        N BertForPreTraining
        N BertForMultipleChoice
        N BertForNextSentencePrediction
        N BertForSequenceClassification
        N BertForQuestionAnswering
    """

    tensors_to_transpose = ("dense.weight", "attention.self.query",
                            "attention.self.key", "attention.self.value")

    var_map = (('layer.', 'layer_'), ('word_embeddings.weight',
                                      'word_embeddings'),
               ('position_embeddings.weight', 'position_embeddings'),
               ('token_type_embeddings.weight', 'token_type_embeddings'),
               ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'),
               ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel'))

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return 'bert/{}'.format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype,
                                 shape=tensor.shape,
                                 name=name,
                                 initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].numpy()
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor,
                                   name=tf_name,
                                   session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(
                tf_name, np.allclose(tf_weight, torch_tensor)))

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(
            session,
            os.path.join(ckpt_dir,
                         model_name.replace("-", "_") + ".ckpt"))
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str,
                                     model_name: str):
    """
    :param model:BertModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:

    Currently supported HF models:
        Y BertModel
        N BertForMaskedLM
        N BertForPreTraining
        N BertForMultipleChoice
        N BertForNextSentencePrediction
        N BertForSequenceClassification
        N BertForQuestionAnswering
    """

    tensors_to_transopse = ("dense.weight", "attention.self.query",
                            "attention.self.key", "attention.self.value")

    var_map = (('layer.', 'layer_'), ('word_embeddings.weight',
                                      'word_embeddings'),
               ('position_embeddings.weight', 'position_embeddings'),
               ('token_type_embeddings.weight', 'token_type_embeddings'),
               ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'),
               ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel'))

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    session = tf.Session()
    state_dict = model.state_dict()
    tf_vars = []

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return 'bert/{}'.format(name)

    def assign_tf_var(tensor: np.ndarray, name: str):
        tmp_var = tf.Variable(initial_value=tensor)
        tf_var = tf.get_variable(dtype=tmp_var.dtype,
                                 shape=tmp_var.shape,
                                 name=name)
        op = tf.assign(ref=tf_var, value=tmp_var)
        session.run(tf.variables_initializer([tmp_var, tf_var]))
        session.run(fetches=[op, tf_var])
        return tf_var

    for var_name in state_dict:
        tf_name = to_tf_var_name(var_name)
        torch_tensor = state_dict[var_name].numpy()
        if any([x in var_name for x in tensors_to_transopse]):
            torch_tensor = torch_tensor.T
        tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name)
        tf_vars.append(tf_tensor)
        print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))

    saver = tf.train.Saver(tf_vars)
    saver.save(session,
               os.path.join(ckpt_dir,
                            model_name.replace("-", "_") + ".ckpt"))