コード例 #1
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)

        global best_dev
        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(), lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        all_features = []
        result = dict()

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()

            G_train_loss = 0
            G_d_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0

            G_features = []

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                if args.dataset == 'smp':
                    token, mask, type_ids, knowledge_tag, y = sample
                    batch = len(token)

                    ood_sample = (y == 0.0).float()
                    # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                    # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                    # length weight
                    length_sample = FloatTensor([0] * batch)
                    if args.minlen != -1:
                        short_sample = (mask[:, args.minlen] == 0).float()
                        length_sample = length_sample.add(short_sample)
                    if args.maxlen != -1:
                        long_sample = mask[:, args.maxlen].float()
                        length_sample = length_sample.add(long_sample)

                    # get knowledge sample weight by knowledge_tag
                    exclude_sample = knowledge_tag

                    # initailize weight
                    weight = torch.ones(batch).to(device)

                    # optimize without weights
                    if args.optim_mode == 0 and 0 <= args.beta <= 1:
                        weight -= ood_sample * args.beta

                    # only optimize length by weight
                    if args.optim_mode == 1:
                        # set all exclude_sample's weight to 0
                        weight -= exclude_sample
                        length_sample -= exclude_sample
                        length_sample = (length_sample > 0).float()
                        weight -= length_sample * (1 - args.length_weight)

                        # set ood sample weight
                        if 0 <= args.beta <= 1:
                            ood_sample -= exclude_sample
                            ood_sample = (ood_sample > 0).float()
                            temp = torch.ones(batch).to(device)
                            temp -= ood_sample * args.beta
                            weight *= temp

                    # only optimize sample by weight
                    if args.optim_mode == 2:
                        # set all length_sample's weight to 0
                        weight -= length_sample

                        exclude_sample -= length_sample
                        exclude_sample = (exclude_sample > 0).float()
                        weight -= exclude_sample * (1 - args.sample_weight)

                        # set ood sample weight
                        if 0 <= args.beta <= 1:
                            ood_sample -= length_sample
                            ood_sample = (ood_sample > 0).float()
                            temp = torch.ones(batch)
                            temp -= ood_sample * args.beta
                            weight *= temp

                    # optimize length and sample by weight
                    # if args.optim_mode == 3:
                    #     alpha = 0.5
                    #     beta = 0.5
                    #     weight = torch.ones(len(length_sample)).to(device) \
                    #              - alpha * length_sample * (1 - args.length_weight) \
                    #              - beta * exclude_sample * (1 - args.sample_weight)

                if args.dataset == 'oos-eval':
                    token, mask, type_ids, y = sample
                    batch = len(token)

                    ood_sample = (y == 0.0).float()
                    # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                    # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                    # length weight
                    length_sample = FloatTensor([0] * batch)
                    if args.minlen != -1:
                        short_sample = (mask[:, args.minlen] == 0).float()
                        length_sample = length_sample.add(short_sample)
                    if args.maxlen != -1:
                        long_sample = mask[:, args.maxlen].float()
                        length_sample = length_sample.add(long_sample)

                    # initailize weight
                    weight = torch.ones(batch).to(device)

                    # optimize without weights
                    if args.optim_mode == 0 and 0 <= args.beta <= 1:
                        weight -= ood_sample * args.beta

                    # only optimize length by weight
                    if args.optim_mode == 1:
                        weight -= length_sample * (1 - args.length_weight)

                        # set ood sample weight
                        if 0 <= args.beta <= 1:
                            ood_sample -= length_sample
                            ood_sample = (ood_sample > 0).float()
                            temp = torch.ones(batch).to(device)
                            temp -= ood_sample * args.beta
                            weight *= temp

                real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                # real_loss = adversarial_loss(discriminator_output, (y != 0.0).float())
                real_loss = real_loss_func(discriminator_output, (y != 0.0).float())
                if n_class > 2:  # 大于2表示除了训练判别器还要训练分类器
                    class_loss = classified_loss(classification_output, y.long())
                    real_loss += class_loss
                    D_class_loss += class_loss.detach()
                real_loss.backward()

                if args.do_vis:
                    all_features.append(real_f_vector.detach())

                # # train D on fake
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device)
                else:
                    # uniform (-1,1)
                    # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device)
                    z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                # beta of fake
                if 0 <= args.beta <= 1:
                    fake_loss = args.beta * adversarial_loss(fake_discriminator_output, fake_label)
                else:
                    fake_loss = adversarial_loss(fake_discriminator_output, fake_label)
                fake_loss.backward()
                optimizer_D.step()

                if args.fine_tune:
                    optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(np.random.normal(0, 1, (batch, 32, args.G_z_dim))).to(device)
                else:
                    # uniform (-1,1)
                    # z = FloatTensor(np.random.uniform(-1, 1, (batch, args.G_z_dim))).to(device)
                    z = FloatTensor(np.random.normal(0, 1, (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z), return_feature=True)

                if args.do_vis:
                    G_features.append(fake_f_vector.detach())

                gd_loss = adversarial_loss(D_decision, valid_label)
                # feature matching loss
                fm_loss = torch.abs(torch.mean(real_f_vector.detach(), 0) - torch.mean(fake_f_vector, 0)).mean()
                # fm_loss = feature_matching_loss(torch.mean(fake_f_vector, 0), torch.mean(real_f_vector.detach(), 0))
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_d_loss += g_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()

            # logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(i, D_fake_loss / n_sample))
            # logger.info('[Epoch {}] Train: D_real_loss: {}'.format(i, D_real_loss / n_sample))
            # logger.info('[Epoch {}] Train: D_class_loss: {}'.format(i, D_class_loss / n_sample))
            # logger.info('[Epoch {}] Train: G_train_loss: {}'.format(i, G_train_loss / n_sample))
            # logger.info('[Epoch {}] Train: G_d_loss: {}'.format(i, G_d_loss / n_sample))
            # logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(i, FM_train_loss / n_sample))
            # logger.info('---------------------------------------------------------------------------')

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)

            if dev_dataset:
                # logger.info('#################### eval result at step {} ####################'.format(global_step))
                eval_result = eval(dev_dataset)

                if args.do_vis and args.do_g_eval_vis:
                    G_features = torch.cat(G_features, 0).cpu().numpy()

                    features = np.concatenate([eval_result['all_features'], G_features], axis=0)
                    features = TSNE(n_components=2, verbose=1, n_jobs=-1).fit_transform(features)
                    labels = np.concatenate([eval_result['all_binary_y'], np.array([-1] * len(G_features))], 0).reshape(
                        -1, 1)

                    data = np.concatenate([features, labels], 1)
                    fig = scatter_plot(data, processor)
                    fig.savefig(os.path.join(args.output_dir, 'plot_epoch_' + str(i) + '.png'))

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E, path=config['bert_save_path'], model_name='bert')

                # logger.info(eval_result)
                # logger.info('valid_eer: {}'.format(eval_result['eer']))
                # logger.info('valid_oos_ind_precision: {}'.format(eval_result['oos_ind_precision']))
                # logger.info('valid_oos_ind_recall: {}'.format(eval_result['oos_ind_recall']))
                # logger.info('valid_oos_ind_f_score: {}'.format(eval_result['oos_ind_f_score']))
                # logger.info('valid_auc: {}'.format(eval_result['auc']))
                # logger.info(
                #     'valid_fpr95: {}'.format(ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score

        best_dev = -early_stopping.best_score

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features
        return result
コード例 #2
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(
            D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)
        optimizer_detector = torch.optim.Adam(detector.parameters(),
                                              lr=args.detector_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []
        detector_total_train_loss = []

        all_features = []
        result = dict()

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()
            detector.train()

            G_train_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0
            detector_train_loss = 0

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                ood_sample = (y == 0.0)
                # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(
                    real_feature, return_feature=True)
                # discriminator_output = discriminator_output.squeeze()
                real_loss = adversarial_loss(discriminator_output, valid_label)
                real_loss.backward(retain_graph=True)

                if args.do_vis:
                    all_features.append(real_f_vector.detach())

                # # train D on fake
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                fake_loss = adversarial_loss(fake_discriminator_output,
                                             fake_label)
                fake_loss.backward()
                optimizer_D.step()

                # if args.fine_tune:
                #     optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z),
                                                          return_feature=True)
                gd_loss = adversarial_loss(D_decision, valid_label)
                fm_loss = torch.abs(
                    torch.mean(real_f_vector.detach(), 0) -
                    torch.mean(fake_f_vector, 0)).mean()
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                optimizer_E.zero_grad()

                # train detector
                optimizer_detector.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                if args.loss == 'v1':
                    loss_fake = adversarial_loss(
                        detector(fake_feature),
                        fake_label)  # fake sample is ood
                else:
                    loss_fake = adversarial_loss_v2(
                        detector(fake_feature),
                        fake_label.long().squeeze())
                if args.loss == 'v1':
                    loss_real = adversarial_loss(detector(real_feature),
                                                 y.float())
                else:
                    loss_real = adversarial_loss_v2(detector(real_feature),
                                                    y.long())
                if args.detect_loss == 'v1':
                    detector_loss = args.beta * loss_fake + (
                        1 - args.beta) * loss_real
                else:
                    detector_loss = args.beta * loss_fake + loss_real
                    detector_loss = args.sigma * detector_loss
                detector_loss.backward()
                optimizer_detector.step()

                if args.fine_tune:
                    optimizer_E.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()
                detector_train_loss += detector_loss

            logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(
                i, D_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_real_loss: {}'.format(
                i, D_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_class_loss: {}'.format(
                i, D_class_loss / n_sample))
            logger.info('[Epoch {}] Train: G_train_loss: {}'.format(
                i, G_train_loss / n_sample))
            logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(
                i, FM_train_loss / n_sample))
            logger.info('[Epoch {}] Train: detector_train_loss: {}'.format(
                i, detector_train_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)
            detector_total_train_loss.append(detector_train_loss / n_sample)

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E,
                                   path=config['bert_save_path'],
                                   model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score

        best_dev = -early_stopping.best_score

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features
        return result
コード例 #3
0
ファイル: run_gan.py プロジェクト: Ancrilin/Gan-s-Exploration
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(
            D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        train_loss = []
        iteration = 0

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()

            G_train_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0

            total_loss = 0

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token,
                                                   mask,
                                                   type_ids,
                                                   return_dict=False)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(
                    real_feature, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                real_loss = adversarial_loss(discriminator_output,
                                             (y != 0.0).float())
                if n_class > 2:  # 大于2表示除了训练判别器还要训练分类器
                    class_loss = classified_loss(classification_output,
                                                 y.long())
                    real_loss += class_loss
                    D_class_loss += class_loss.detach()
                real_loss.backward()

                # # train D on fake
                z = FloatTensor(np.random.normal(
                    0, 1, (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                fake_loss = args.beta * adversarial_loss(
                    fake_discriminator_output, fake_label)
                fake_loss.backward()
                optimizer_D.step()

                if args.fine_tune:
                    optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                z = FloatTensor(np.random.normal(
                    0, 1, (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z),
                                                          return_feature=True)
                gd_loss = adversarial_loss(D_decision, valid_label)
                fm_loss = torch.abs(
                    torch.mean(real_f_vector.detach(), 0) -
                    torch.mean(fake_f_vector, 0)).mean()
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()

            # train_loss.append(total_loss / n_sample)
            # iteration += 1

            logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(
                i, D_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_real_loss: {}'.format(
                i, D_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_class_loss: {}'.format(
                i, D_class_loss / n_sample))
            logger.info('[Epoch {}] Train: G_train_loss: {}'.format(
                i, G_train_loss / n_sample))
            logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(
                i, FM_train_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E,
                                   path=config['bert_save_path'],
                                   model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score