def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam(D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] all_features = [] result = dict() for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() G_train_loss = 0 G_d_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 G_features = [] for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) if args.dataset == 'smp': token, mask, type_ids, knowledge_tag, y = sample batch = len(token) ood_sample = (y == 0.0).float() # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # length weight length_sample = FloatTensor([0] * batch) if args.minlen != -1: short_sample = (mask[:, args.minlen] == 0).float() length_sample = length_sample.add(short_sample) if args.maxlen != -1: long_sample = mask[:, args.maxlen].float() length_sample = length_sample.add(long_sample) # get knowledge sample weight by knowledge_tag exclude_sample = knowledge_tag # initailize weight weight = torch.ones(batch).to(device) # optimize without weights if args.optim_mode == 0 and 0 <= args.beta <= 1: weight -= ood_sample * args.beta # only optimize length by weight if args.optim_mode == 1: # set all exclude_sample's weight to 0 weight -= exclude_sample length_sample -= exclude_sample length_sample = (length_sample > 0).float() weight -= length_sample * (1 - args.length_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= exclude_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch).to(device) temp -= ood_sample * args.beta weight *= temp # only optimize sample by weight if args.optim_mode == 2: # set all length_sample's weight to 0 weight -= length_sample exclude_sample -= length_sample exclude_sample = (exclude_sample > 0).float() weight -= exclude_sample * (1 - args.sample_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= length_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch) temp -= ood_sample * args.beta weight *= temp # optimize length and sample by weight # if args.optim_mode == 3: # alpha = 0.5 # beta = 0.5 # weight = torch.ones(len(length_sample)).to(device) \ # - alpha * length_sample * (1 - args.length_weight) \ # - beta * exclude_sample * (1 - args.sample_weight) if args.dataset == 'oos-eval': token, mask, type_ids, y = sample batch = len(token) ood_sample = (y == 0.0).float() # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # length weight length_sample = FloatTensor([0] * batch) if args.minlen != -1: short_sample = (mask[:, args.minlen] == 0).float() length_sample = length_sample.add(short_sample) if args.maxlen != -1: long_sample = mask[:, args.maxlen].float() length_sample = length_sample.add(long_sample) # initailize weight weight = torch.ones(batch).to(device) # optimize without weights if args.optim_mode == 0 and 0 <= args.beta <= 1: weight -= ood_sample * args.beta # only optimize length by weight if args.optim_mode == 1: weight -= length_sample * (1 - args.length_weight) # set ood sample weight if 0 <= args.beta <= 1: ood_sample -= length_sample ood_sample = (ood_sample > 0).float() temp = torch.ones(batch).to(device) temp -= ood_sample * args.beta weight *= temp real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True) discriminator_output = discriminator_output.squeeze() # real_loss = adversarial_loss(discriminator_output, (y != 0.0).float()) real_loss = real_loss_func(discriminator_output, (y != 0.0).float()) if n_class > 2: # 大于2表示除了训练判别器还要训练分类器 class_loss = classified_loss(classification_output, y.long()) real_loss += class_loss D_class_loss += class_loss.detach() real_loss.backward() if args.do_vis: all_features.append(real_f_vector.detach()) # # train D on fake if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: # uniform (-1,1) # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device) z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) # beta of fake if 0 <= args.beta <= 1: fake_loss = args.beta * adversarial_loss(fake_discriminator_output, fake_label) else: fake_loss = adversarial_loss(fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() if args.fine_tune: optimizer_E.step() # train G optimizer_G.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: # uniform (-1,1) # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device) z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) if args.do_vis: G_features.append(fake_f_vector.detach()) gd_loss = adversarial_loss(D_decision, valid_label) # feature matching loss fm_loss = torch.abs(torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() # fm_loss = feature_matching_loss(torch.mean(fake_f_vector, 0), torch.mean(real_f_vector.detach(), 0)) g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_d_loss += g_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() # logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(i, D_fake_loss / n_sample)) # logger.info('[Epoch {}] Train: D_real_loss: {}'.format(i, D_real_loss / n_sample)) # logger.info('[Epoch {}] Train: D_class_loss: {}'.format(i, D_class_loss / n_sample)) # logger.info('[Epoch {}] Train: G_train_loss: {}'.format(i, G_train_loss / n_sample)) # logger.info('[Epoch {}] Train: G_d_loss: {}'.format(i, G_d_loss / n_sample)) # logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(i, FM_train_loss / n_sample)) # logger.info('---------------------------------------------------------------------------') D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) if dev_dataset: # logger.info('#################### eval result at step {} ####################'.format(global_step)) eval_result = eval(dev_dataset) if args.do_vis and args.do_g_eval_vis: G_features = torch.cat(G_features, 0).cpu().numpy() features = np.concatenate([eval_result['all_features'], G_features], axis=0) features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform(features) labels = np.concatenate([eval_result['all_binary_y'], np.array([-1] * len(G_features))], 0).reshape( -1, 1) data = np.concatenate([features, labels], 1) fig = scatter_plot(data, processor) fig.savefig(os.path.join(args.output_dir, 'plot_epoch_' + str(i) + '.png')) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append(eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') # logger.info(eval_result) # logger.info('valid_eer: {}'.format(eval_result['eer'])) # logger.info('valid_oos_ind_precision: {}'.format(eval_result['oos_ind_precision'])) # logger.info('valid_oos_ind_recall: {}'.format(eval_result['oos_ind_recall'])) # logger.info('valid_oos_ind_f_score: {}'.format(eval_result['oos_ind_f_score'])) # logger.info('valid_auc: {}'.format(eval_result['auc'])) # logger.info( # 'valid_fpr95: {}'.format(ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score best_dev = -early_stopping.best_score if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result
def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam( D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) optimizer_detector = torch.optim.Adam(detector.parameters(), lr=args.detector_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] detector_total_train_loss = [] all_features = [] result = dict() for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() detector.train() G_train_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 detector_train_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) ood_sample = (y == 0.0) # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta # real_loss_func = torch.nn.BCELoss(weight=weight).to(device) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D( real_feature, return_feature=True) # discriminator_output = discriminator_output.squeeze() real_loss = adversarial_loss(discriminator_output, valid_label) real_loss.backward(retain_graph=True) if args.do_vis: all_features.append(real_f_vector.detach()) # # train D on fake if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) fake_loss = adversarial_loss(fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() # if args.fine_tune: # optimizer_E.step() # train G optimizer_G.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) gd_loss = adversarial_loss(D_decision, valid_label) fm_loss = torch.abs( torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() optimizer_E.zero_grad() # train detector optimizer_detector.zero_grad() if args.model == 'lstm_gan' or args.model == 'cnn_gan': z = FloatTensor( np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device) else: z = FloatTensor( np.random.normal(0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() if args.loss == 'v1': loss_fake = adversarial_loss( detector(fake_feature), fake_label) # fake sample is ood else: loss_fake = adversarial_loss_v2( detector(fake_feature), fake_label.long().squeeze()) if args.loss == 'v1': loss_real = adversarial_loss(detector(real_feature), y.float()) else: loss_real = adversarial_loss_v2(detector(real_feature), y.long()) if args.detect_loss == 'v1': detector_loss = args.beta * loss_fake + ( 1 - args.beta) * loss_real else: detector_loss = args.beta * loss_fake + loss_real detector_loss = args.sigma * detector_loss detector_loss.backward() optimizer_detector.step() if args.fine_tune: optimizer_E.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() detector_train_loss += detector_loss logger.info('[Epoch {}] Train: D_fake_loss: {}'.format( i, D_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_real_loss: {}'.format( i, D_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_class_loss: {}'.format( i, D_class_loss / n_sample)) logger.info('[Epoch {}] Train: G_train_loss: {}'.format( i, G_train_loss / n_sample)) logger.info('[Epoch {}] Train: FM_train_loss: {}'.format( i, FM_train_loss / n_sample)) logger.info('[Epoch {}] Train: detector_train_loss: {}'.format( i, detector_train_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) detector_total_train_loss.append(detector_train_loss / n_sample) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score best_dev = -early_stopping.best_score if args.do_vis: all_features = torch.cat(all_features, 0).cpu().numpy() result['all_features'] = all_features return result
def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr) # optimizer for generator optimizer_D = torch.optim.Adam( D.parameters(), lr=args.D_lr) # optimizer for discriminator optimizer_E = AdamW(E.parameters(), args.bert_lr) G_total_train_loss = [] D_total_fake_loss = [] D_total_real_loss = [] FM_total_train_loss = [] D_total_class_loss = [] valid_detection_loss = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] train_loss = [] iteration = 0 for i in range(args.n_epoch): # Initialize model state G.train() D.train() E.train() G_train_loss = 0 D_fake_loss = 0 D_real_loss = 0 FM_train_loss = 0 D_class_loss = 0 total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) # the label used to train generator and discriminator. valid_label = FloatTensor(batch, 1).fill_(1.0).detach() fake_label = FloatTensor(batch, 1).fill_(0.0).detach() optimizer_E.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids, return_dict=False) real_feature = pooled_output # train D on real optimizer_D.zero_grad() real_f_vector, discriminator_output, classification_output = D( real_feature, return_feature=True) discriminator_output = discriminator_output.squeeze() real_loss = adversarial_loss(discriminator_output, (y != 0.0).float()) if n_class > 2: # 大于2表示除了训练判别器还要训练分类器 class_loss = classified_loss(classification_output, y.long()) real_loss += class_loss D_class_loss += class_loss.detach() real_loss.backward() # # train D on fake z = FloatTensor(np.random.normal( 0, 1, (batch, args.G_z_dim))).to(device) fake_feature = G(z).detach() fake_discriminator_output = D.detect_only(fake_feature) fake_loss = args.beta * adversarial_loss( fake_discriminator_output, fake_label) fake_loss.backward() optimizer_D.step() if args.fine_tune: optimizer_E.step() # train G optimizer_G.zero_grad() z = FloatTensor(np.random.normal( 0, 1, (batch, args.G_z_dim))).to(device) fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True) gd_loss = adversarial_loss(D_decision, valid_label) fm_loss = torch.abs( torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean() g_loss = gd_loss + 0 * fm_loss g_loss.backward() optimizer_G.step() global_step += 1 D_fake_loss += fake_loss.detach() D_real_loss += real_loss.detach() G_train_loss += g_loss.detach() + fm_loss.detach() FM_train_loss += fm_loss.detach() # train_loss.append(total_loss / n_sample) # iteration += 1 logger.info('[Epoch {}] Train: D_fake_loss: {}'.format( i, D_fake_loss / n_sample)) logger.info('[Epoch {}] Train: D_real_loss: {}'.format( i, D_real_loss / n_sample)) logger.info('[Epoch {}] Train: D_class_loss: {}'.format( i, D_class_loss / n_sample)) logger.info('[Epoch {}] Train: G_train_loss: {}'.format( i, G_train_loss / n_sample)) logger.info('[Epoch {}] Train: FM_train_loss: {}'.format( i, FM_train_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) D_total_fake_loss.append(D_fake_loss / n_sample) D_total_real_loss.append(D_real_loss / n_sample) D_total_class_loss.append(D_class_loss / n_sample) G_total_train_loss.append(G_train_loss / n_sample) FM_total_train_loss.append(FM_train_loss / n_sample) if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) if args.patience >= args.n_epoch: save_gan_model(D, G, config['gan_save_path']) if args.fine_tune: save_model(E, path=config['bert_save_path'], model_name='bert') freeze_data['D_total_fake_loss'] = D_total_fake_loss freeze_data['D_total_real_loss'] = D_total_real_loss freeze_data['D_total_class_loss'] = D_total_class_loss freeze_data['G_total_train_loss'] = G_total_train_loss freeze_data['FM_total_train_loss'] = FM_total_train_loss freeze_data['valid_real_loss'] = valid_detection_loss freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score