def init_bert_model_with_teacher( student: BertModel, teacher: BertModel, layers_to_transfer: List[int] = None, ) -> BertModel: """Initialize student model with teacher layers. Args: student (BertModel): Student model. teacher (BertModel): Teacher model. layers_to_transfer (List[int], optional): Defines which layers will be transfered. If None then will transfer last layers. Defaults to None. Returns: BertModel: [description] """ teacher_hidden_size = teacher.config.hidden_size student_hidden_size = student.config.hidden_size if teacher_hidden_size != student_hidden_size: raise Exception("Teacher and student hidden size should be the same") teacher_layers_num = teacher.config.num_hidden_layers student_layers_num = student.config.num_hidden_layers if layers_to_transfer is None: layers_to_transfer = list( range(teacher_layers_num - student_layers_num, teacher_layers_num)) prefix_teacher = list(teacher.state_dict().keys())[0].split(".")[0] prefix_student = list(student.state_dict().keys())[0].split(".")[0] student_sd = _extract_layers( teacher_model=teacher, layers=layers_to_transfer, ) student.load_state_dict(student_sd) return student
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path): # Instantiate model logger.info(f"Loading model based on config from {config_path}...") config = BertConfig.from_json_file(config_path) model = BertModel(config) # Load weights from checkpoint logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...") load_tf2_weights_in_bert(model, tf_checkpoint_path, config) # Save pytorch-model logger.info(f"Saving PyTorch model to {pytorch_dump_path}...") torch.save(model.state_dict(), pytorch_dump_path)
def check_compability(torch_model: BertModel, tf_model: TFBertModel): torch_weights = [] for k, v in torch_model.state_dict().items(): if k == "embeddings.position_ids": print("im here") continue if not k.startswith("embeddings.") and k.endswith(".weight"): torch_weights.append(v.t().numpy()) else: torch_weights.append(v.numpy()) torch_weights[1], torch_weights[2] = torch_weights[2], torch_weights[1] tf_weights = tf_model.get_weights() check = [(torch_weight == tf_weight).all() for torch_weight, tf_weight in zip(torch_weights, tf_weights)] return all(check)
def train(config, bert_config, train_path, dev_path, rel2id, id2rel, tokenizer): if os.path.exists(config.output_dir) is False: os.makedirs(config.output_dir, exist_ok=True) if os.path.exists('./data/train_file.pkl'): train_data = pickle.load(open("./data/train_file.pkl", mode='rb')) else: train_data = data.load_data(train_path, tokenizer, rel2id, num_rels) pickle.dump(train_data, open("./data/train_file.pkl", mode='wb')) dev_data = json.load(open(dev_path)) for sent in dev_data: data.to_tuple(sent) data_manager = data.SPO(train_data) train_sampler = RandomSampler(data_manager) train_data_loader = DataLoader(data_manager, sampler=train_sampler, batch_size=config.batch_size, drop_last=True) num_train_steps = int( len(data_manager) / config.batch_size) * config.max_epoch if config.bert_pretrained_model is not None: logger.info('load bert weight') Bert_model = BertModel.from_pretrained(config.bert_pretrained_model, config=bert_config) else: logger.info('random initialize bert model') Bert_model = BertModel(config=bert_config).init_weights() Bert_model.to(device) submodel = sub_model(config).to(device) objmodel = obj_model(config).to(device) loss_fuc = nn.BCELoss(reduction='none') params = list(Bert_model.parameters()) + list( submodel.parameters()) + list(objmodel.parameters()) optimizer = AdamW(params, lr=config.lr) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(data_manager)) logger.info(" Num Epochs = %d", config.max_epoch) logger.info(" Total train batch size = %d", config.batch_size) logger.info(" Total optimization steps = %d", num_train_steps) logger.info(" Logging steps = %d", config.print_freq) logger.info(" Save steps = %d", config.save_freq) global_step = 0 Bert_model.train() submodel.train() objmodel.train() for _ in range(config.max_epoch): optimizer.zero_grad() epoch_itorator = tqdm(train_data_loader, disable=None) for step, batch in enumerate(epoch_itorator): batch = tuple(t.to(device) for t in batch) input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch bert_output = Bert_model(input_ids, input_masks, segment_ids)[0] pred_sub_heads, pred_sub_tails = submodel( bert_output) # [batch_size, seq_len, 1] pred_obj_heads, pred_obj_tails = objmodel(bert_output, sub_positions) # 计算loss mask = input_masks.view(-1) # loss1 sub_heads = sub_heads.unsqueeze(-1) # [batch_szie, seq_len, 1] sub_tails = sub_tails.unsqueeze(-1) loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1) loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask) loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1) loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask) loss1 = loss1_head + loss1_tail # loss2 loss2_head = loss_fuc(pred_obj_heads, obj_heads).view(-1, obj_heads.shape[-1]) loss2_head = torch.sum( loss2_head * mask.unsqueeze(-1)) / torch.sum(mask) loss2_tail = loss_fuc(pred_obj_tails, obj_tails).view(-1, obj_tails.shape[-1]) loss2_tail = torch.sum( loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask) loss2 = loss2_head + loss2_tail # optimize loss = loss1 + loss2 loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % config.print_freq == 0: logger.info( "epoch : {} step: {} #### loss1: {} loss2: {}".format( _, global_step + 1, loss1.cpu().item(), loss2.cpu().item())) if (global_step + 1) % config.eval_freq == 0: logger.info("***** Running evaluating *****") with torch.no_grad(): Bert_model.eval() submodel.eval() objmodel.eval() P, R, F1 = utils.metric(Bert_model, submodel, objmodel, dev_data, id2rel, tokenizer) logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}') Bert_model.train() submodel.train() objmodel.train() if (global_step + 1) % config.save_freq == 0: # Save a trained model model_name = "pytorch_model_%d" % (global_step + 1) output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file) model_name = "pytorch_model_last" output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file)
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): """ :param model:BertModel Pytorch model instance to be converted :param ckpt_dir: Tensorflow model directory :param model_name: model name :return: Currently supported HF models: Y BertModel N BertForMaskedLM N BertForPreTraining N BertForMultipleChoice N BertForNextSentencePrediction N BertForSequenceClassification N BertForQuestionAnswering """ tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") var_map = ( ("layer.", "layer_"), ("word_embeddings.weight", "word_embeddings"), ("position_embeddings.weight", "position_embeddings"), ("token_type_embeddings.weight", "token_type_embeddings"), (".", "/"), ("LayerNorm/weight", "LayerNorm/gamma"), ("LayerNorm/bias", "LayerNorm/beta"), ("weight", "kernel"), ) if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) state_dict = model.state_dict() def to_tf_var_name(name: str): for patt, repl in iter(var_map): name = name.replace(patt, repl) return "bert/{}".format(name) def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): tf_dtype = tf.dtypes.as_dtype(tensor.dtype) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) session.run(tf.variables_initializer([tf_var])) session.run(tf_var) return tf_var tf.reset_default_graph() with tf.Session() as session: for var_name in state_dict: tf_name = to_tf_var_name(var_name) torch_tensor = state_dict[var_name].numpy() if any([x in var_name for x in tensors_to_transpose]): torch_tensor = torch_tensor.T tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf.keras.backend.set_value(tf_var, torch_tensor) tf_weight = session.run(tf_var) print("Successfully created {}: {}".format( tf_name, np.allclose(tf_weight, torch_tensor))) saver = tf.train.Saver(tf.trainable_variables()) saver.save( session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(): parser = argparse.ArgumentParser( description='Train the individual Transformer model') parser.add_argument('--dataset_folder', type=str, default='datasets') parser.add_argument('--dataset_name', type=str, default='zara1') parser.add_argument('--obs', type=int, default=8) parser.add_argument('--preds', type=int, default=12) parser.add_argument('--emb_size', type=int, default=1024) parser.add_argument('--heads', type=int, default=8) parser.add_argument('--layers', type=int, default=6) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--cpu', action='store_true') parser.add_argument('--output_folder', type=str, default='Output') parser.add_argument('--val_size', type=int, default=50) parser.add_argument('--gpu_device', type=str, default="0") parser.add_argument('--verbose', action='store_true') parser.add_argument('--max_epoch', type=int, default=100) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--validation_epoch_start', type=int, default=30) parser.add_argument('--resume_train', action='store_true') parser.add_argument('--delim', type=str, default='\t') parser.add_argument('--name', type=str, default="zara1") args = parser.parse_args() model_name = args.name try: os.mkdir('models') except: pass try: os.mkdir('output') except: pass try: os.mkdir('output/BERT') except: pass try: os.mkdir(f'models/BERT') except: pass try: os.mkdir(f'output/BERT/{args.name}') except: pass try: os.mkdir(f'models/BERT/{args.name}') except: pass log = SummaryWriter('logs/BERT_%s' % model_name) log.add_scalar('eval/mad', 0, 0) log.add_scalar('eval/fad', 0, 0) try: os.mkdir(args.name) except: pass device = torch.device("cuda") if args.cpu or not torch.cuda.is_available(): device = torch.device("cpu") args.verbose = True ## creation of the dataloaders for train and validation train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=True, verbose=args.verbose) val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, verbose=args.verbose) test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, eval=True, verbose=args.verbose) from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW config = BertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='relu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12) model = BertModel(config).to(device) from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS a = NewEmbed(3, 768).to(device) model.set_input_embeddings(a) generator = GeneratorTS(768, 2).to(device) #model.set_output_embeddings(GeneratorTS(1024,2)) tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01) #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005) optim = NoamOpt( 768, 0.1, len(tr_dl), torch.optim.Adam(list(a.parameters()) + list(model.parameters()) + list(generator.parameters()), lr=0, betas=(0.9, 0.98), eps=1e-9)) #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001) epoch = 0 mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0 std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1 while epoch < args.max_epoch: epoch_loss = 0 model.train() for id_b, batch in enumerate(tr_dl): optim.optimizer.zero_grad() r = 0 rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]]) inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) inp = torch.matmul(inp, torch.from_numpy(rot_mat).float().to(device)) trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.ones( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.matmul( torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device), torch.from_numpy(rot_mat).float().to(device))).mean() loss.backward() optim.step() print("epoch %03i/%03i frame %04i / %04i loss: %7.4f" % (epoch, args.max_epoch, id_b, len(tr_dl), loss.item())) epoch_loss += loss.item() #sched.step() log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch) with torch.no_grad(): model.eval() gt = [] pr = [] val_loss = 0 for batch in val_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device)).mean() val_loss += loss.item() gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) log.add_scalar('validation/loss', val_loss / len(val_dl), epoch) log.add_scalar('validation/mad', mad, epoch) log.add_scalar('validation/fad', fad, epoch) model.eval() gt = [] pr = [] for batch in test_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) torch.save(model.state_dict(), "models/BERT/%s/ep_%03i.pth" % (args.name, epoch)) torch.save(generator.state_dict(), "models/BERT/%s/gen_%03i.pth" % (args.name, epoch)) torch.save(a.state_dict(), "models/BERT/%s/emb_%03i.pth" % (args.name, epoch)) log.add_scalar('eval/mad', mad, epoch) log.add_scalar('eval/fad', fad, epoch) epoch += 1 ab = 1
class RenamingModelHybrid(nn.Module): def __init__(self, vocab, top_k, config, device): super(RenamingModelHybrid, self).__init__() self.vocab = vocab self.top_k = top_k self.source_vocab_size = len(self.vocab.source_tokens) + 1 self.graph_encoder = GraphASTEncoder.build( config['encoder']['graph_encoder']) self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][ 'hidden_size'] self.emb_size = 256 state_dict = torch.load( 'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth', map_location=device) keys_to_delete = [ "cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight", "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias", "cls.seq_relationship.weight", "cls.seq_relationship.bias" ] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue name = k[5:] # remove `bert.` new_state_dict[name] = v bert_config = BertConfig(vocab_size=self.source_vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4) self.bert_encoder = BertModel(bert_config) self.bert_encoder.load_state_dict(new_state_dict) self.target_vocab_size = len(self.vocab.all_subtokens) + 1 bert_config = BertConfig(vocab_size=self.target_vocab_size, max_position_embeddings=1000, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4, is_decoder=True) self.bert_decoder = BertModel(bert_config) state_dict = torch.load( 'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth', map_location=device) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue if 'crossattention' in k: continue name = k[5:] # remove `bert.` new_state_dict[name] = v for key in new_state_dict: self.bert_decoder.state_dict()[key].copy_(new_state_dict[key]) self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size, self.emb_size) self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size) self.fc_final.weight.data = state_dict['model'][ 'cls.predictions.decoder.weight'] def forward(self, src_tokens, src_mask, variable_ids, target_tokens, graph_input): encoder_attention_mask = torch.ones_like(src_tokens).float().to( src_tokens.device) encoder_attention_mask[src_tokens == PAD_ID] = 0.0 assert torch.max(src_tokens) < self.source_vocab_size assert torch.min(src_tokens) >= 0 assert torch.max(target_tokens) < self.target_vocab_size assert torch.min(target_tokens) >= 0 encoder_output = self.bert_encoder( input_ids=src_tokens, attention_mask=encoder_attention_mask)[0] graph_output = self.graph_encoder(graph_input) variable_emb = graph_output['variable_encoding'] graph_embedding = torch.gather( variable_emb, 1, variable_ids.unsqueeze(2).repeat( 1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2) full_enc_output = self.enc_graph_map( torch.cat((encoder_output, graph_embedding), dim=2)) decoder_attention_mask = torch.ones_like(target_tokens).float().to( target_tokens.device) decoder_attention_mask[target_tokens == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=target_tokens, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output, encoder_attention_mask=encoder_attention_mask)[0] predictions = self.fc_final(decoder_output) return predictions def predict(self, src_tokens, src_mask, variable_ids, graph_input, approx=False): end_token = self.vocab.all_subtokens.word2id['</s>'] start_token = self.vocab.all_subtokens.word2id['<s>'] batch_size = src_tokens.shape[0] encoder_attention_mask = torch.ones_like(src_tokens).float().to( src_tokens.device) encoder_attention_mask[src_tokens == PAD_ID] = 0.0 assert torch.max(src_tokens) < self.source_vocab_size assert torch.min(src_tokens) >= 0 encoder_output = self.bert_encoder( input_ids=src_tokens, attention_mask=encoder_attention_mask)[0] graph_output = self.graph_encoder(graph_input) variable_emb = graph_output['variable_encoding'] graph_embedding = torch.gather( variable_emb, 1, variable_ids.unsqueeze(2).repeat( 1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2) full_enc_output = self.enc_graph_map( torch.cat((encoder_output, graph_embedding), dim=2)) source_vocab_to_target = { self.vocab.source_tokens.word2id[t]: self.vocab.all_subtokens.word2id[t] for t in self.vocab.source_tokens.word2id.keys() } src_target_maps = [] confidences = [] for i in range(batch_size): if src_tokens[i][0] != start_token: input_sequence = torch.zeros(src_tokens.shape[1] + 1).to( src_tokens.device) input_mask = torch.zeros(src_mask.shape[1] + 1).to( src_mask.device) input_sequence[1:] = src_tokens[i] input_mask[1:] = src_mask[i] else: input_sequence = src_tokens[i] input_mask = src_mask[i] num_vars = int(input_mask.sum()) seq_len = torch.sum((input_sequence != PAD_ID).long()) generated_seqs = torch.zeros(1, min( seq_len + 10 * num_vars, 1000)).long().to(src_tokens.device) source_marker = 0 gen_markers = torch.LongTensor([0]).to(generated_seqs.device) prior_probs = torch.FloatTensor([0]).to(generated_seqs.device) candidate_maps = [{}] for _ in range(num_vars): # Filling up the known (non-identifier) tokens while source_marker < seq_len and input_mask[ source_marker] != 1: token = input_sequence[source_marker] values = source_vocab_to_target[token.item( )] * torch.ones_like(gen_markers).to(generated_seqs.device) generated_seqs = torch.scatter(generated_seqs, 1, gen_markers.unsqueeze(1), values.unsqueeze(1)) source_marker += 1 gen_markers += 1 if source_marker >= seq_len: break curr_var = input_sequence[source_marker].item() if curr_var in candidate_maps[0]: if approx is True: source_marker += 1 continue # If we've seen this variable before, just use the previous predictions and update the scores # Note - it's enough to check candidate_maps[0] because if it is in the first map, it is in all of them orig_markers = gen_markers.clone() for j in range(len(candidate_maps)): pred = candidate_maps[j][curr_var] generated_seqs[j][gen_markers[j]:gen_markers[j] + len(pred)] = torch.LongTensor( pred).to(generated_seqs.device) gen_markers[j] += len(pred) decoder_attention_mask = torch.ones_like( generated_seqs).float().to(generated_seqs.device) decoder_attention_mask[generated_seqs == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=generated_seqs, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output[i].unsqueeze(0), encoder_attention_mask=encoder_attention_mask[i]. unsqueeze(0))[0] probabilities = F.log_softmax( self.fc_final(decoder_output), dim=-1) # Add up the scores of the token at the __next__ time step scores = torch.zeros(generated_seqs.shape[0]).to( generated_seqs.device) active = torch.ones(generated_seqs.shape[0]).long().to( generated_seqs.device) temp_markers = orig_markers while torch.sum(active) != 0: position_probs = torch.gather( probabilities, 1, (temp_markers - 1).reshape(-1, 1, 1).repeat( 1, 1, probabilities.shape[2])).squeeze(1) curr_tokens = torch.gather(generated_seqs, 1, temp_markers.unsqueeze(1)) tok_probs = torch.gather(position_probs, 1, curr_tokens).squeeze(1) tok_probs *= active scores += tok_probs active *= (temp_markers != (gen_markers - 1)).long() temp_markers += active # Update the prior probabilities prior_probs = prior_probs + scores else: # You encounter a new variable which hasn't been seen before # Generate <beam_width> possibilities for its name generated_seqs, gen_markers, prior_probs, candidate_maps = self.beam_search( generated_seqs, gen_markers, prior_probs, candidate_maps, curr_var, full_enc_output[i].unsqueeze(0), encoder_attention_mask[i].unsqueeze(0), beam_width=5, top_k=self.top_k) source_marker += 1 final_ind = torch.argmax(prior_probs) confidence = torch.max(prior_probs).item() src_target_map = candidate_maps[final_ind] src_target_maps.append(src_target_map) confidences.append(confidence) return src_target_maps, confidences def beam_search(self, generated_seqs, gen_markers, prior_probs, candidate_maps, curr_var, full_enc_output, encoder_attention_mask, beam_width=5, top_k=10): if generated_seqs.shape[0] * beam_width < top_k: beam_width = top_k active = torch.ones_like(gen_markers).to(gen_markers.device) beam_alpha = 0.7 end_token = self.vocab.all_subtokens.word2id['</s>'] candidate_maps = candidate_maps orig_markers = gen_markers.clone() for _ in range(10): # Predict at most 10 subtokens decoder_attention_mask = torch.ones_like( generated_seqs).float().to(generated_seqs.device) decoder_attention_mask[generated_seqs == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=generated_seqs, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output, encoder_attention_mask=encoder_attention_mask)[0] probabilities = F.log_softmax(self.fc_final(decoder_output), dim=-1) # Gather the predictions at the current markers # (gen_marker - 1) because prediction happens one step ahead probabilities = torch.gather( probabilities, 1, (gen_markers - 1).reshape(-1, 1, 1).repeat( 1, 1, probabilities.shape[2])).squeeze(1) probs, preds = probabilities.sort(dim=-1, descending=True) probs *= active.unsqueeze( 1) # Set log prob of non-active ones to 0 preds[ active == 0] = end_token # Set preds of non-active ones to the end token (ie, remain unchanged) # Repeat active ones only once. Repeat the rest beam_width no. of times. filter_mask = torch.ones( (preds.shape[0], beam_width)).long().to(preds.device) filter_mask *= active.unsqueeze(1) filter_mask[:, 0][active == 0] = 1 filter_mask = filter_mask.reshape(-1) preds = preds[:, :beam_width].reshape(-1)[filter_mask == 1] probs = probs[:, :beam_width].reshape(-1)[filter_mask == 1] generated_seqs = torch.repeat_interleave(generated_seqs, beam_width, dim=0)[filter_mask == 1] orig_markers = torch.repeat_interleave(orig_markers, beam_width, dim=0)[filter_mask == 1] gen_markers = torch.repeat_interleave(gen_markers, beam_width, dim=0)[filter_mask == 1] active = torch.repeat_interleave(active, beam_width, dim=0)[filter_mask == 1] prior_probs = torch.repeat_interleave(prior_probs, beam_width, dim=0)[filter_mask == 1] candidate_maps = [ item.copy() for item in candidate_maps for _ in range(beam_width) ] candidate_maps = [ candidate_maps[i] for i in range(len(candidate_maps)) if filter_mask[i] == 1 ] generated_seqs.scatter_(1, gen_markers.unsqueeze(1), preds.unsqueeze(1)) # lengths = (gen_markers - gen_marker + 1).float() # penalties = torch.pow(5 + lengths, beam_alpha) / math.pow(6, beam_alpha) penalties = torch.ones_like(probs).to(probs.device) updated_probs = probs + prior_probs sort_inds = (updated_probs / penalties).argsort(descending=True) updated_probs = updated_probs[sort_inds] prior_probs = updated_probs[:top_k] new_preds = preds[sort_inds[:top_k]] generated_seqs = generated_seqs[sort_inds[:top_k]] gen_markers = gen_markers[sort_inds[:top_k]] active = active[sort_inds[:top_k]] orig_markers = orig_markers[sort_inds[:top_k]] candidate_maps = [ candidate_maps[ind.item()] for ind in sort_inds[:top_k] ] active = active * (new_preds != end_token).long() gen_markers += active if torch.sum(active) == 0: break # gen_markers are pointing at the end_token. Move them one ahead gen_markers += 1 assert generated_seqs.shape[0] == top_k for i in range(top_k): candidate_maps[i][curr_var] = generated_seqs[i][ orig_markers[i]:gen_markers[i]].cpu().tolist() return generated_seqs, gen_markers, prior_probs, candidate_maps
model_info = pytorch_kobert model_path = download(model_info['url'], model_info['fname'], model_info['chksum'], cachedir=cachedir) # download vocab vocab_info = tokenizer vocab_path = download(vocab_info['url'], vocab_info['fname'], vocab_info['chksum'], cachedir=cachedir) ################################################################################################# print('BERT 모델 선언') bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) bertmodel.state_dict(torch.load(model_path)) print("GPU 디바이스 세팅") device = torch.device(ctx) bertmodel.to(device) bertmodel.train() vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_path, padding_token='[PAD]') ################################################################################################# # 파라미터 세팅 tokenizer = get_tokenizer() tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False) max_len = 64 batch_size = 64
def _extract_layers( teacher_model: BertModel, layers: List[int], prefix_teacher="bert", prefix_student="bert", encoder_name="encoder", ): state_dict = teacher_model.state_dict() compressed_sd = {} # extract embeddings for w in ["word_embeddings", "position_embeddings"]: compressed_sd[f"{prefix_student}.embeddings.{w}.weight"] = state_dict[ f"{prefix_teacher}.embeddings.{w}.weight"] for w in ["weight", "bias"]: compressed_sd[ f"{prefix_student}.embeddings.LayerNorm.{w}"] = state_dict[ f"{prefix_teacher}.embeddings.LayerNorm.{w}"] # extract encoder for std_idx, teacher_idx in enumerate(layers): for w in ["weight", "bias"]: compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.attention.q_lin.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.query.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.attention.k_lin.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.key.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.attention.v_lin.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.value.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.attention.out_lin.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.output.dense.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.sa_layer_norm.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.ffn.lin1.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.intermediate.dense.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.ffn.lin2.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.output.dense.{w}" # noqa: E501 ] compressed_sd[ f"{prefix_student}.encoder.layer.{std_idx}.output_layer_norm.{w}" # noqa: E501 ] = state_dict[ f"{prefix_teacher}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}" # noqa: E501 ] # extract vocab compressed_sd["cls.predictions.decoder.weight"] = state_dict[ "cls.predictions.decoder.weight"] compressed_sd["cls.predictions.bias"] = state_dict["cls.predictions.bias"] for w in ["weight", "bias"]: compressed_sd[f"vocab_transform.{w}"] = state_dict[ f"cls.predictions.transform.dense.{w}"] compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[ f"cls.predictions.transform.LayerNorm.{w}"] return compressed_sd