def main(): parser = argparse.ArgumentParser( description='Train the individual Transformer model') parser.add_argument('--dataset_folder', type=str, default='datasets') parser.add_argument('--dataset_name', type=str, default='zara1') parser.add_argument('--obs', type=int, default=8) parser.add_argument('--preds', type=int, default=12) parser.add_argument('--emb_size', type=int, default=1024) parser.add_argument('--heads', type=int, default=8) parser.add_argument('--layers', type=int, default=6) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--cpu', action='store_true') parser.add_argument('--output_folder', type=str, default='Output') parser.add_argument('--val_size', type=int, default=50) parser.add_argument('--gpu_device', type=str, default="0") parser.add_argument('--verbose', action='store_true') parser.add_argument('--max_epoch', type=int, default=100) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--validation_epoch_start', type=int, default=30) parser.add_argument('--resume_train', action='store_true') parser.add_argument('--delim', type=str, default='\t') parser.add_argument('--name', type=str, default="zara1") args = parser.parse_args() model_name = args.name try: os.mkdir('models') except: pass try: os.mkdir('output') except: pass try: os.mkdir('output/BERT') except: pass try: os.mkdir(f'models/BERT') except: pass try: os.mkdir(f'output/BERT/{args.name}') except: pass try: os.mkdir(f'models/BERT/{args.name}') except: pass log = SummaryWriter('logs/BERT_%s' % model_name) log.add_scalar('eval/mad', 0, 0) log.add_scalar('eval/fad', 0, 0) try: os.mkdir(args.name) except: pass device = torch.device("cuda") if args.cpu or not torch.cuda.is_available(): device = torch.device("cpu") args.verbose = True ## creation of the dataloaders for train and validation train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=True, verbose=args.verbose) val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, verbose=args.verbose) test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, eval=True, verbose=args.verbose) from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW config = BertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='relu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12) model = BertModel(config).to(device) from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS a = NewEmbed(3, 768).to(device) model.set_input_embeddings(a) generator = GeneratorTS(768, 2).to(device) #model.set_output_embeddings(GeneratorTS(1024,2)) tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01) #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005) optim = NoamOpt( 768, 0.1, len(tr_dl), torch.optim.Adam(list(a.parameters()) + list(model.parameters()) + list(generator.parameters()), lr=0, betas=(0.9, 0.98), eps=1e-9)) #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001) epoch = 0 mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0 std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1 while epoch < args.max_epoch: epoch_loss = 0 model.train() for id_b, batch in enumerate(tr_dl): optim.optimizer.zero_grad() r = 0 rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]]) inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) inp = torch.matmul(inp, torch.from_numpy(rot_mat).float().to(device)) trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.ones( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.matmul( torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device), torch.from_numpy(rot_mat).float().to(device))).mean() loss.backward() optim.step() print("epoch %03i/%03i frame %04i / %04i loss: %7.4f" % (epoch, args.max_epoch, id_b, len(tr_dl), loss.item())) epoch_loss += loss.item() #sched.step() log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch) with torch.no_grad(): model.eval() gt = [] pr = [] val_loss = 0 for batch in val_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device)).mean() val_loss += loss.item() gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) log.add_scalar('validation/loss', val_loss / len(val_dl), epoch) log.add_scalar('validation/mad', mad, epoch) log.add_scalar('validation/fad', fad, epoch) model.eval() gt = [] pr = [] for batch in test_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) torch.save(model.state_dict(), "models/BERT/%s/ep_%03i.pth" % (args.name, epoch)) torch.save(generator.state_dict(), "models/BERT/%s/gen_%03i.pth" % (args.name, epoch)) torch.save(a.state_dict(), "models/BERT/%s/emb_%03i.pth" % (args.name, epoch)) log.add_scalar('eval/mad', mad, epoch) log.add_scalar('eval/fad', fad, epoch) epoch += 1 ab = 1
def train(config, bert_config, train_path, dev_path, rel2id, id2rel, tokenizer): if os.path.exists(config.output_dir) is False: os.makedirs(config.output_dir, exist_ok=True) if os.path.exists('./data/train_file.pkl'): train_data = pickle.load(open("./data/train_file.pkl", mode='rb')) else: train_data = data.load_data(train_path, tokenizer, rel2id, num_rels) pickle.dump(train_data, open("./data/train_file.pkl", mode='wb')) dev_data = json.load(open(dev_path)) for sent in dev_data: data.to_tuple(sent) data_manager = data.SPO(train_data) train_sampler = RandomSampler(data_manager) train_data_loader = DataLoader(data_manager, sampler=train_sampler, batch_size=config.batch_size, drop_last=True) num_train_steps = int( len(data_manager) / config.batch_size) * config.max_epoch if config.bert_pretrained_model is not None: logger.info('load bert weight') Bert_model = BertModel.from_pretrained(config.bert_pretrained_model, config=bert_config) else: logger.info('random initialize bert model') Bert_model = BertModel(config=bert_config).init_weights() Bert_model.to(device) submodel = sub_model(config).to(device) objmodel = obj_model(config).to(device) loss_fuc = nn.BCELoss(reduction='none') params = list(Bert_model.parameters()) + list( submodel.parameters()) + list(objmodel.parameters()) optimizer = AdamW(params, lr=config.lr) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(data_manager)) logger.info(" Num Epochs = %d", config.max_epoch) logger.info(" Total train batch size = %d", config.batch_size) logger.info(" Total optimization steps = %d", num_train_steps) logger.info(" Logging steps = %d", config.print_freq) logger.info(" Save steps = %d", config.save_freq) global_step = 0 Bert_model.train() submodel.train() objmodel.train() for _ in range(config.max_epoch): optimizer.zero_grad() epoch_itorator = tqdm(train_data_loader, disable=None) for step, batch in enumerate(epoch_itorator): batch = tuple(t.to(device) for t in batch) input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch bert_output = Bert_model(input_ids, input_masks, segment_ids)[0] pred_sub_heads, pred_sub_tails = submodel( bert_output) # [batch_size, seq_len, 1] pred_obj_heads, pred_obj_tails = objmodel(bert_output, sub_positions) # 计算loss mask = input_masks.view(-1) # loss1 sub_heads = sub_heads.unsqueeze(-1) # [batch_szie, seq_len, 1] sub_tails = sub_tails.unsqueeze(-1) loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1) loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask) loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1) loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask) loss1 = loss1_head + loss1_tail # loss2 loss2_head = loss_fuc(pred_obj_heads, obj_heads).view(-1, obj_heads.shape[-1]) loss2_head = torch.sum( loss2_head * mask.unsqueeze(-1)) / torch.sum(mask) loss2_tail = loss_fuc(pred_obj_tails, obj_tails).view(-1, obj_tails.shape[-1]) loss2_tail = torch.sum( loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask) loss2 = loss2_head + loss2_tail # optimize loss = loss1 + loss2 loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % config.print_freq == 0: logger.info( "epoch : {} step: {} #### loss1: {} loss2: {}".format( _, global_step + 1, loss1.cpu().item(), loss2.cpu().item())) if (global_step + 1) % config.eval_freq == 0: logger.info("***** Running evaluating *****") with torch.no_grad(): Bert_model.eval() submodel.eval() objmodel.eval() P, R, F1 = utils.metric(Bert_model, submodel, objmodel, dev_data, id2rel, tokenizer) logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}') Bert_model.train() submodel.train() objmodel.train() if (global_step + 1) % config.save_freq == 0: # Save a trained model model_name = "pytorch_model_%d" % (global_step + 1) output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file) model_name = "pytorch_model_last" output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file)
class BertForQuestionAnsweringWithCRF(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.hidden_size = self.bert.config.hidden_size self.CRF_fc1 = nn.Sequential( nn.Dropout(0.5), nn.Linear(self.hidden_size, config.num_labels + 2, bias=True), ) self.CRF = CRF(target_size=self.bert.config.num_labels, device=torch.device("cuda")) self.CrossEntropyLoss = nn.CrossEntropyLoss() self.fc2 = nn.Linear(self.hidden_size, 2, bias=True) def forward(self, tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l): ## 字符ID [batch_size, seq_length] tokens_x_2d = torch.LongTensor(tokens_id_l).to(self.device) token_type_ids_2d = torch.LongTensor(token_type_ids_l).to(self.device) # 计算sql_len 不包含[CLS] batch_size, seq_length = tokens_x_2d[:, 1:].size() ## CRF答案ID [batch_size, seq_length] y_2d = torch.LongTensor(answer_seq_label_l).to(self.device)[:, 1:] ## (batch_size,) y_IsQA_2d = torch.LongTensor(IsQA_l).to(self.device) if self.training: # self.training基层的外部类 self.bert.train() output = self.bert( input_ids=tokens_x_2d, token_type_ids=token_type_ids_2d, output_hidden_states=True, return_dict=True) #[batch_size, seq_len, hidden_size] else: self.bert.eval() with torch.no_grad(): output = self.bert(input_ids=tokens_x_2d, token_type_ids=token_type_ids_2d, output_hidden_states=True, return_dict=True) ## [CLS] for IsQA [batch_size, hidden_size] cls_emb = output.last_hidden_state[:, 0, :] IsQA_logits = self.fc2(cls_emb) ## [batch_size, 2] IsQA_loss = self.CrossEntropyLoss.forward(IsQA_logits, y_IsQA_2d) ## [batch_size, 1] IsQA_prediction = IsQA_logits.argmax(dim=-1).unsqueeze(dim=-1) # CRF mask mask = np.ones(shape=[batch_size, seq_length], dtype=np.uint8) mask = torch.ByteTensor(mask).to( self.device) # [batch_size, seq_len, 4] # No [CLS] crf_logits = self.CRF_fc1(output.last_hidden_state[:, 1:, :]) crf_loss = self.CRF.neg_log_likelihood_loss(feats=crf_logits, mask=mask, tags=y_2d) _, CRFprediction = self.CRF.forward(feats=crf_logits, mask=mask) return IsQA_prediction, CRFprediction, IsQA_loss, crf_loss, y_2d, y_IsQA_2d.unsqueeze( dim=-1) # (batch_size,) -> (batch_size, 1)
class DialogBERT(nn.Module): '''Hierarchical BERT for dialog v5 with two features: - Masked context utterances prediction with direct MSE matching of their vectors - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction''' # TODO: 1. Enhance sorting net # 2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead, def __init__(self, args, base_model_name='bert-base-uncased'): super(DialogBERT, self).__init__() if args.language == 'chinese': base_model_name = 'bert-base-chinese' self.tokenizer = BertTokenizer.from_pretrained(base_model_name, cache_dir='./cache/') if args.model_size == 'tiny': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=256, num_hidden_layers=6, num_attention_heads=2, intermediate_size=1024) self.utt_encoder = BertForPreTraining(self.encoder_config) elif args.model_size == 'small': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=8, num_attention_heads=4, intermediate_size=2048) self.utt_encoder = BertForPreTraining(self.encoder_config) else: self.encoder_config = BertConfig.from_pretrained( base_model_name, cache_dir='./cache/') self.utt_encoder = BertForPreTraining.from_pretrained( base_model_name, config=self.encoder_config, cache_dir='./cache/') self.context_encoder = BertModel( self.encoder_config) # context encoder: encode context to vector self.mlm_mode = 'mse' # 'mdn', 'mse' if self.mlm_mode == 'mdn': self.context_mlm_trans = MixtureDensityNetwork( self.encoder_config.hidden_size, self.encoder_config.hidden_size, 3) else: self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config ) # transform context hidden states back to utterance encodings self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) # self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1) self.decoder_config = deepcopy(self.encoder_config) self.decoder_config.is_decoder = True self.decoder_config.add_cross_attention = True self.decoder = BertLMHeadModel(self.decoder_config) def init_weights(self, m): # Initialize Linear Weight for GAN if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.08, 0.08) #nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0.) @classmethod def from_pretrained(self, model_dir): self.encoder_config = BertConfig.from_pretrained(model_dir) self.tokenizer = BertTokenizer.from_pretrained( path.join(model_dir, 'tokenizer'), do_lower_case=args.do_lower_case) self.utt_encoder = BertForPreTraining.from_pretrained( path.join(model_dir, 'utt_encoder')) self.context_encoder = BertForSequenceClassification.from_pretrained( path.join(model_dir, 'context_encoder')) self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config) self.context_mlm_trans.load_state_dict( torch.load(path.join(model_dir, 'context_mlm_trans.pkl'))) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) self.context_order_trans.load_state_dict( torch.load(path.join(model_dir, 'context_order_trans.pkl'))) self.decoder_config = BertConfig.from_pretrained(model_dir) self.decoder = BertLMHeadModel.from_pretrained( path.join(model_dir, 'decoder')) def save_pretrained(self, output_dir): def save_module(model, save_path): torch.save(model_to_save.state_dict(), save_path) def make_list_dirs(dir_list): for dir_ in dir_list: os.makedirs(dir_, exist_ok=True) make_list_dirs([ path.join(output_dir, name) for name in ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder'] ]) model_to_save = self.module if hasattr(self, 'module') else self model_to_save.encoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.tokenizer.save_pretrained( path.join(output_dir, 'tokenizer')) model_to_save.utt_encoder.save_pretrained( path.join(output_dir, 'utt_encoder')) model_to_save.context_encoder.save_pretrained( path.join(output_dir, 'context_encoder')) save_module(model_to_save.context_mlm_trans, path.join(output_dir, 'context_mlm_trans.pkl')) save_module(model_to_save.context_order_trans, path.join(output_dir, 'context_order_trans.pkl')) model_to_save.decoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder')) def utt_encoding(self, context, utts_attn_mask): batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] utts = context.view( -1, max_utt_len) # [(batch_size*diag_len) x max_utt_len] utts_attn_mask = utts_attn_mask.view(-1, max_utt_len) _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask) utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1) return utts_encodings def context_encoding(self, context, utts_attn_mask, ctx_attn_mask): #with torch.no_grad(): utt_encodings = self.utt_encoding(context, utts_attn_mask) context_hiddens, pooled_output, *_ = self.context_encoder( None, ctx_attn_mask, None, None, None, utt_encodings) # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim] return context_hiddens, pooled_output def train_dialog_flow(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the dialog flow model """ self.context_encoder.train() # set the module in training mode. self.context_mlm_trans.train() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) lm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens)) context_lm_targets[context_lm_targets == -100] = 0 ctx_lm_mask = context_lm_targets.sum(2) if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1 lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0] context_lm_targets = context_lm_targets[ctx_lm_mask > 0] context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0] with torch.no_grad(): _, lm_tgt_encodings, *_ = self.utt_encoder.bert( context_lm_targets, context_lm_targets_attn_mask) loss_ctx_mlm = MSELoss()(lm_pred_encodings, lm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) loss = loss_ctx_mlm + loss_ctx_uop return { 'loss': loss, 'loss_ctx_mlm': loss_ctx_lm, 'loss_ctx_uop': loss_ctx_uop } def train_decoder(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the decoder """ self.decoder.train() with torch.no_grad(): context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] == -100 # ignore responses whose contexts are shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) results = {'loss': loss_decoder, 'loss_decoder': loss_decoder} return results def forward(self, context, context_utts_attn_mask, context_attn_mask, context_mlm_targets, context_position_perm_id, context_position_ids, response): self.train() batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train dialog flow modeling context_mlm_targets[context_mlm_targets == -100] = 0 ctx_mlm_mask = context_mlm_targets.sum(2) #[batch_size x num_utts] if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1 ctx_mlm_mask = ctx_mlm_mask > 0 with torch.no_grad(): _, mlm_tgt_encodings, *_ = self.utt_encoder.bert( context_mlm_targets[ctx_mlm_mask], context_utts_attn_mask[ctx_mlm_mask]) if self.mlm_mode == 'mdn': # mixture density network mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi, mlm_pred_normal, mlm_tgt_encodings) else: # simply mean square loss mlm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = MSELoss()( mlm_pred_encodings, mlm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] = -100 # ignore responses whose context was shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder results = { 'loss': loss, 'loss_ctx_mlm': loss_ctx_mlm, 'loss_ctx_uop': loss_ctx_uop, 'loss_decoder': loss_decoder } return results def validate(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): results = self.train_decoder(context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response) return results['loss'].item() def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'): self.eval() device = next(self.parameters()).device context, context_utts_attn_mask, context_attn_mask = [ t.to(device) for t in input_batch[:3] ] ground_truth = input_batch[6].numpy() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) generated = torch.zeros( (num_samples, 1), dtype=torch.long, device=device).fill_(self.tokenizer.cls_token_id) # [batch_sz x 1] (1=seq_len) sample_lens = torch.ones((num_samples, 1), dtype=torch.long, device=device) len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device) for _ in range(max_len): outputs, *_ = self.decoder( generated, generated.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) # [batch_size x seq_len x vocab_size] next_token_logits = outputs[:, -1, :] / self.decoder_config.temperature # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) for i in range(num_samples): for _ in set(generated[i].tolist()): next_token_logits[ i, _] /= self.decoder_config.repetition_penalty filtered_logits = top_k_top_p_filtering( next_token_logits, top_k=self.decoder_config.top_k, top_p=self.decoder_config.top_p) if mode == 'greedy': # greedy sampling: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=num_samples) next_token[len_inc == 0] = self.tokenizer.pad_token_id generated = torch.cat((generated, next_token), dim=1) len_inc = len_inc * ( next_token != self.tokenizer.sep_token_id).long( ) # stop incresing length (set 0 bit) when EOS is encountered if len_inc.sum() < 1: break sample_lens = sample_lens + len_inc # to numpy sample_words = generated.data.cpu().numpy() sample_lens = sample_lens.data.cpu().numpy() context = context.data.cpu().numpy() return sample_words, sample_lens, context, ground_truth # nparray: [repeat x seq_len]
class BertVisdEmbedding(nn.Module): ''' The layer of generate Bert contextual representation ''' def __init__(self, config=None, device=t.device("cpu")): ''' Args: @config: configuration file of internal Bert layer ''' super(BertVisdEmbedding, self).__init__() if config is None: self.bert = BertModel.from_pretrained('bert-base-uncased') else: self.bert = BertModel(config=config) # transformers correspondence self.device = device self.bert_hidden_size = self.bert.config.hidden_size tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.CLS = tokenizer.convert_tokens_to_ids( ['[CLS]'])[0] #ID of the Bert [CLS] token self.SEP = tokenizer.convert_tokens_to_ids( ['[SEP]'])[0] #ID of the Bert [SEP] token self.PAD = tokenizer.convert_tokens_to_ids( ['[PAD]'])[0] #ID of the Bert [PAD] token def make_bert_input(self, content_idxs, content_type, seg_ids): ''' Args: @content_idxs (tensor): Bert IDs of the content. (batch_size, max_seq_len) Note that the max_seq_len is a fixed number due to padding/clamping policy. @content_type (str): whether the content is "question", "history" or "answer". @the initial segment ID: for "question" and "answer", this should be None; for 'history', this is should be well-initialized [0,..,0,1,...,1]. Return: cmp_idx (tensor): [CLS] context_idxs [SEP]. (batch_size, max_seq_len+2) segment_ids (tensor): for "question" and "answer", this should be "1,1,...,1"; for "history", this should be "seg_ids[0], seg_ids, seg_ids[-1]". (batch_size, max_seq_len+2) input_mask (tensor): attention of the real token in content. Note [CLS] and [SEP] are count as real token. (batch_size, q_len + ctx_len + 2) ''' mask = content_idxs != self.PAD #get the mask indicating the non-padding tokens in the content if content_type == 'question' or content_type == 'answer': #question/answer type seg_ids = t.zeros_like(content_idxs, dtype=content_idxs.dtype, device=content_idxs.device) seq_len = mask.sum(dim=1) #(batch_size, ) length of each sequence batch_size, _ = content_idxs.size() content_idxs = t.cat( (content_idxs, t.tensor([[self.PAD]] * batch_size, device=content_idxs.device)), dim=1) #(batch_size, max_seq_len+1) content_idxs[ t.arange(0, batch_size), seq_len] = self.SEP #append [SEP] token to obtain "content_idxs [SEP]" seg_last = seg_ids[t.arange(0, batch_size), seq_len - 1] #get the last segment id of each sequence seg_ids = t.cat( (seg_ids, t.tensor([[0]] * batch_size, device=content_idxs.device)), dim=1) #(batch_size, max_seq_len+1) seg_ids[t.arange(0, batch_size), seq_len] = seg_last #the segment id of the new appended [SEP] content_idxs = t.cat( (t.tensor([[self.CLS]] * batch_size, device=content_idxs.device), content_idxs), dim=1 ) #(batch_size, max_seq_len+2)append [CLS] token to obtain "[CLS] content_idxs [SEP]" seg_ids = t.cat( (seg_ids[:, 0].view(-1, 1), seg_ids), dim=1 ) #(batch_size, max_seq_len+2) extend the first column of the segment id input_mask = (content_idxs != self.PAD).long() #(batch_size, max_seq_len+2) return content_idxs, seg_ids, input_mask def parse_bert_output(self, bert_output, orig_PAD_mask): ''' Args: @bert_output (tensor): Bert output with [CLS] and [SEP] embeddings. (batch_size, 1+max_seq_len+1, bert_hidden_size) @orig_PAD_mask (tensor): 1 for PAD token, 0 for non-PAD token. (batch_size, max_seq_len) Return: bert_enc (tensor): Bert output without [CLS] and [SEP] embeddings, and with zero-embedding for all PAD tokens. (batch_size, max_seq_len, bert_hidden_size) ''' bert_enc = bert_output[:, 1: -1] #(batch_size, max_seq_len, bert_hidden_size) pad_emb = t.zeros( self.bert_hidden_size, device=bert_output.device ) #manually set the embedding of PAD token to be zero #print(bert_enc.size(), orig_PAD_mask.size(), pad_emb.size(), bert_enc.device, orig_PAD_mask.device, pad_emb.device) bert_enc = bert_enc.contiguous() bert_enc[ orig_PAD_mask] = pad_emb #set the PAD token embeddings to be zero. return bert_enc def forward(self, content_idxs, content_type, seg_ids=None): ''' Args: @content_idxs (tensor): Bert IDs of the contents. (batch_size, max_seq_len) Note that the max_seq_len is a fixed number due to padding/clamping policy @content_type (str): whether the tensor is "question", "history" or "answer" Return: bert_ctx_emb (tensor): contextual embedding condition on question. (batch_size, max_seq_len, bert_hidden_size) ''' orig_PAD_mask = content_idxs == self.PAD cmp_idxs, segment_ids, bert_att = self.make_bert_input( content_idxs, content_type, seg_ids) outputs = self.bert(cmp_idxs, segment_ids, bert_att) bert_output = outputs[0] bert_enc = self.parse_bert_output(bert_output, orig_PAD_mask) return bert_enc def train(self, mode=True): ''' Specifically set self.bert into training mode ''' self.training = mode self.bert.train(mode) return self def eval(self): ''' Specifically set self.bert into evaluation mode ''' return self.train(False) def to(self, *args, **kwargs): ''' Override to() interface. ''' print("bert emd to() called!") self = super().to(*args, **kwargs) self.bert = self.bert.to(*args, **kwargs) return self
# download vocab vocab_info = tokenizer vocab_path = download(vocab_info['url'], vocab_info['fname'], vocab_info['chksum'], cachedir=cachedir) ################################################################################################# print('BERT 모델 선언') bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) bertmodel.state_dict(torch.load(model_path)) print("GPU 디바이스 세팅") device = torch.device(ctx) bertmodel.to(device) bertmodel.train() vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_path, padding_token='[PAD]') ################################################################################################# # 파라미터 세팅 tokenizer = get_tokenizer() tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False) max_len = 64 batch_size = 64 warmup_ratio = 0.1 max_grad_norm = 1 log_interval = 200 learning_rate = 5e-5 #################################################################################################
return np.sum(pred_flat == labels_flat) / len(labels_flat) #@title The Training Loop t = [] # Store our loss and accuracy for plotting train_loss_set = [] # trange is a tqdm wrapper around the normal python range for _ in trange(epochs, desc="Epoch"): # Training # Set our model to training mode (as opposed to evaluation mode) model.train() gc.collect() torch.cuda.empty_cache() # Tracking variables tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 # Train the data for one epoch for step, batch in enumerate(train_dataloader): # Add batch to GPU batch = tuple(t.to(device) for t in batch) # Unpack the inputs from our dataloader b_input_ids, b_input_mask, b_labels = batch b_input_ids = torch.tensor(b_input_ids).long() # Clear out the gradients (by default they accumulate)