示例#1
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                if args.BCE:
                    loss = adversarial_loss(discriminator_output,
                                            (y != 0.0).float())
                else:
                    loss = classified_loss(discriminator_output, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss
示例#2
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                logits = model(token, mask, type_ids)
                loss = classified_loss(logits, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(eval_result['accuracy'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                # logger.info(eval_result)

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss
示例#3
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step

        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer_pos = torch.optim.Adam(pos.parameters(), lr=args.pos_lr)
        optimizer_E = AdamW(E.parameters(), args.bert_lr)

        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        train_loss = []
        iteration = 0

        for i in range(args.n_epoch):
            logger.info('***********************************')
            logger.info('epoch: {}'.format(i))

            # Initialize model state
            pos.train()
            E.train()

            total_loss = 0
            for sample in tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, pos1, pos2, pos_mask, y = sample
                batch = len(token)

                optimizer_E.zero_grad()
                optimizer_pos.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                loss = adversarial_loss(out, y.float())
                loss.backward()
                total_loss += loss.detach()

                if args.fine_tune:
                    optimizer_E.step()

                optimizer_pos.step()

            logger.info('[Epoch {}] Train: loss: {}'.format(
                i, total_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )
            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                # elif signal == 0:
                #     pass
                # elif signal == 1:
                #     save_gan_model(D, G, config['gan_save_path'])
                #     if args.fine_tune:
                #         save_model(E, path=config['bert_save_path'], model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        best_dev = -early_stopping.best_score
        # 绘制训练损失曲线
        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)