def main(args): logger.info('Checking...') SEED = args.seed logger.info('seed: {}'.format(SEED)) logger.info('model: {}'.format(args.model)) check_manual_seed(SEED) check_args(args) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = import_module('model.' + args.model) model_d = import_module('model.' + 'detector') D = model.Discriminator(config) G = model.Generator(config) E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder if args.loss == 'v1': detector = model_d.Detector(config) else: detector = model_d.Detector_v2(config) logger.info('Discriminator: {}'.format(D)) logger.info('Generator: {}'.format(G)) logger.info('Detector: {}'.format(detector)) if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False D.to(device) G.to(device) E.to(device) detector.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam( D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) optimizer_detector = torch.optim.Adam(detector.parameters(), lr=args.detector_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] detector_total_train_loss = [] all_features = [] result = dict() for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() detector.train() G_train_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 detector_train_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) ood_sample = (y == 0.0) # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D( real_feature, return_feature=True) # discriminator_output = discriminator_output.squeeze() real_loss = adversarial_loss(discriminator_output, valid_label) real_loss.backward(retain_graph=True) if args.do_vis: all_features.append(real_f_vector.detach()) # # train D on fake if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) fake_loss = adversarial_loss(fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() # if args.fine_tune: # optimizer_E.step() # train G optimizer_G.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) gd_loss = adversarial_loss(D_decision, valid_label) fm_loss = torch.abs( torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() optimizer_E.zero_grad() # train detector optimizer_detector.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() if args.loss == 'v1': loss_fake = adversarial_loss( detector(fake_feature), fake_label) # fake sample is ood else: loss_fake = adversarial_loss_v2( detector(fake_feature), fake_label.long().squeeze()) if args.loss == 'v1': loss_real = adversarial_loss(detector(real_feature), y.float()) else: loss_real = adversarial_loss_v2(detector(real_feature), y.long()) if args.detect_loss == 'v1': detector_loss = args.beta * loss_fake + ( 1 - args.beta) * loss_real else: detector_loss = args.beta * loss_fake + loss_real detector_loss = args.sigma * detector_loss detector_loss.backward() optimizer_detector.step() if args.fine_tune: optimizer_E.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() detector_train_loss += detector_loss logger.info('[Epoch {}] Train: D_fake_loss: {}'.format( i, D_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_real_loss: {}'.format( i, D_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_class_loss: {}'.format( i, D_class_loss / n_sample)) logger.info('[Epoch {}] Train: G_train_loss: {}'.format( i, G_train_loss / n_sample)) logger.info('[Epoch {}] Train: FM_train_loss: {}'.format( i, FM_train_loss / n_sample)) logger.info('[Epoch {}] Train: detector_train_loss: {}'.format( i, detector_train_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) detector_total_train_loss.append(detector_train_loss / n_sample) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score best_dev = -early_stopping.best_score if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # G.eval() # D.eval() E.eval() detector.eval() all_detection_preds = [] all_class_preds = [] all_logit = [] for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # all_class_preds.append(classification_output) pass # 只预测判别器 else: # f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # f_vector = D.get_vector(real_feature) if args.loss == 'v1': detector_out = detector(real_feature) all_detection_preds.append(detector_out) else: detector_out = detector(real_feature) all_logit.append(detector_out) all_detection_preds.append( torch.argmax(detector_out, 1)) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] if args.loss == 'v1': all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] else: all_detection_binary_preds = all_detection_preds all_logit = torch.cat(all_logit, 0).cpu() # 计算损失 if args.loss == 'v1': loss = detection_loss(all_detection_preds, all_binary_y.float()) else: loss = detection_loss_v2(all_logit, all_y.long()) result['detection_loss'] = loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy( all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info( metrics.classification_report( all_y, all_class_preds, target_names=processor.id_to_label)) logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc freeze_data['valid_all_y'] = all_y freeze_data['vaild_all_pred'] = all_detection_binary_preds freeze_data['valid_score'] = y_score return result def test(dataset): # load BERT and GAN load_gan_model(D, G, config['gan_save_path']) if args.fine_tune: load_model(E, path=config['bert_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) G.eval() D.eval() E.eval() detector.eval() all_detection_preds = [] all_class_preds = [] all_features = [] all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) # all_detection_preds.append(discriminator_output) # all_class_preds.append(classification_output) pass else: if args.loss == 'v1': detector_out = detector(real_feature) all_detection_preds.append(detector_out) else: detector_out = detector(real_feature) all_logit.append(detector_out) all_detection_preds.append( torch.argmax(detector_out, 1)) # if args.do_vis: # all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] if args.loss == 'v1': all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] else: all_detection_binary_preds = all_detection_preds all_logit = torch.cat(all_logit, 0).cpu() # 计算损失 if args.loss == 'v1': loss = detection_loss(all_detection_preds, all_binary_y.float()) else: loss = detection_loss_v2(all_logit, all_y.long()) result['detection_loss'] = loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy( all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info( metrics.classification_report( all_y, all_class_preds, target_names=processor.id_to_label)) logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['all_y'] = all_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['score'] = y_score result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_detection_binary_preds.tolist() freeze_data['test_score'] = y_score return result def get_fake_feature(num_output): """ 生成一定数量的假特征 """ G.eval() fake_features = [] start = 0 batch = args.predict_batch_size with torch.no_grad(): while start < num_output: end = min(num_output, start + batch) if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, size=(end - start, 32, args.G_z_dim))) else: z = FloatTensor( np.random.normal(0, 1, size=(end - start, args.G_z_dim))) fake_feature = G(z) f_vector, _ = D.detect_only(fake_feature, return_feature=True) fake_features.append(f_vector) start += batch return torch.cat(fake_features, 0).cpu().numpy() if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train_result = train(train_dataset, dev_dataset) # save_feature(train_result['all_features'], os.path.join(args.output_dir, 'train_feature')) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # save_feature(test_result['all_features'], os.path.join(args.output_dir, 'test_feature')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) if args.do_vis: # [2 * length, feature_fim] features = np.concatenate([ test_result['all_features'], get_fake_feature(len(test_dataset) // 2) ], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform( features) # [2 * length, 2] # [2 * length, 1] if n_class > 2: labels = np.concatenate([ test_result['all_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) else: labels = np.concatenate([ test_result['all_binary_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) # [2 * length, 3] data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot.png')) fig.show() freeze_data['feature_label'] = data # plot_train_test(train_result['all_features'], test_result['all_features'], args.output_dir) with open(os.path.join(config['output_dir'], 'freeze_data.pkl'), 'wb') as f: pickle.dump(freeze_data, f) df = pd.DataFrame( data={ 'valid_y': freeze_data['valid_all_y'], 'valid_score': freeze_data['valid_score'], }) df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv')) df = pd.DataFrame( data={ 'test_y': freeze_data['test_all_y'], 'test_score': freeze_data['test_score'] }) df.to_csv(os.path.join(config['output_dir'], 'test_score.csv'))
def main(args): logger.info('Checking...') SEED = args.seed check_manual_seed(SEED) check_args(args) logger.info('seed: {}'.format(args.seed)) gross_result['seed'] = args.seed logger.info('Loading config...') bert_config = BertConfig('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['model_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = TextCNN(bert_config, n_class) # Bert encoder if args.fine_tune: model.unfreeze_bert_encoder() else: model.freeze_bert_encoder() model.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) logits = model(token, mask, type_ids) loss = classified_loss(logits, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(eval_result['accuracy']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') # logger.info(eval_result) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) all_pred = [] all_logit = [] total_loss = 0 for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): logit = model(token, mask, type_ids) all_logit.append(logit) all_pred.append(torch.argmax(logit, 1)) total_loss += classified_loss(logit, y.long()) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_pred = torch.cat(all_pred, 0).cpu() all_logit = torch.cat(all_logit, 0).cpu() ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y) report = metrics.classification_report(all_y, all_pred, output_dict=True) result.update(report) y_score = all_logit.softmax(1)[:, 1].tolist() eer = metrics.cal_eer(all_binary_y, y_score) oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_pred, all_binary_y) result['eer'] = eer result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['auc'] = roc_auc_score(all_binary_y, y_score) result['y_score'] = y_score result['all_binary_y'] = all_binary_y freeze_data['valid_all_y'] = all_y freeze_data['vaild_all_pred'] = all_pred freeze_data['valid_score'] = y_score return result def test(dataset): load_model(model, path=config['model_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) all_pred = [] total_loss = 0 all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): logit = model(token, mask, type_ids) all_logit.append(logit) all_pred.append(torch.argmax(logit, 1)) total_loss += classified_loss(logit, y.long()) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_pred = torch.cat(all_pred, 0).cpu() all_logit = torch.cat(all_logit, 0).cpu() # classification report ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y) report = metrics.classification_report(all_y, all_pred, output_dict=True) oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_pred, all_binary_y) result.update(report) # 只有二分类时候ERR才有意义 y_score = all_logit.softmax(1)[:, 1].tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['all_y'] = all_y.tolist() result['all_pred'] = all_pred.tolist() result['all_binary_y'] = all_binary_y freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_pred.tolist() freeze_data['test_score'] = y_score result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['auc'] = roc_auc_score(all_binary_y, y_score) result['y_score'] = y_score return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) # logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) gross_result['eval_eer'] = eval_result['eer'] gross_result['eval_auc'] = eval_result['auc'] gross_result['eval_fpr95'] = ErrorRateAt95Recall( eval_result['all_binary_y'], eval_result['y_score']) gross_result['eval_oos_ind_precision'] = eval_result[ 'oos_ind_precision'] gross_result['eval_oos_ind_recall'] = eval_result['oos_ind_recall'] gross_result['eval_oos_ind_f_score'] = eval_result['oos_ind_f_score'] if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) gross_result['test_eer'] = test_result['eer'] gross_result['test_auc'] = test_result['auc'] gross_result['test_fpr95'] = ErrorRateAt95Recall( test_result['all_binary_y'], test_result['y_score']) gross_result['test_oos_ind_precision'] = test_result[ 'oos_ind_precision'] gross_result['test_oos_ind_recall'] = test_result['oos_ind_recall'] gross_result['test_oos_ind_f_score'] = test_result['oos_ind_f_score'] # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_y'], test_result['all_pred'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_y'], test_result['all_pred'], args.output_dir) with open(os.path.join(config['output_dir'], 'freeze_data.pkl'), 'wb') as f: pickle.dump(freeze_data, f) df = pd.DataFrame( data={ 'valid_y': freeze_data['valid_all_y'], 'valid_score': freeze_data['valid_score'], }) df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv')) df = pd.DataFrame( data={ 'test_y': freeze_data['test_all_y'], 'test_score': freeze_data['test_score'] }) df.to_csv(os.path.join(config['output_dir'], 'test_score.csv')) if args.result != 'no': pd_result = pd.DataFrame(gross_result) if args.seed == 16: pd_result.to_csv(args.result + '_gross_result.csv', index=False) else: pd_result.to_csv(args.result + '_gross_result.csv', index=False, mode='a', header=False) if args.seed == 8192: print(args.result) std_mean(args.result + '_gross_result.csv')
def main(config, needs_save): os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices seed = check_manual_seed(config.training.seed) print('Using manual seed: {}'.format(seed)) if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS': patient_ids = TRAIN_PATIENT_IDS elif config.dataset.patient_ids == 'TEST_PATIENT_IDS': patient_ids = TEST_PATIENT_IDS else: raise NotImplementedError data_loader = get_data_loader( mode=config.dataset.mode, dataset_name=config.dataset.name, patient_ids=patient_ids, root_dir_path=config.dataset.root_dir_path, use_augmentation=config.dataset.use_augmentation, batch_size=config.dataset.batch_size, num_workers=config.dataset.num_workers, image_size=config.dataset.image_size) E = Encoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.enc_filters, activation=config.model.enc_activation).float() D = Decoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.dec_filters, activation=config.model.dec_activation, final_activation=config.model.dec_final_activation).float() if config.model.enc_spectral_norm: apply_spectral_norm(E) if config.model.dec_spectral_norm: apply_spectral_norm(D) if config.training.use_cuda: E.cuda() D.cuda() E = nn.DataParallel(E) D = nn.DataParallel(D) if config.model.saved_E: print(config.model.saved_E) E.load_state_dict(torch.load(config.model.saved_E)) if config.model.saved_D: print(config.model.saved_D) D.load_state_dict(torch.load(config.model.saved_D)) print(E) print(D) e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()), config.optimizer.enc_lr, [0.9, 0.9999]) d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), config.optimizer.dec_lr, [0.9, 0.9999]) alpha = config.training.alpha beta = config.training.beta margin = config.training.margin batch_size = config.dataset.batch_size fixed_z = torch.randn(calc_latent_dim(config)) if 'ssim' in config.training.loss: ssim_loss = pytorch_ssim.SSIM(window_size=11) def l_recon(recon: torch.Tensor, target: torch.Tensor): if config.training.loss == 'l2': loss = F.mse_loss(recon, target, reduction='sum') elif config.training.loss == 'l1': loss = F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) elif config.training.loss == 'ssim+l1': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim+l2': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.mse_loss(recon, target, reduction='sum') else: raise NotImplementedError return beta * loss / batch_size def l_reg(mu: torch.Tensor, log_var: torch.Tensor): loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var)) return loss / batch_size def update(engine, batch): E.train() D.train() image = norm(batch['image']) if config.training.use_cuda: image = image.cuda(non_blocking=True).float() else: image = image.float() e_optim.zero_grad() d_optim.zero_grad() z, z_mu, z_logvar = E(image) x_r = D(z) l_vae_reg = l_reg(z_mu, z_logvar) l_vae_recon = l_recon(x_r, image) l_vae_total = l_vae_reg + l_vae_recon l_vae_total.backward() e_optim.step() d_optim.step() if config.training.use_cuda: torch.cuda.synchronize() return { 'TotalLoss': l_vae_total.item(), 'EncodeLoss': l_vae_reg.item(), 'ReconLoss': l_vae_recon.item(), } output_dir = get_output_dir_path(config) trainer = Engine(update) timer = Timer(average=True) monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss'] for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def save_config(engine): config_to_save = defaultdict(dict) for key, child in config._asdict().items(): for k, v in child._asdict().items(): config_to_save[key][k] = v config_to_save['seed'] = seed config_to_save['output_dir'] = output_dir print('Training starts by the following configuration: ', config_to_save) if needs_save: save_path = os.path.join(output_dir, 'config.json') with open(save_path, 'w') as f: json.dump(config_to_save, f) @trainer.on(Events.ITERATION_COMPLETED) def show_logs(engine): if (engine.state.iteration - 1) % config.save.log_iter_interval == 0: columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=config.training.n_epochs, i=engine.state.iteration, max_i=len(data_loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.EPOCH_COMPLETED) def save_logs(engine): if needs_save: fname = os.path.join(output_dir, 'logs.tsv') columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def save_images(engine): if needs_save: if engine.state.epoch % config.save.save_epoch_interval == 0: image = norm(engine.state.batch['image']) with torch.no_grad(): z, _, _ = E(image) x_r = D(z) x_p = D(fixed_z) image = denorm(image).detach().cpu() x_r = denorm(x_r).detach().cpu() x_p = denorm(x_p).detach().cpu() image = image[:config.save.n_save_images, ...] x_r = x_r[:config.save.n_save_images, ...] x_p = x_p[:config.save.n_save_images, ...] save_path = os.path.join( output_dir, 'result_{}.png'.format(engine.state.epoch)) save_image(torch.cat([image, x_r, x_p]).data, save_path) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.save.n_saved, create_dir=True, ) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'E': E, 'D': D }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.training.n_epochs, config.training.n_epochs * len(data_loader))) trainer.run(data_loader, config.training.n_epochs)
def main(config, needs_save, i): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices assert config.train_dataset.root_dir_path == config.val_dataset.root_dir_path # train_patient_ids, val_patient_ids = divide_patients(config.train_dataset.root_dir_path) train_patient_ids, val_patient_ids = get_cv_splits( config.train_dataset.root_dir_path, i) seed = check_manual_seed() print('Using seed: {}'.format(seed)) class_name_to_index = config.label_to_id._asdict() index_to_class_name = {v: k for k, v in class_name_to_index.items()} train_data_loader = get_data_loader( mode='train', dataset_name=config.train_dataset.dataset_name, root_dir_path=config.train_dataset.root_dir_path, patient_ids=train_patient_ids, batch_size=config.train_dataset.batch_size, num_workers=config.train_dataset.num_workers, volume_size=config.train_dataset.volume_size, ) val_data_loader = get_data_loader( mode='val', dataset_name=config.val_dataset.dataset_name, root_dir_path=config.val_dataset.root_dir_path, patient_ids=val_patient_ids, batch_size=config.val_dataset.batch_size, num_workers=config.val_dataset.num_workers, volume_size=config.val_dataset.volume_size, ) model = ResUNet( input_dim=config.model.input_dim, output_dim=config.model.output_dim, filters=config.model.filters, ) print(model) if config.run.use_cuda: model.cuda() model = nn.DataParallel(model) if config.model.saved_model: print('Loading saved model: {}'.format(config.model.saved_model)) model.load_state_dict(torch.load(config.model.saved_model)) else: print('Initializing weights.') init_weights(model, init_type=config.model.init_type) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.optimizer.lr, betas=config.optimizer.betas, weight_decay=config.optimizer.weight_decay) dice_loss = SoftDiceLoss() focal_loss = FocalLoss( gamma=config.focal_loss.gamma, alpha=config.focal_loss.alpha, ) active_contour_loss = ActiveContourLoss( weight=config.active_contour_loss.weight, ) dice_coeff = DiceCoefficient( n_classes=config.metric.n_classes, index_to_class_name=index_to_class_name, ) one_hot_encoder = OneHotEncoder( n_classes=config.metric.n_classes, ).forward def train(engine, batch): adjust_learning_rate(optimizer, engine.state.epoch, initial_lr=config.optimizer.lr, n_epochs=config.run.n_epochs, gamma=config.optimizer.gamma) model.train() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() optimizer.zero_grad() output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) l_total = l_dice + l_focal + l_active_contour l_total.backward() optimizer.step() m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures def evaluate(engine, batch): model.eval() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures output_dir_path = get_output_dir_path(config, i) trainer = Engine(train) evaluator = Engine(evaluate) timer = Timer(average=True) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir_path, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.run.n_epochs + 1, create_dir=True, ) monitoring_metrics = ['SoftDiceLoss', 'FocalLoss', 'ActiveContourLoss'] monitoring_metrics += class_name_to_index.keys() for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( evaluator, metric) pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=monitoring_metrics) pbar.attach(evaluator, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def call_save_config(engine): if needs_save: return save_config(engine, config, seed, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_save_logs(engine): if needs_save: return save_logs('train', engine, config, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_print_times(engine): return print_times(engine, config, pbar, timer) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): evaluator.run(val_data_loader, 1) if needs_save: save_logs('val', evaluator, config, output_dir_path) save_images(evaluator, trainer.state.epoch) def save_images(evaluator, epoch): batch = evaluator.state.batch image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): pred = model(image) output = torch.ones_like(label) mask_0 = pred[:, 0, ...] < 0.5 mask_1 = pred[:, 1, ...] < 0.5 mask_2 = pred[:, 2, ...] < 0.5 mask = mask_0 * mask_1 * mask_2 pred = pred.argmax(1) output += pred output[mask] = 0 image = image.detach().cpu().float() label = label.detach().cpu().unsqueeze(1).float() output = output.detach().cpu().unsqueeze(1).float() z_middle = image.shape[-1] // 2 image = image[:, 0, ..., z_middle] label = label[:, 0, ..., z_middle] output = output[:, 0, ..., z_middle] if config.save.image_vmax is not None: vmax = config.save.image_vmax else: vmax = image.max() if config.save.image_vmin is not None: vmin = config.save.image_vmin else: vmin = image.min() image = np.clip(image, vmin, vmax) image -= vmin image /= (vmax - vmin) image *= 255.0 save_path = os.path.join(output_dir_path, 'result_{}.png'.format(epoch)) save_images_via_plt(image, label, output, config.save.n_save_images, config, save_path) if needs_save: trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optim': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.run.n_epochs, config.run.n_epochs * len(train_data_loader))) trainer.run(train_data_loader, config.run.n_epochs)
def main(args): logger.info('Checking...') check_manual_seed(args.seed) check_args(args) logger.info('Loading config...') bert_config = BertConfig('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['model_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = BertClassifier(bert_config, config) # Bert encoder if args.fine_tune: model.unfreeze_bert_encoder() else: model.freeze_bert_encoder() model.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() if args.BCE: loss = adversarial_loss(discriminator_output, (y != 0.0).float()) else: loss = classified_loss(discriminator_output, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) detection_loss = torch.nn.BCELoss().to(device) all_detection_preds = [] all_class_preds = [] all_pred = [] all_logit = [] total_loss = 0 for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() all_detection_preds.append(discriminator_output) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds, all_y) result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): load_model(model, path=config['model_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() model.eval() # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) detection_loss = torch.nn.BCELoss().to(device) all_detection_preds = [] all_features = [] all_pred = [] total_loss = 0 all_logit = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) with torch.no_grad(): f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds, all_y) result['ind_class_acc'] = ind_class_acc result['loss'] = total_loss / n_sample result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['all_y'] = all_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['score'] = y_score result['y_score'] = y_score result['all_pred'] = all_detection_binary_preds result['auc'] = roc_auc_score(all_binary_y, y_score) freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_detection_binary_preds.tolist() freeze_data['test_score'] = y_score return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) # output_cases(texts, test_result['all_y'], test_result['all_pred'], # os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['test_logit']) # confusion matrix plot_confusion_matrix(test_result['all_y'], test_result['all_pred'], args.output_dir)
def main(args): check_manual_seed(args.seed) logger.info('seed: {}'.format(args.seed)) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') with open(data_path, 'r', encoding='utf-8') as fp: data = json.load(fp) for type in data: logger.info('{} : {}'.format(type, len(data[type]))) with open(label_path, 'r', encoding='utf-8') as fp: logger.info(json.load(fp)) if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) logger.info('OOSProcessor') elif args.dataset == 'smp': # processor = SMPProcessor(bert_config, maxlen=32) processor = PosSMPProcessor(bert_config, maxlen=32) logger.info('SMPProcessor') else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. processor.load_pos('data/pos.json') logger.info("label_to_id: {}".format(processor.label_to_id)) logger.info("id_to_label: {}".format(processor.id_to_label)) n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) from model.pos_emb_v2 import Pos_emb E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder config['pos_dim'] = args.pos_dim config['batch_size'] = args.train_batch_size config['n_pos'] = len(processor.pos) config['device'] = device config['nhead'] = 2 config['num_layers'] = 1 config['maxlen'] = processor.maxlen print('config', config) print(processor.pos) pos = Pos_emb(config) if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False pos.to(device) E.to(device) # logger.info(('pos_dim: {}, feature_dim'.format(config['pos_dim'], config['feature_dim']))) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer_pos = torch.optim.Adam(pos.parameters(), lr=args.pos_lr) optimizer_E = AdamW(E.parameters(), args.bert_lr) valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] train_loss = [] iteration = 0 for i in range(args.n_epoch): logger.info('***********************************') logger.info('epoch: {}'.format(i)) # Initialize model state pos.train() E.train() total_loss = 0 for sample in tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) optimizer_E.zero_grad() optimizer_pos.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) loss = adversarial_loss(out, y.float()) loss.backward() total_loss += loss.detach() if args.fine_tune: optimizer_E.step() optimizer_pos.step() logger.info('[Epoch {}] Train: loss: {}'.format( i, total_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break # elif signal == 0: # pass # elif signal == 1: # save_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) best_dev = -early_stopping.best_score # 绘制训练损失曲线 from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() detection_loss = torch.nn.BCELoss().to(device) pos.eval() E.eval() all_detection_preds = [] for sample in tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) all_detection_preds.append(out) all_y = LongTensor( dataset.dataset[:, -4].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): # # load BERT and GAN # load_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # load_model(E, path=config['bert_save_path'], model_name='bert') # test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) pos.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) all_detection_preds.append(out) all_y = LongTensor( dataset.dataset[:, -4].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = PosOOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = PosOOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = PosOOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = PosOOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['y_score']) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) beta_log_path = 'beta_log.txt' if os.path.exists(beta_log_path): flag = True else: flag = False with open(beta_log_path, 'a', encoding='utf-8') as f: if flag == False: f.write('seed\tdataset\tdev_eer\ttest_eer\tdata_size\n') line = '\t'.join([ str(config['seed']), str(config['data_file']), str(best_dev), str(test_result['eer']), '100' ]) f.write(line + '\n')
def main(args): logger.info('Checking...') print('torch.cuda.is_available:', torch.cuda.is_available()) # print('torch.cuda.current_device:', torch.cuda.current_device()) logger.info('device: {}'.format(device)) logger.info('ood: {}'.format(args.ood)) SEED = args.seed gross_result['seed'] = args.seed logger.info('seed: {}'.format(SEED)) logger.info('model: {}'.format(args.model)) check_manual_seed(SEED) check_args(args) if 0 <= args.beta <= 1: logger.info('beta: {}'.format(args.beta)) logger.info('mode: {}'.format(args.mode)) logger.info('maxlen: {}'.format(args.maxlen)) logger.info('minlen: {}'.format(args.minlen)) logger.info('optim_mode: {}'.format(args.optim_mode)) logger.info('length_weight: {}'.format(args.length_weight)) logger.info('sample_weight: {}'.format(args.sample_weight)) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') # with open(data_path, 'r', encoding='utf-8') as fp: # source = json.load(fp) # for type in source: # n = 0 # n_id = 0 # n_ood = 0 # text_len = {} # for line in source[type]: # if line['domain'] == 'chat': # n_ood += 1 # else: # n_id += 1 # n += 1 # text_len[len(line['text'])] = text_len.get(len(line['text']), 0) + 1 # print(type, n) # print('ood', n_ood) # print('id', n_id) # print(sorted(text_len.items(), key=lambda d: d[0], reverse=False)) if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) elif args.dataset == 'smp': if args.mode == -1: processor = SMPProcessor(bert_config, maxlen=32) print('processor') else: processor = SMPProcessor_v3(bert_config, maxlen=32) print('processor_v3') else: raise ValueError('The dataset {} is not supported.'.format(args.dataset)) processor.load_label(label_path) # Adding label_to_id and id_to_label ot processor. n_class = len(processor.id_to_label) print('label: ', processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) model = import_module('model.' + args.model) D = model.Discriminator(config) G = model.Generator(config) E = BertModel.from_pretrained(bert_config['PreTrainModelDir']) # Bert encoder # logger.info('Discriminator: {}'.format(D)) # logger.info('Generator: {}'.format(G)) if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False D.to(device) G.to(device) E.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam(D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] all_features = [] result = dict() for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() G_train_loss = 0 G_d_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 G_features = [] for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) if args.dataset == 'smp': token, mask, type_ids, knowledge_tag, y = sample batch = len(token) ood_sample = (y == 0.0).float() # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # length weight length_sample = FloatTensor([0] * batch) if args.minlen != -1: short_sample = (mask[:, args.minlen] == 0).float() length_sample = length_sample.add(short_sample) if args.maxlen != -1: long_sample = mask[:, args.maxlen].float() length_sample = length_sample.add(long_sample) # get knowledge sample weight by knowledge_tag exclude_sample = knowledge_tag # initailize weight weight = torch.ones(batch).to(device) # optimize without weights if args.optim_mode == 0 and 0 <= args.beta <= 1: weight -= ood_sample * args.beta # only optimize length by weight if args.optim_mode == 1: # set all exclude_sample's weight to 0 weight -= exclude_sample length_sample -= exclude_sample length_sample = (length_sample > 0).float() weight -= length_sample * (1 - args.length_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= exclude_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch).to(device) temp -= ood_sample * args.beta weight *= temp # only optimize sample by weight if args.optim_mode == 2: # set all length_sample's weight to 0 weight -= length_sample exclude_sample -= length_sample exclude_sample = (exclude_sample > 0).float() weight -= exclude_sample * (1 - args.sample_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= length_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch) temp -= ood_sample * args.beta weight *= temp # optimize length and sample by weight # if args.optim_mode == 3: # alpha = 0.5 # beta = 0.5 # weight = torch.ones(len(length_sample)).to(device) \ # - alpha * length_sample * (1 - args.length_weight) \ # - beta * exclude_sample * (1 - args.sample_weight) if args.dataset == 'oos-eval': token, mask, type_ids, y = sample batch = len(token) ood_sample = (y == 0.0).float() # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # length weight length_sample = FloatTensor([0] * batch) if args.minlen != -1: short_sample = (mask[:, args.minlen] == 0).float() length_sample = length_sample.add(short_sample) if args.maxlen != -1: long_sample = mask[:, args.maxlen].float() length_sample = length_sample.add(long_sample) # initailize weight weight = torch.ones(batch).to(device) # optimize without weights if args.optim_mode == 0 and 0 <= args.beta <= 1: weight -= ood_sample * args.beta # only optimize length by weight if args.optim_mode == 1: weight -= length_sample * (1 - args.length_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= length_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch).to(device) temp -= ood_sample * args.beta weight *= temp real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) discriminator_output = discriminator_output.squeeze() # real_loss = adversarial_loss(discriminator_output, (y != 0.0).float()) real_loss = real_loss_func(discriminator_output, (y != 0.0).float()) if n_class > 2: # 大于2表示除了训练判别器还要训练分类器 class_loss = classified_loss(classification_output, y.long()) real_loss += class_loss D_class_loss += class_loss.detach() real_loss.backward() if args.do_vis: all_features.append(real_f_vector.detach()) # # train D on fake if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: # uniform (-1,1) # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device) z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) # beta of fake if 0 <= args.beta <= 1: fake_loss = args.beta * adversarial_loss(fake_discriminator_output, fake_label) else: fake_loss = adversarial_loss(fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() if args.fine_tune: optimizer_E.step() # train G optimizer_G.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: # uniform (-1,1) # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device) z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) if args.do_vis: G_features.append(fake_f_vector.detach()) gd_loss = adversarial_loss(D_decision, valid_label) # feature matching loss fm_loss = torch.abs(torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() # fm_loss = feature_matching_loss(torch.mean(fake_f_vector, 0), torch.mean(real_f_vector.detach(), 0)) g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_d_loss += g_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() # logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(i, D_fake_loss / n_sample)) # logger.info('[Epoch {}] Train: D_real_loss: {}'.format(i, D_real_loss / n_sample)) # logger.info('[Epoch {}] Train: D_class_loss: {}'.format(i, D_class_loss / n_sample)) # logger.info('[Epoch {}] Train: G_train_loss: {}'.format(i, G_train_loss / n_sample)) # logger.info('[Epoch {}] Train: G_d_loss: {}'.format(i, G_d_loss / n_sample)) # logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(i, FM_train_loss / n_sample)) # logger.info('---------------------------------------------------------------------------') D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) if dev_dataset: # logger.info('#################### eval result at step {} ####################'.format(global_step)) eval_result = eval(dev_dataset) if args.do_vis and args.do_g_eval_vis: G_features = torch.cat(G_features, 0).cpu().numpy() features = np.concatenate([eval_result['all_features'], G_features], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform(features) labels = np.concatenate([eval_result['all_binary_y'], np.array([-1] * len(G_features))], 0).reshape( -1, 1) data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot_epoch_' + str(i) + '.png')) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append(eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') # logger.info(eval_result) # logger.info('valid_eer: {}'.format(eval_result['eer'])) # logger.info('valid_oos_ind_precision: {}'.format(eval_result['oos_ind_precision'])) # logger.info('valid_oos_ind_recall: {}'.format(eval_result['oos_ind_recall'])) # logger.info('valid_oos_ind_f_score: {}'.format(eval_result['oos_ind_f_score'])) # logger.info('valid_auc: {}'.format(eval_result['auc'])) # logger.info( # 'valid_fpr95: {}'.format(ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score best_dev = -early_stopping.best_score if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) G.eval() D.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm.tqdm(dev_dataloader): sample = (i.to(device) for i in sample) if args.dataset == 'smp': token, mask, type_ids, knowledge_tag, y = sample if args.dataset == 'oos-eval': token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) all_detection_preds.append(discriminator_output) all_class_preds.append(classification_output) # 只预测判别器 else: f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True) all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor(dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze()) # [length, 1] # print('all_detection_preds', all_detection_preds.size()) # print('all_binary_y', all_binary_y.size()) # 计算损失 detection_loss = detection_loss(all_detection_preds.squeeze(), all_binary_y.float()) result['detection_loss'] = detection_loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy(all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info(metrics.classification_report(all_y, all_class_preds, target_names=processor.id_to_label)) # logger.info(metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) result['fpr95'] = ErrorRateAt95Recall(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features freeze_data['valid_all_y'] = all_y freeze_data['vaild_all_pred'] = all_detection_binary_preds freeze_data['valid_score'] = y_score return result def test(dataset): # load BERT and GAN load_gan_model(D, G, config['gan_save_path']) if args.fine_tune: load_model(E, path=config['bert_save_path'], model_name='bert') test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) G.eval() D.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm.tqdm(test_dataloader): sample = (i.to(device) for i in sample) if args.dataset == 'smp': token, mask, type_ids, knowledge_tag, y = sample if args.dataset == 'oos-eval': token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # 大于2表示除了训练判别器还要训练分类器 if n_class > 2: f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) all_detection_preds.append(discriminator_output) all_class_preds.append(classification_output) # 只预测判别器 else: f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True) all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor(dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss if n_class > 2: class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu() # one hot label class_loss = classified_loss(class_one_hot_preds, all_y) # compute loss all_class_preds = torch.argmax(class_one_hot_preds, 1) # label class_acc = metrics.ind_class_accuracy(all_class_preds, all_y, oos_index=0) # accuracy for ind class logger.info(metrics.classification_report(all_y, all_class_preds, target_names=processor.id_to_label)) # logger.info(metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['all_y'] = all_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['score'] = y_score result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) result['fpr95'] = ErrorRateAt95Recall(all_binary_y, y_score) if n_class > 2: result['class_loss'] = class_loss result['class_acc'] = class_acc if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features freeze_data['test_all_y'] = all_y.tolist() freeze_data['test_all_pred'] = all_detection_binary_preds.tolist() freeze_data['test_score'] = y_score return result def get_fake_feature(num_output): """ 生成一定数量的假特征 """ G.eval() fake_features = [] start = 0 batch = args.predict_batch_size with torch.no_grad(): while start < num_output: end = min(num_output, start + batch) if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor(np.random.normal(0, 1, size=(end - start, 32, args.G_z_dim))) else: z = FloatTensor(np.random.normal(0, 1, size=(end - start, args.G_z_dim))) fake_feature = G(z) f_vector, _ = D.detect_only(fake_feature, return_feature=True) fake_features.append(f_vector) start += batch return torch.cat(fake_features, 0).cpu().numpy() if args.do_train: if config['data_file'].startswith('binary'): if args.optim_mode == 0: text_train_set = processor.read_dataset(data_path, ['train'], args.mode, args.maxlen, args.minlen, pre_exclude=True) else: # optimize length or sample by weight text_train_set = processor.read_dataset(data_path, ['train'], args.mode, args.maxlen, args.minlen, pre_exclude=False) text_dev_set = processor.read_dataset(data_path, ['val'], args.mode, args.maxlen, args.minlen, pre_exclude=False) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set, text_train_len = processor.read_dataset(data_path, ['train']) text_dev_set, text_dev_len = processor.read_dataset(data_path, ['val']) if args.ood: text_train_set = [sample for sample in text_train_set if sample['domain'] != 'chat'] train_features = processor.convert_to_ids(text_train_set) dev_features = processor.convert_to_ids(text_dev_set) if config['dataset'] == 'oos-eval': train_dataset = OOSDataset(train_features) dev_dataset = OOSDataset(dev_features) if config['dataset'] == 'smp': train_dataset = SMPDataset(train_features) dev_dataset = SMPDataset(dev_features) train_result = train(train_dataset, dev_dataset) # save_feature(train_result['all_features'], os.path.join(args.output_dir, 'train_feature')) if args.do_eval: logger.info('#################### eval result at step {} ####################'.format(global_step)) if config['data_file'].startswith('binary'): # don't optim dev_set by weight, don't pre_exclude it text_dev_set = processor.read_dataset(data_path, ['val'], args.mode, args.maxlen, args.minlen, pre_exclude=False) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) if config['dataset'] == 'oos-eval': dev_dataset = OOSDataset(dev_features) if config['dataset'] == 'smp': dev_dataset = SMPDataset(dev_features) eval_result = eval(dev_dataset) # logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format(eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format(eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format(eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info( 'eval_fpr95: {}'.format(ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) gross_result['eval_oos_ind_precision'] = eval_result['oos_ind_precision'] gross_result['eval_oos_ind_recall'] = eval_result['oos_ind_recall'] gross_result['eval_oos_ind_f_score'] = eval_result['oos_ind_f_score'] gross_result['eval_eer'] = eval_result['eer'] gross_result['eval_fpr95'] = ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']) gross_result['eval_auc'] = eval_result['auc'] freeze_data['eval_result'] = eval_result if args.do_test: logger.info('#################### test result at step {} ####################'.format(global_step)) if config['data_file'].startswith('binary'): # always keep test_set unchanged text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) if config['dataset'] == 'oos-eval': test_dataset = OOSDataset(test_features) if config['dataset'] == 'smp': test_dataset = SMPDataset(test_features) test_result = test(test_dataset) # logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format(test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format(test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format(test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format(ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # save_feature(test_result['all_features'], os.path.join(args.output_dir, 'test_feature')) gross_result['test_oos_ind_precision'] = test_result['oos_ind_precision'] gross_result['test_oos_ind_recall'] = test_result['oos_ind_recall'] gross_result['test_oos_ind_f_score'] = test_result['oos_ind_f_score'] gross_result['test_eer'] = test_result['eer'] gross_result['test_fpr95'] = ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']) gross_result['test_auc'] = test_result['auc'] freeze_data['test_result'] = test_result # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format(args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) # beta_log_path = 'beta_log.txt' # if os.path.exists(beta_log_path): # flag = True # else: # flag = False # with open(beta_log_path, 'a', encoding='utf-8') as f: # if flag == False: # f.write('seed\tbeta\tdataset\tdev_eer\ttest_eer\tdata_size\n') # line = '\t'.join([str(config['seed']), str(config['beta']), str(config['data_file']), str(best_dev), str(test_result['eer']), '100']) # f.write(line + '\n') if args.do_vis: # [2 * length, feature_fim] features = np.concatenate([test_result['all_features'], get_fake_feature(len(test_dataset) // 2)], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform(features) # [2 * length, 2] # [2 * length, 1] if n_class > 2: labels = np.concatenate([test_result['all_y'], np.array([-1] * (len(test_dataset) // 2))], 0).reshape( (-1, 1)) else: labels = np.concatenate([test_result['all_binary_y'], np.array([-1] * (len(test_dataset) // 2))], 0).reshape((-1, 1)) # [2 * length, 3] data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot.png')) fig.show() freeze_data['feature_label'] = data # plot_train_test(train_result['all_features'], test_result['all_features'], args.output_dir) with open(os.path.join(config['output_dir'], 'freeze_data.pkl'), 'wb') as f: pickle.dump(freeze_data, f) df = pd.DataFrame(data={'valid_y': freeze_data['valid_all_y'], 'valid_score': freeze_data['valid_score'], }) df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv')) df = pd.DataFrame(data={'test_y': freeze_data['test_all_y'], 'test_score': freeze_data['test_score'] }) df.to_csv(os.path.join(config['output_dir'], 'test_score.csv')) if args.result != 'no': pd_result = pd.DataFrame(gross_result) if args.seed == 16: pd_result.to_csv(args.result + '_gross_result.csv', index=False) else: pd_result.to_csv(args.result + '_gross_result.csv', index=False, mode='a', header=False) if args.seed == 8192: print(args.result) std_mean(args.result + '_gross_result.csv')
def inference(config): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices test_patient_ids = os.listdir(config.test_dataset.root_dir_path) seed = check_manual_seed() print('Using seed: {}'.format(seed)) class_name_to_index = config.label_to_id._asdict() index_to_class_name = {v: k for k, v in class_name_to_index.items()} test_data_loader = get_data_loader( mode='test', dataset_name=config.test_dataset.dataset_name, root_dir_path=config.test_dataset.root_dir_path, patient_ids=test_patient_ids, batch_size=config.test_dataset.batch_size, num_workers=config.test_dataset.num_workers, volume_size=config.test_dataset.volume_size, ) model_1 = get_trained_model(config.model_1) model_2 = get_trained_model(config.model_2) model_3 = get_trained_model(config.model_3) model_4 = get_trained_model(config.model_4) model_5 = get_trained_model(config.model_5) model_1.eval() model_2.eval() model_3.eval() model_4.eval() model_5.eval() for batch in tqdm(test_data_loader): image = batch['image'].cuda().float() assert image.size(0) == 1 patient_id = batch['patient_id'][0] nii_path = batch['nii_path'][0] image = F.pad(image, (2, 3, 0, 0, 0, 0, 0, 0, 0, 0), 'constant', 0) output = torch.ones((1, image.shape[2], image.shape[3], image.shape[4])) with torch.no_grad(): pred_1 = model_1(image) pred_2 = model_2(image) pred_3 = model_3(image) pred_4 = model_4(image) pred_5 = model_5(image) pred = (pred_1 + pred_2 + pred_3 + pred_4 + pred_5) / 5.0 mask_0 = pred[:, 0, ...] < 0.5 mask_1 = pred[:, 1, ...] < 0.5 mask_2 = pred[:, 2, ...] < 0.5 mask = mask_0 * mask_1 * mask_2 pred = pred.argmax(1).cpu() output += pred output[mask] = 0 image = image[..., 2:-3] output = output[..., 2:-3] save_dir_path = os.path.join(config.save.save_root_dir, patient_id) os.makedirs(save_dir_path, exist_ok=True) image = image.cpu().numpy()[0, 1, ...] output = output.cpu().numpy()[0, ...].astype(np.int16) nii_image = nib.load(nii_path) nii_output = nib.Nifti1Image(output, affine=nii_image.affine) nib.save(nii_output, os.path.join(os.path.join( save_dir_path, patient_id + '_output.nii.gz')) )
def main(config, needs_save, study_name, k, n_splits): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices seed = check_manual_seed(config.run.seed) print('Using seed: {}'.format(seed)) train_data_loader, test_data_loader, data_train = get_k_hold_data_loader( config.dataset, k=k, n_splits=n_splits, ) data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True) data_train = torch.t(data_train) model = get_model(config.model) model.cuda() model = nn.DataParallel(model) print('count params: ', count_parameters(model.module)) saved_model_path, _, _ = get_saved_model_path( config, study_name, config.model.checkpoint_epoch, k, n_splits, ) model.load_state_dict(torch.load(saved_model_path)['model']) model.eval() if config.model.model_name == 'MLP': embedding = model.module.get_embedding() elif config.model.model_name == 'ModifiedMLP': embedding = model.module.get_embedding() elif config.model.model_name == 'DietNetworks': embedding = model.module.get_embedding(data_train) elif config.model.model_name == 'ModifiedDietNetworks': embedding = model.module.get_embedding(data_train) embedding = embedding.detach().cpu().numpy() emb_pca = PCA(n_components=2) emb_pca.fit_transform(embedding) if config.run.decomp == '1D': print('Approximate by 1D PCA') axis_1= torch.from_numpy(emb_pca.components_[0]) score_1 = np.dot(embedding, axis_1) approx = np.outer(score_1, axis_1) elif config.run.decomp == '2D': print('Approximate by 2D PCA') axis_1= torch.from_numpy(emb_pca.components_[0]) score_1 = np.dot(embedding, axis_1) axis_2= torch.from_numpy(emb_pca.components_[1]) score_2 = np.dot(embedding, axis_2) approx = np.outer(score_1, axis_1) + np.outer(score_2, axis_2) # approx = np.outer(score_2, axis_2) approx = torch.from_numpy(approx).float().cuda(non_blocking=True) criterion = nn.CrossEntropyLoss() def inference(engine, batch): x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) assert config.run.transposed_matrix == 'overall' x_t = data_train with torch.no_grad(): out, _ = model.module.approx(x, approx) l_discriminative = criterion(out, y) l_total = l_discriminative metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), }) torch.cuda.synchronize() return metrics evaluator = Engine(inference) monitoring_metrics = ['l_total', 'l_discriminative', 'accuracy'] for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(evaluator, metric) pbar = ProgressBar() pbar.attach(evaluator, metric_names=monitoring_metrics) evaluator.run(test_data_loader, 1) columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(evaluator.state.metrics.keys()) values = [str(k), str(n_splits), str(evaluator.state.epoch), str(evaluator.state.iteration)] \ + [str(value) for value in evaluator.state.metrics.values()] values = {c: v for (c, v) in zip(columns, values)} values.update({ 'variance_ratio_1': emb_pca.explained_variance_ratio_[0], 'variance_ratio_2': emb_pca.explained_variance_ratio_[1], }) return values
def main(args): check_manual_seed(args.seed) logger.info('seed: {}'.format(args.seed)) logger.info('Loading config...') bert_config = Config('config/bert.ini') bert_config = bert_config(args.bert_type) # for oos-eval dataset data_config = Config('config/data.ini') data_config = data_config(args.dataset) # Prepare data processor data_path = os.path.join(data_config['DataDir'], data_config[args.data_file]) # 把目录和文件名合成一个路径 label_path = data_path.replace('.json', '.label') with open(data_path, 'r', encoding='utf-8') as fp: data = json.load(fp) for type in data: logger.info('{} : {}'.format(type, len(data[type]))) with open(label_path, 'r', encoding='utf-8') as fp: logger.info(json.load(fp)) if args.dataset == 'oos-eval': processor = OOSProcessor(bert_config, maxlen=32) logger.info('OOSProcessor') elif args.dataset == 'smp': processor = SMPProcessor(bert_config, maxlen=32) logger.info('SMPProcessor') else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) processor.load_label( label_path) # Adding label_to_id and id_to_label ot processor. logger.info("label_to_id: {}".format(processor.label_to_id)) logger.info("id_to_label: {}".format(processor.id_to_label)) n_class = len(processor.id_to_label) config = vars(args) # 返回参数字典 config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt') config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt') config['n_class'] = n_class logger.info('config:') logger.info(config) D_detect = Discriminator(config) D_g = Discriminator(config) G = Generator(config) E = BertModel.from_pretrained( bert_config['PreTrainModelDir']) # Bert encoder if args.fine_tune: for param in E.parameters(): param.requires_grad = True else: for param in E.parameters(): param.requires_grad = False D_detect.to(device) D_g.to(device) G.to(device) E.to(device) global_step = 0 def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D_detect = torch.optim.Adam( D_detect.parameters(), lr=args.D_detect_lr) # optimizer for discriminator optimizer_D_g = torch.optim.Adam(D_g.parameters(), lr=args.D_g_lr) optimizer_E = AdamW(E.parameters(), args.bert_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] for i in range(args.n_epoch): logger.info('***********************************') logger.info('epoch: {}'.format(i)) # Initialize model state G.train() D_detect.train() D_g.train() E.train() D_g_real_loss = 0 D_g_fake_loss = 0 D_detect_real_loss = 0 D_detect_fake_loss = 0 G_loss = 0 for sample in tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) all_g_D_g_loss = 0 D_gen_real_loss = None D_gen_fake_loss = None # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output for gan_i in range(args.time): # ------------------------- train D_g -------------------------# # train on D_g real id_sample = (y == 1.0) weight = torch.ones(len(id_sample)).to( device) - id_sample * 1.0 # 除去id损失, 只用ood数据 real_loss_func = torch.nn.BCELoss(weight=weight).to(device) optimizer_D_g.zero_grad() D_gen_real_discriminator_output, f_vector = D_g( real_feature) # D_gen_real_loss = adversarial_loss(D_gen_real_discriminator_output, valid_label) # 判别器对真实样本的损失 D_gen_real_loss = real_loss_func( D_gen_real_discriminator_output.squeeze(), valid_label.squeeze()) # train on D_g fake z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() D_gen_fake_discriminator_output, f_vector = D_g( fake_feature) D_gen_fake_loss = adversarial_loss( D_gen_fake_discriminator_output.squeeze(), fake_label.squeeze()) # 判别器对假样本的损失 D_gen_loss = D_gen_real_loss + D_gen_fake_loss D_gen_loss.backward(retain_graph=True) # 保存计算图,生成器还要使用 optimizer_D_g.step() # ------------------------- train G -------------------------# list_g_D_g_loss = [] for gi in range(args.g_time): optimizer_G.zero_grad() z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() D_gen_fake_discriminator_output, f_vector = D_g( fake_feature) g_D_g_loss = adversarial_loss( D_gen_fake_discriminator_output.squeeze(), valid_label.squeeze()) # 生成器欺骗 D_g, 认为是真实样本 g_D_g_loss.backward() optimizer_G.step() all_g_D_g_loss += g_D_g_loss.detach() list_g_D_g_loss.append(g_D_g_loss) # ------------------------- train D_detect_ood -------------------------# # train on real(detect real sample) optimizer_D_detect.zero_grad() ood_real_detect_discriminator_output, f_vector = D_detect( real_feature) ood_real_detect_loss = adversarial_loss( ood_real_detect_discriminator_output.squeeze(), (y != 0.0).float()) # ood 判别器对真实样本的损失 # train on fake(detect fake sample) fake sample is fake id -> ood z = FloatTensor(np.random.normal( 0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() ood_fake_detect_discriminator_output, f_vector = D_detect( fake_feature) ood_fake_detect_loss = adversarial_loss( ood_fake_detect_discriminator_output.squeeze(), fake_label.squeeze()) # 假ood认为是ood样本 D_detect_loss = args.beta * ood_real_detect_loss + ( 1 - args.beta) * ood_fake_detect_loss # 真实样本与假ood样本影响比例 D_detect_loss.backward() optimizer_D_detect.step() if args.fine_tune: optimizer_E.step() global_step += 1 D_g_real_loss += D_gen_real_loss.detach() D_g_fake_loss += D_gen_fake_loss.detach() D_detect_real_loss += ood_real_detect_loss.detach() D_detect_fake_loss += ood_fake_detect_loss.detach() G_loss += all_g_D_g_loss logger.info('[Epoch {}] Train: D_g_real_loss: {}'.format( i, D_g_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_g_fake_loss: {}'.format( i, D_g_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_detect_real_loss: {}'.format( i, D_detect_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_detect_fake_loss: {}'.format( i, D_detect_fake_loss / n_sample)) logger.info('[Epoch {}] Train: G_loss: {}'.format( i, G_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break # elif signal == 0: # pass # elif signal == 1: # save_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) best_dev = -early_stopping.best_score def eval(dataset): dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(dev_dataloader) result = dict() detection_loss = torch.nn.BCELoss().to(device) D_detect.eval() E.eval() all_detection_preds = [] for sample in tqdm(dev_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output discriminator_output, f_vector = D_detect(real_feature) all_detection_preds.append(discriminator_output) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) return result def test(dataset): # # load BERT and GAN # load_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # load_model(E, path=config['bert_save_path'], model_name='bert') # test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2) n_sample = len(test_dataloader) result = dict() # Loss function detection_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) D_detect.eval() E.eval() all_detection_preds = [] all_class_preds = [] all_features = [] for sample in tqdm(test_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # -------------------------evaluate D------------------------- # # BERT encode sentence to feature vector with torch.no_grad(): sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output discriminator_output, f_vector = D_detect(real_feature) all_detection_preds.append(discriminator_output) if args.do_vis: all_features.append(f_vector) all_y = LongTensor( dataset.dataset[:, -1].astype(int)).cpu() # [length, n_class] all_binary_y = (all_y != 0).long() # [length, 1] label 0 is oos all_detection_preds = torch.cat(all_detection_preds, 0).cpu() # [length, 1] all_detection_binary_preds = convert_to_int_by_threshold( all_detection_preds.squeeze()) # [length, 1] # 计算损失 detection_loss = detection_loss(all_detection_preds, all_binary_y.float()) result['detection_loss'] = detection_loss logger.info( metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in'])) # report oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore( all_detection_binary_preds, all_binary_y) detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y) y_score = all_detection_preds.squeeze().tolist() eer = metrics.cal_eer(all_binary_y, y_score) result['eer'] = eer result['all_detection_binary_preds'] = all_detection_binary_preds result['detection_acc'] = detection_acc result['all_binary_y'] = all_binary_y result['oos_ind_precision'] = oos_ind_precision result['oos_ind_recall'] = oos_ind_recall result['oos_ind_f_score'] = oos_ind_fscore result['y_score'] = y_score result['auc'] = roc_auc_score(all_binary_y, y_score) if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result def get_fake_feature(num_output): """ 生成一定数量的假特征 """ G.eval() fake_features = [] start = 0 batch = args.predict_batch_size with torch.no_grad(): while start < num_output: end = min(num_output, start + batch) z = FloatTensor( np.random.normal(0, 1, size=(end - start, args.G_z_dim))) fake_feature = G(z) discriminator_output, f_vector = D_detect(fake_feature) fake_features.append(f_vector) start += batch return torch.cat(fake_features, 0).cpu().numpy() if args.do_train: if config['data_file'].startswith('binary'): text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_train_set = processor.read_dataset(data_path, ['train', 'oos_train']) text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_train_set = processor.read_dataset(data_path, ['train']) text_dev_set = processor.read_dataset(data_path, ['val']) train_features = processor.convert_to_ids(text_train_set) train_dataset = OOSDataset(train_features) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) train(train_dataset, dev_dataset) if args.do_eval: logger.info( '#################### eval result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_dev_set = processor.read_dataset(data_path, ['val']) elif config['dataset'] == 'oos-eval': text_dev_set = processor.read_dataset(data_path, ['val', 'oos_val']) elif config['dataset'] == 'smp': text_dev_set = processor.read_dataset(data_path, ['val']) dev_features = processor.convert_to_ids(text_dev_set) dev_dataset = OOSDataset(dev_features) eval_result = eval(dev_dataset) logger.info(eval_result) logger.info('eval_eer: {}'.format(eval_result['eer'])) logger.info('eval_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('eval_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('eval_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('eval_auc: {}'.format(eval_result['auc'])) logger.info('eval_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.do_test: logger.info( '#################### test result at step {} ####################'. format(global_step)) if config['data_file'].startswith('binary'): text_test_set = processor.read_dataset(data_path, ['test']) elif config['dataset'] == 'oos-eval': text_test_set = processor.read_dataset(data_path, ['test', 'oos_test']) elif config['dataset'] == 'smp': text_test_set = processor.read_dataset(data_path, ['test']) test_features = processor.convert_to_ids(text_test_set) test_dataset = OOSDataset(test_features) test_result = test(test_dataset) logger.info(test_result) logger.info('test_eer: {}'.format(test_result['eer'])) logger.info('test_ood_ind_precision: {}'.format( test_result['oos_ind_precision'])) logger.info('test_ood_ind_recall: {}'.format( test_result['oos_ind_recall'])) logger.info('test_ood_ind_f_score: {}'.format( test_result['oos_ind_f_score'])) logger.info('test_auc: {}'.format(test_result['auc'])) logger.info('test_fpr95: {}'.format( ErrorRateAt95Recall(test_result['all_binary_y'], test_result['y_score']))) my_plot_roc(test_result['all_binary_y'], test_result['y_score'], os.path.join(args.output_dir, 'roc_curve.png')) save_result(test_result, os.path.join(args.output_dir, 'test_result')) # 输出错误cases if config['dataset'] == 'oos-eval': texts = [line[0] for line in text_test_set] elif config['dataset'] == 'smp': texts = [line['text'] for line in text_test_set] else: raise ValueError('The dataset {} is not supported.'.format( args.dataset)) output_cases(texts, test_result['all_binary_y'], test_result['all_detection_binary_preds'], os.path.join(args.output_dir, 'test_cases.csv'), processor) # confusion matrix plot_confusion_matrix(test_result['all_binary_y'], test_result['all_detection_binary_preds'], args.output_dir) beta_log_path = 'beta_log.txt' if os.path.exists(beta_log_path): flag = True else: flag = False with open(beta_log_path, 'a', encoding='utf-8') as f: if flag == False: f.write('seed\tbeta\tdataset\tdev_eer\ttest_eer\tdata_size\n') line = '\t'.join([ str(config['seed']), str(config['beta']), str(config['data_file']), str(best_dev), str(test_result['eer']), '100' ]) f.write(line + '\n') if args.do_vis: # [2 * length, feature_fim] features = np.concatenate([ test_result['all_features'], get_fake_feature(len(test_dataset) // 2) ], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform( features) # [2 * length, 2] # [2 * length, 1] if n_class > 2: labels = np.concatenate([ test_result['all_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) else: labels = np.concatenate([ test_result['all_binary_y'], np.array([-1] * (len(test_dataset) // 2)) ], 0).reshape((-1, 1)) # [2 * length, 3] data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot.png')) fig.show()
def main(config, needs_save, study_name, k, n_splits, output_dir_path): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices seed = check_manual_seed(config.run.seed) print('Using seed: {}'.format(seed)) train_data_loader, test_data_loader, data_train = get_k_hold_data_loader( config.dataset, k=k, n_splits=n_splits, ) data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True) data_train = torch.t(data_train) model = get_model(config.model) model.cuda() model = nn.DataParallel(model) criterion = nn.CrossEntropyLoss() if config.optimizer.optimizer_name == 'Adam': optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), config.optimizer.lr, [0.9, 0.9999], weight_decay=config.optimizer.weight_decay, ) else: raise NotImplementedError # scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 ** epoch) def update(engine, batch): model.train() x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) if config.run.transposed_matrix == 'overall': x_t = data_train elif config.run.transposed_matrix == 'batch': x_t = torch.t(x) def closure(): optimizer.zero_grad() if 'MLP' in config.model.model_name: out, x_hat = model(x) else: out, x_hat = model(x, x_t) l_discriminative = criterion(out, y) l_feature = torch.tensor(0.0).cuda() if config.run.w_feature_selection: l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue)) l_recon = torch.tensor(0.0).cuda() if config.run.w_reconstruction: l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat) l_total = l_discriminative + l_feature + l_recon l_total.backward() return l_total, l_discriminative, l_feature, l_recon, out l_total, l_discriminative, l_feature, l_recon, out = optimizer.step(closure) metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), 'l_feature': l_feature.item(), 'l_recon': l_recon.item(), }) torch.cuda.synchronize() return metrics def inference(engine, batch): model.eval() x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) if config.run.transposed_matrix == 'overall': x_t = data_train elif config.run.transposed_matrix == 'batch': x_t = torch.t(x) with torch.no_grad(): if 'MLP' in config.model.model_name: out, x_hat = model(x) else: out, x_hat = model(x, x_t) l_discriminative = criterion(out, y) l_feature = torch.tensor(0.0).cuda() if config.run.w_feature_selection: l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue)) l_recon = torch.tensor(0.0).cuda() if config.run.w_reconstruction: l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat) l_total = l_discriminative + l_feature + l_recon metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), 'l_feature': l_feature.item(), 'l_recon': l_recon.item(), }) torch.cuda.synchronize() return metrics trainer = Engine(update) evaluator = Engine(inference) timer = Timer(average=True) monitoring_metrics = ['l_total', 'l_discriminative', 'l_feature', 'l_recon', 'accuracy'] for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(trainer, metric) for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(evaluator, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) pbar.attach(evaluator, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def events_started(engine): if needs_save: save_config(config, seed, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def switch_training_to_evaluation(engine): if needs_save: save_logs('train', k, n_splits, trainer, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) evaluator.run(test_data_loader, max_epochs=1) @evaluator.on(Events.EPOCH_COMPLETED) def switch_evaluation_to_training(engine): if needs_save: save_logs('val', k, n_splits, evaluator, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) if trainer.state.epoch % 100 == 0: save_models(model, optimizer, k, n_splits, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) # scheduler.step() @trainer.on(Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def show_logs(engine): columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(engine.state.metrics.keys()) values = [str(k), str(n_splits), str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch, max_epoch=config.run.n_epochs, i=engine.state.iteration, max_i=len(train_data_loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.run.n_epochs, config.run.n_epochs * len(train_data_loader)) ) trainer.run(train_data_loader, config.run.n_epochs)
def main(config, study_name, i, n_splits, NUM=None, CHECKPOINT_EPOCHS=[500]): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices seed = check_manual_seed(config.run.seed) print('Using seed: {}'.format(seed)) _, _, data_train, label_train = get_k_hold_data_loader( config.dataset, k=i, n_splits=n_splits, with_label_train=True, ) attributes, p_values = calc_freq(data_train, label_train, 0.05) gene_symbols = get_gene_symbols(config.dataset.data_path) if NUM is None: NUM = len(gene_symbols) labels = {} for g, symbol in enumerate(gene_symbols): if g < NUM: labels[g] = symbol else: break attributes = attributes[: NUM] p_values = p_values[: NUM] data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True) data_train = torch.t(data_train) model = get_model(config.model) model.cuda() model = nn.DataParallel(model) for checkpoint_epoch in CHECKPOINT_EPOCHS: saved_model_path, model_name, saved_dir_path = get_saved_model_path( config, study_name, checkpoint_epoch, i, n_splits, ) model.load_state_dict(torch.load(saved_model_path)['model']) model.eval() with torch.no_grad(): if config.model.model_name == 'MLP': embedding = model.module.get_embedding() elif config.model.model_name == 'ModifiedMLP': embedding = model.module.get_embedding() elif config.model.model_name == 'DietNetworks': embedding = model.module.get_embedding(data_train) elif config.model.model_name == 'ModifiedDietNetworks': embedding = model.module.get_embedding(data_train) embedding = embedding.detach().cpu().numpy() embedding = embedding[: NUM, :] X_pca = PCA(n_components=2) X_pca.fit_transform(embedding) # fig, ax = plt.subplots() # ax.scatter(X_pca[:, 0], X_pca[:, 1], s=10., c=attributes) # # # for i in range(NUM): # # # dist = np.sqrt(np.power(X_pca[i, 0], 2) + np.power(X_pca[i, 1], 2)) # # # if dist > 1.0: # # ax.annotate(labels[i], (X_pca[i, 0], X_pca[i, 1])) # # plt.show() # plt.clf() axis_1= X_pca.components_[0] score_1 = np.dot(embedding, axis_1) axis_2= X_pca.components_[1] score_2 = np.dot(embedding, axis_2) for arg in np.argsort(score_2)[::-1]: print(score_2[arg], labels[arg]) input()