def main(args): logger.info('Checking...') SEED = args.seed logger.info('seed: {}'.format(SEED)) logger.info('model: {}'.format(args.model)) check_manual_seed(SEED) check_args(args) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = import_module('model.' + args.model) model_d = import_module('model.' + 'detector') D = model.Discriminator(config) G = model.Generator(config) E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder if args.loss == 'v1': detector = model_d.Detector(config) else: detector = model_d.Detector_v2(config) logger.info('Discriminator: {}'.format(D)) logger.info('Generator: {}'.format(G)) logger.info('Detector: {}'.format(detector)) if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False D.to(device) G.to(device) E.to(device) detector.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam( D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) optimizer_detector = torch.optim.Adam(detector.parameters(), lr=args.detector_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] detector_total_train_loss = [] all_features = [] result = dict() for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() detector.train() G_train_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 detector_train_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) ood_sample = (y == 0.0) # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D( real_feature, return_feature=True) # discriminator_output = discriminator_output.squeeze() real_loss = adversarial_loss(discriminator_output, valid_label) real_loss.backward(retain_graph=True) if args.do_vis: all_features.append(real_f_vector.detach()) # # train D on fake if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) fake_loss = adversarial_loss(fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() # if args.fine_tune: # optimizer_E.step() # train G optimizer_G.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) gd_loss = adversarial_loss(D_decision, valid_label) fm_loss = torch.abs( torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() optimizer_E.zero_grad() # train detector optimizer_detector.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() if args.loss == 'v1': loss_fake = adversarial_loss( detector(fake_feature), fake_label) # fake sample is ood else: loss_fake = adversarial_loss_v2( detector(fake_feature), fake_label.long().squeeze()) if args.loss == 'v1': loss_real = adversarial_loss(detector(real_feature), y.float()) else: loss_real = adversarial_loss_v2(detector(real_feature), y.long()) if args.detect_loss == 'v1': detector_loss = args.beta * loss_fake + ( 1 - args.beta) * loss_real else: detector_loss = args.beta * loss_fake + loss_real detector_loss = args.sigma * detector_loss detector_loss.backward() optimizer_detector.step() if args.fine_tune: optimizer_E.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() detector_train_loss += detector_loss logger.info('[Epoch {}] Train: D_fake_loss: {}'.format( i, D_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_real_loss: {}'.format( i, D_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_class_loss: {}'.format( i, D_class_loss / n_sample)) logger.info('[Epoch {}] Train: G_train_loss: {}'.format( i, G_train_loss / n_sample)) logger.info('[Epoch {}] Train: FM_train_loss: {}'.format( i, FM_train_loss / n_sample)) logger.info('[Epoch {}] Train: detector_train_loss: {}'.format( i, detector_train_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) detector_total_train_loss.append(detector_train_loss / n_sample) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score best_dev = -early_stopping.best_score if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # G.eval() # D.eval() E.eval() detector.eval() all_detection_preds = [] all_class_preds = [] all_logit = [] for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # all_class_preds.append(classification_output) pass # 只预测判别器 else: # f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # f_vector = D.get_vector(real_feature) if args.loss == 'v1': detector_out = detector(real_feature) all_detection_preds.append(detector_out) else: detector_out = detector(real_feature) all_logit.append(detector_out) all_detection_preds.append( torch.argmax(detector_out, 1)) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] if args.loss == 'v1': all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] else: all_detection_binary_preds = all_detection_preds all_logit = torch.cat(all_logit, 0).cpu() # 计算损失 if args.loss == 'v1': loss = detection_loss(all_detection_preds, all_binary_y.float()) else: loss = detection_loss_v2(all_logit, all_y.long()) result['detection_loss'] = loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy( all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info( metrics.classification_report( all_y, all_class_preds, target_names=processor.id_to_label)) logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc freeze_data['valid_all_y'] = all_y freeze_data['vaild_all_pred'] = all_detection_binary_preds freeze_data['valid_score'] = y_score return result def test(dataset): # load BERT and GAN load_gan_model(D, G, config['gan_save_path']) if args.fine_tune: load_model(E, path=config['bert_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) G.eval() D.eval() E.eval() detector.eval() all_detection_preds = [] all_class_preds = [] all_features = [] all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # all_class_preds.append(classification_output) pass else: if args.loss == 'v1': detector_out = detector(real_feature) all_detection_preds.append(detector_out) else: detector_out = detector(real_feature) all_logit.append(detector_out) all_detection_preds.append( torch.argmax(detector_out, 1)) # if args.do_vis: # all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] if args.loss == 'v1': all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] else: all_detection_binary_preds = all_detection_preds all_logit = torch.cat(all_logit, 0).cpu() # 计算损失 if args.loss == 'v1': loss = detection_loss(all_detection_preds, all_binary_y.float()) else: loss = detection_loss_v2(all_logit, all_y.long()) result['detection_loss'] = loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy( all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info( metrics.classification_report( all_y, all_class_preds, target_names=processor.id_to_label)) logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['all_y'] = all_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['score'] = y_score result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_detection_binary_preds.tolist() freeze_data['test_score'] = y_score return result def get_fake_feature(num_output): """ 生成一定数量的假特征 """ G.eval() fake_features = [] start = 0 batch = args.predict_batch_size with torch.no_grad(): while start < num_output: end = min(num_output, start + batch) if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, size=(end - start, 32, args.G_z_dim))) else: z = FloatTensor( np.random.normal(0, 1, size=(end - start, args.G_z_dim))) fake_feature = G(z) f_vector, _ = D.detect_only(fake_feature, return_feature=True) fake_features.append(f_vector) start += batch return torch.cat(fake_features, 0).cpu().numpy() if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train_result = train(train_dataset, dev_dataset) # save_feature(train_result['all_features'], os.path.join(args.output_dir, 'train_feature')) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # save_feature(test_result['all_features'], os.path.join(args.output_dir, 'test_feature')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) if args.do_vis: # [2 * length, feature_fim] features = np.concatenate([ test_result['all_features'], get_fake_feature(len(test_dataset) // 2) ], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform( features) # [2 * length, 2] # [2 * length, 1] if n_class > 2: labels = np.concatenate([ test_result['all_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) else: labels = np.concatenate([ test_result['all_binary_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) # [2 * length, 3] data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot.png')) fig.show() freeze_data['feature_label'] = data # plot_train_test(train_result['all_features'], test_result['all_features'], args.output_dir) with open(os.path.join(config['output_dir'], 'freeze_data.pkl'), 'wb') as f: pickle.dump(freeze_data, f) df = pd.DataFrame( data={ 'valid_y': freeze_data['valid_all_y'], 'valid_score': freeze_data['valid_score'], }) df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv')) df = pd.DataFrame( data={ 'test_y': freeze_data['test_all_y'], 'test_score': freeze_data['test_score'] }) df.to_csv(os.path.join(config['output_dir'], 'test_score.csv'))
def main(args): logger.info('Checking...') SEED = args.seed check_manual_seed(SEED) check_args(args) logger.info('seed: {}'.format(args.seed)) gross_result['seed'] = args.seed logger.info('Loading config...') bert_config = BertConfig('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['model_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = TextCNN(bert_config, n_class) # Bert encoder if args.fine_tune: model.unfreeze_bert_encoder() else: model.freeze_bert_encoder() model.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) logits = model(token, mask, type_ids) loss = classified_loss(logits, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(eval_result['accuracy']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') # logger.info(eval_result) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) all_pred = [] all_logit = [] total_loss = 0 for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): logit = model(token, mask, type_ids) all_logit.append(logit) all_pred.append(torch.argmax(logit, 1)) total_loss += classified_loss(logit, y.long()) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_pred = torch.cat(all_pred, 0).cpu() all_logit = torch.cat(all_logit, 0).cpu() ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y) report = metrics.classification_report(all_y, all_pred, output_dict=True) result.update(report) y_score = all_logit.softmax(1)[:, 1].tolist() eer = metrics.cal_eer(all_binary_y, y_score) oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_pred, all_binary_y) result['eer'] = eer result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['auc'] = roc_auc_score(all_binary_y, y_score) result['y_score'] = y_score result['all_binary_y'] = all_binary_y freeze_data['valid_all_y'] = all_y freeze_data['vaild_all_pred'] = all_pred freeze_data['valid_score'] = y_score return result def test(dataset): load_model(model, path=config['model_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) all_pred = [] total_loss = 0 all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): logit = model(token, mask, type_ids) all_logit.append(logit) all_pred.append(torch.argmax(logit, 1)) total_loss += classified_loss(logit, y.long()) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_pred = torch.cat(all_pred, 0).cpu() all_logit = torch.cat(all_logit, 0).cpu() # classification report ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y) report = metrics.classification_report(all_y, all_pred, output_dict=True) oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_pred, all_binary_y) result.update(report) # 只有二分类时候ERR才有意义 y_score = all_logit.softmax(1)[:, 1].tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['all_y'] = all_y.tolist() result['all_pred'] = all_pred.tolist() result['all_binary_y'] = all_binary_y freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_pred.tolist() freeze_data['test_score'] = y_score result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['auc'] = roc_auc_score(all_binary_y, y_score) result['y_score'] = y_score return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) # logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) gross_result['eval_eer'] = eval_result['eer'] gross_result['eval_auc'] = eval_result['auc'] gross_result['eval_fpr95'] = ErrorRateAt95Recall( eval_result['all_binary_y'], eval_result['y_score']) gross_result['eval_oos_ind_precision'] = eval_result[ 'oos_ind_precision'] gross_result['eval_oos_ind_recall'] = eval_result['oos_ind_recall'] gross_result['eval_oos_ind_f_score'] = eval_result['oos_ind_f_score'] if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) gross_result['test_eer'] = test_result['eer'] gross_result['test_auc'] = test_result['auc'] gross_result['test_fpr95'] = ErrorRateAt95Recall( test_result['all_binary_y'], test_result['y_score']) gross_result['test_oos_ind_precision'] = test_result[ 'oos_ind_precision'] gross_result['test_oos_ind_recall'] = test_result['oos_ind_recall'] gross_result['test_oos_ind_f_score'] = test_result['oos_ind_f_score'] # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_y'], test_result['all_pred'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_y'], test_result['all_pred'], args.output_dir) with open(os.path.join(config['output_dir'], 'freeze_data.pkl'), 'wb') as f: pickle.dump(freeze_data, f) df = pd.DataFrame( data={ 'valid_y': freeze_data['valid_all_y'], 'valid_score': freeze_data['valid_score'], }) df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv')) df = pd.DataFrame( data={ 'test_y': freeze_data['test_all_y'], 'test_score': freeze_data['test_score'] }) df.to_csv(os.path.join(config['output_dir'], 'test_score.csv')) if args.result != 'no': pd_result = pd.DataFrame(gross_result) if args.seed == 16: pd_result.to_csv(args.result + '_gross_result.csv', index=False) else: pd_result.to_csv(args.result + '_gross_result.csv', index=False, mode='a', header=False) if args.seed == 8192: print(args.result) std_mean(args.result + '_gross_result.csv')
def main(args): check_manual_seed(args.seed) logger.info('seed: {}'.format(args.seed)) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') with open(data_path, 'r', encoding='utf-8') as fp: data = json.load(fp) for type in data: logger.info('{} : {}'.format(type, len(data[type]))) with open(label_path, 'r', encoding='utf-8') as fp: logger.info(json.load(fp)) if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) logger.info('OOSProcessor') elif args.dataset == 'smp': # processor = SMPProcessor(bert_config, maxlen=32) processor = PosSMPProcessor(bert_config, maxlen=32) logger.info('SMPProcessor') else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. processor.load_pos('data/pos.json') logger.info("label_to_id: {}".format(processor.label_to_id)) logger.info("id_to_label: {}".format(processor.id_to_label)) n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) from model.pos_emb_v2 import Pos_emb E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder config['pos_dim'] = args.pos_dim config['batch_size'] = args.train_batch_size config['n_pos'] = len(processor.pos) config['device'] = device config['nhead'] = 2 config['num_layers'] = 1 config['maxlen'] = processor.maxlen print('config', config) print(processor.pos) pos = Pos_emb(config) if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False pos.to(device) E.to(device) # logger.info(('pos_dim: {}, feature_dim'.format(config['pos_dim'], config['feature_dim']))) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer_pos = torch.optim.Adam(pos.parameters(), lr=args.pos_lr) optimizer_E = AdamW(E.parameters(), args.bert_lr) valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] train_loss = [] iteration = 0 for i in range(args.n_epoch): logger.info('***********************************') logger.info('epoch: {}'.format(i)) # Initialize model state pos.train() E.train() total_loss = 0 for sample in tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) optimizer_E.zero_grad() optimizer_pos.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) loss = adversarial_loss(out, y.float()) loss.backward() total_loss += loss.detach() if args.fine_tune: optimizer_E.step() optimizer_pos.step() logger.info('[Epoch {}] Train: loss: {}'.format( i, total_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break # elif signal == 0: # pass # elif signal == 1: # save_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) best_dev = -early_stopping.best_score # 绘制训练损失曲线 from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() detection_loss = torch.nn.BCELoss().to(device) pos.eval() E.eval() all_detection_preds = [] for sample in tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) all_detection_preds.append(out) all_y = LongTensor( dataset.dataset[:, -4].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): # # load BERT and GAN # load_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # load_model(E, path=config['bert_save_path'], model_name='bert') # test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) pos.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) all_detection_preds.append(out) all_y = LongTensor( dataset.dataset[:, -4].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = PosOOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = PosOOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = PosOOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = PosOOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['y_score']) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) beta_log_path = 'beta_log.txt' if os.path.exists(beta_log_path): flag = True else: flag = False with open(beta_log_path, 'a', encoding='utf-8') as f: if flag == False: f.write('seed\tdataset\tdev_eer\ttest_eer\tdata_size\n') line = '\t'.join([ str(config['seed']), str(config['data_file']), str(best_dev), str(test_result['eer']), '100' ]) f.write(line + '\n')
def main(args): logger.info('Checking...') check_manual_seed(args.seed) check_args(args) logger.info('Loading config...') bert_config = BertConfig('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['model_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = BertClassifier(bert_config, config) # Bert encoder if args.fine_tune: model.unfreeze_bert_encoder() else: model.freeze_bert_encoder() model.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() if args.BCE: loss = adversarial_loss(discriminator_output, (y != 0.0).float()) else: loss = classified_loss(discriminator_output, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) detection_loss = torch.nn.BCELoss().to(device) all_detection_preds = [] all_class_preds = [] all_pred = [] all_logit = [] total_loss = 0 for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() all_detection_preds.append(discriminator_output) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds, all_y) result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): load_model(model, path=config['model_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) detection_loss = torch.nn.BCELoss().to(device) all_detection_preds = [] all_features = [] all_pred = [] total_loss = 0 all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds, all_y) result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['all_y'] = all_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['score'] = y_score result['y_score'] = y_score result['all_pred'] = all_detection_binary_preds result['auc'] = roc_auc_score(all_binary_y, y_score) freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_detection_binary_preds.tolist() freeze_data['test_score'] = y_score return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) # output_cases(texts, test_result['all_y'], test_result['all_pred'], # os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['test_logit']) # confusion matrix plot_confusion_matrix(test_result['all_y'], test_result['all_pred'], args.output_dir)
def main(args): check_manual_seed(args.seed) logger.info('seed: {}'.format(args.seed)) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') with open(data_path, 'r', encoding='utf-8') as fp: data = json.load(fp) for type in data: logger.info('{} : {}'.format(type, len(data[type]))) with open(label_path, 'r', encoding='utf-8') as fp: logger.info(json.load(fp)) if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) logger.info('OOSProcessor') elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) logger.info('SMPProcessor') else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. logger.info("label_to_id: {}".format(processor.label_to_id)) logger.info("id_to_label: {}".format(processor.id_to_label)) n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) D_detect = Discriminator(config) D_g = Discriminator(config) G = Generator(config) E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False D_detect.to(device) D_g.to(device) G.to(device) E.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D_detect = torch.optim.Adam( D_detect.parameters(), lr=args.D_detect_lr) # optimizer for discriminator optimizer_D_g = torch.optim.Adam(D_g.parameters(), lr=args.D_g_lr) optimizer_E = AdamW(E.parameters(), args.bert_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] for i in range(args.n_epoch): logger.info('***********************************') logger.info('epoch: {}'.format(i)) # Initialize model state G.train() D_detect.train() D_g.train() E.train() D_g_real_loss = 0 D_g_fake_loss = 0 D_detect_real_loss = 0 D_detect_fake_loss = 0 G_loss = 0 for sample in tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) all_g_D_g_loss = 0 D_gen_real_loss = None D_gen_fake_loss = None # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output for gan_i in range(args.time): # ------------------------- train D_g -------------------------# # train on D_g real id_sample = (y == 1.0) weight = torch.ones(len(id_sample)).to( device) - id_sample * 1.0 # 除去id损失, 只用ood数据 real_loss_func = torch.nn.BCELoss(weight=weight).to(device) optimizer_D_g.zero_grad() D_gen_real_discriminator_output, f_vector = D_g( real_feature) # D_gen_real_loss = adversarial_loss(D_gen_real_discriminator_output, valid_label) # 判别器对真实样本的损失 D_gen_real_loss = real_loss_func( D_gen_real_discriminator_output.squeeze(), valid_label.squeeze()) # train on D_g fake z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() D_gen_fake_discriminator_output, f_vector = D_g( fake_feature) D_gen_fake_loss = adversarial_loss( D_gen_fake_discriminator_output.squeeze(), fake_label.squeeze()) # 判别器对假样本的损失 D_gen_loss = D_gen_real_loss + D_gen_fake_loss D_gen_loss.backward(retain_graph=True) # 保存计算图,生成器还要使用 optimizer_D_g.step() # ------------------------- train G -------------------------# list_g_D_g_loss = [] for gi in range(args.g_time): optimizer_G.zero_grad() z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() D_gen_fake_discriminator_output, f_vector = D_g( fake_feature) g_D_g_loss = adversarial_loss( D_gen_fake_discriminator_output.squeeze(), valid_label.squeeze()) # 生成器欺骗 D_g, 认为是真实样本 g_D_g_loss.backward() optimizer_G.step() all_g_D_g_loss += g_D_g_loss.detach() list_g_D_g_loss.append(g_D_g_loss) # ------------------------- train D_detect_ood -------------------------# # train on real(detect real sample) optimizer_D_detect.zero_grad() ood_real_detect_discriminator_output, f_vector = D_detect( real_feature) ood_real_detect_loss = adversarial_loss( ood_real_detect_discriminator_output.squeeze(), (y != 0.0).float()) # ood 判别器对真实样本的损失 # train on fake(detect fake sample) fake sample is fake id -> ood z = FloatTensor(np.random.normal( 0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() ood_fake_detect_discriminator_output, f_vector = D_detect( fake_feature) ood_fake_detect_loss = adversarial_loss( ood_fake_detect_discriminator_output.squeeze(), fake_label.squeeze()) # 假ood认为是ood样本 D_detect_loss = args.beta * ood_real_detect_loss + ( 1 - args.beta) * ood_fake_detect_loss # 真实样本与假ood样本影响比例 D_detect_loss.backward() optimizer_D_detect.step() if args.fine_tune: optimizer_E.step() global_step += 1 D_g_real_loss += D_gen_real_loss.detach() D_g_fake_loss += D_gen_fake_loss.detach() D_detect_real_loss += ood_real_detect_loss.detach() D_detect_fake_loss += ood_fake_detect_loss.detach() G_loss += all_g_D_g_loss logger.info('[Epoch {}] Train: D_g_real_loss: {}'.format( i, D_g_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_g_fake_loss: {}'.format( i, D_g_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_detect_real_loss: {}'.format( i, D_detect_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_detect_fake_loss: {}'.format( i, D_detect_fake_loss / n_sample)) logger.info('[Epoch {}] Train: G_loss: {}'.format( i, G_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break # elif signal == 0: # pass # elif signal == 1: # save_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) best_dev = -early_stopping.best_score def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() detection_loss = torch.nn.BCELoss().to(device) D_detect.eval() E.eval() all_detection_preds = [] for sample in tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output discriminator_output, f_vector = D_detect(real_feature) all_detection_preds.append(discriminator_output) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): # # load BERT and GAN # load_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # load_model(E, path=config['bert_save_path'], model_name='bert') # test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) D_detect.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output discriminator_output, f_vector = D_detect(real_feature) all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result def get_fake_feature(num_output): """ 生成一定数量的假特征 """ G.eval() fake_features = [] start = 0 batch = args.predict_batch_size with torch.no_grad(): while start < num_output: end = min(num_output, start + batch) z = FloatTensor( np.random.normal(0, 1, size=(end - start, args.G_z_dim))) fake_feature = G(z) discriminator_output, f_vector = D_detect(fake_feature) fake_features.append(f_vector) start += batch return torch.cat(fake_features, 0).cpu().numpy() if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) beta_log_path = 'beta_log.txt' if os.path.exists(beta_log_path): flag = True else: flag = False with open(beta_log_path, 'a', encoding='utf-8') as f: if flag == False: f.write('seed\tbeta\tdataset\tdev_eer\ttest_eer\tdata_size\n') line = '\t'.join([ str(config['seed']), str(config['beta']), str(config['data_file']), str(best_dev), str(test_result['eer']), '100' ]) f.write(line + '\n') if args.do_vis: # [2 * length, feature_fim] features = np.concatenate([ test_result['all_features'], get_fake_feature(len(test_dataset) // 2) ], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform( features) # [2 * length, 2] # [2 * length, 1] if n_class > 2: labels = np.concatenate([ test_result['all_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) else: labels = np.concatenate([ test_result['all_binary_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) # [2 * length, 3] data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot.png')) fig.show()