def load_from_tf(config, tf_path): model = BertForQuestionAnswering(config) model.classifier = model.qa_outputs # This part is copied from HuggingFace Transformers with a fix to bypass an error init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] for name, shape in init_vars: # print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) names.append(name) arrays.append(array) for name, array in zip(names, arrays): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any(n in ["adam_v", "adam_m", "global_step"] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): scope_names = re.split(r"_(\d+)", m_name) else: scope_names = [m_name] if scope_names[0] == "kernel" or scope_names[0] == "gamma": pointer = getattr(pointer, "weight") elif scope_names[0] == "output_bias" or scope_names[0] == "beta": pointer = getattr(pointer, "bias") elif scope_names[0] == "output_weights": pointer = getattr(pointer, "weight") elif scope_names[0] == "squad": pointer = getattr( pointer, "classifier") # This line is causing the issue else: try: pointer = getattr(pointer, scope_names[0]) except AttributeError: print("Skipping {}".format("/".join(name))) continue if len(scope_names) >= 2: num = int(scope_names[1]) pointer = pointer[num] if m_name[-11:] == "_embeddings": pointer = getattr(pointer, "weight") elif m_name == "kernel": array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: e.args += (pointer.shape, array.shape) raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) model.qa_outputs = model.classifier del model.classifier return model
def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker: # We load a sequence classification model first -- again, as a workaround. Refactor. try: model = AutoModelForSequenceClassification.from_pretrained(options.model_name) except OSError: model = AutoModelForSequenceClassification.from_pretrained(options.model_name, from_tf=True) fixed_model = BertForQuestionAnswering(model.config) fixed_model.qa_outputs = model.classifier fixed_model.bert = model.bert device = torch.device(options.device) model = fixed_model.to(device).eval() tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name, do_lower_case=options.do_lower_case) return QuestionAnsweringTransformerReranker(model, tokenizer)