コード例 #1
0
def evaluate(args):
    logger = logging.getLogger('Causality')
    logger.info('Loading valid file...')
    with open(args.valid_record_file, 'rb') as fh:
        valid_file = pkl.load(fh)
    fh.close()
    logger.info('Loading test file...')
    with open(args.test_record_file, 'rb') as fh:
        test_file = pkl.load(fh)
    fh.close()
    logger.info('Loading valid meta...')
    valid_meta = load_json(args.valid_meta)
    logger.info('Loading test meta...')
    test_meta = load_json(args.test_meta)
    logger.info('Loading id to token file...')
    id2token_file = load_json(args.id2token_file)
    logger.info('Loading token embeddings...')
    with open(args.token_emb_file, 'rb') as fh:
        token_embeddings = pkl.load(fh)
    fh.close()
    valid_num = valid_meta['total']
    test_num = test_meta['total']

    logger.info('Loading shape meta...')
    logger.info('Num valid data {} test data {}'.format(valid_num, test_num))

    args.dropout = {'emb': args.emb_dropout, 'layer': args.layer_dropout}
    model = getattr(models, args.model)(token_embeddings, args,
                                        logger).to(device=args.device)
    model.load_state_dict(torch.load(os.path.join(args.model_dir,
                                                  'model.bin')))

    eval_metrics, fpr, tpr, precision, recall = evaluate_batch(
        model, test_num, args.batch_eval, test_file, args.device, args.is_fc,
        'eval', logger)
    logger.info('Eval Loss - {}'.format(eval_metrics['loss']))
    logger.info('Eval Acc - {}'.format(eval_metrics['acc']))
    logger.info('Eval Precision - {}'.format(eval_metrics['precision']))
    logger.info('Eval Recall - {}'.format(eval_metrics['recall']))
    logger.info('Eval F1 - {}'.format(eval_metrics['f1']))
    logger.info('Eval AUCROC - {}'.format(eval_metrics['auc_roc']))
    logger.info('Eval AUCPRC - {}'.format(eval_metrics['auc_prc']))

    # if args.model == 'MCDN' or args.model == 'TB':
    #     draw_att(model, test_num, args.batch_eval, test_file, args.device, id2token_file,
    #              args.pics_dir, args.n_block, args.n_head, logger)

    FALSE = {'FP': eval_metrics['fp'], 'FN': eval_metrics['fn']}
    ROC = {'FPR': fpr, 'TPR': tpr}
    PRC = {'PRECISION': precision, 'RECALL': recall}

    dump_json(os.path.join(args.result_dir, 'FALSE_transfer.json'), FALSE)
    dump_json(os.path.join(args.result_dir, 'ROC_transfer.json'), ROC)
    dump_json(os.path.join(args.result_dir, 'PRC_transfer.json'), PRC)
    draw_curve(ROC['FPR'], ROC['TPR'], PRC['PRECISION'], PRC['RECALL'],
               args.pics_dir)
コード例 #2
0
def train(args):
    logger = logging.getLogger('Causality')
    logger.info('Loading train file...')
    with open(args.train_record_file, 'rb') as fh:
        train_file = pkl.load(fh)
    fh.close()
    logger.info('Loading valid file...')
    with open(args.valid_record_file, 'rb') as fh:
        valid_file = pkl.load(fh)
    fh.close()
    logger.info('Loading train meta...')
    train_meta = load_json(args.train_meta)
    logger.info('Loading valid meta...')
    valid_meta = load_json(args.valid_meta)
    logger.info('Loading token embeddings...')
    with open(args.token_emb_file, 'rb') as fh:
        token_embeddings = pkl.load(fh)
    fh.close()
    train_num = train_meta['total']
    valid_num = valid_meta['total']

    logger.info('Loading shape meta...')
    logger.info('Num train data {} valid data {}'.format(train_num, valid_num))

    args.dropout = {'emb': args.emb_dropout, 'layer': args.layer_dropout}
    logger.info('Initialize the model...')
    model = getattr(models, args.model)(token_embeddings, args,
                                        logger).to(device=args.device)
    lr = args.lr
    optimizer = getattr(optim, args.optim)(model.parameters(),
                                           lr=lr,
                                           weight_decay=args.weight_decay)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=args.patience, verbose=True)
    scheduler = WarmupCosineSchedule(optimizer, args.warmup,
                                     (train_num // args.batch_train + 1) *
                                     args.epochs)
    # torch.backends.cudnn.benchmark = True
    max_acc, max_p, max_r, max_f, max_roc, max_prc, max_sum, max_epoch = np.zeros(
        8)
    FALSE, ROC, PRC = {}, {}, {}
    train_loss, valid_loss = [], []
    for ep in range(1, args.epochs + 1):
        logger.info('Training the model for epoch {}'.format(ep))
        avg_loss = train_one_epoch(model, optimizer, scheduler, train_num,
                                   train_file, args, logger)
        train_loss.append(avg_loss)
        logger.info('Epoch {} AvgLoss {}'.format(ep, avg_loss))

        logger.info('Evaluating the model for epoch {}'.format(ep))
        eval_metrics, fpr, tpr, precision, recall = evaluate_batch(
            model, valid_num, args.batch_eval, valid_file, args.device,
            args.is_fc, 'valid', logger)
        valid_loss.append(eval_metrics['loss'])
        logger.info('Valid Loss - {}'.format(eval_metrics['loss']))
        logger.info('Valid Acc - {}'.format(eval_metrics['acc']))
        logger.info('Valid Precision - {}'.format(eval_metrics['precision']))
        logger.info('Valid Recall - {}'.format(eval_metrics['recall']))
        logger.info('Valid F1 - {}'.format(eval_metrics['f1']))
        logger.info('Valid AUCROC - {}'.format(eval_metrics['auc_roc']))
        logger.info('Valid AUCPRC - {}'.format(eval_metrics['auc_prc']))
        max_acc = max((eval_metrics['acc'], max_acc))
        max_p = max(eval_metrics['precision'], max_p)
        max_r = max(eval_metrics['recall'], max_r)
        max_f = max(eval_metrics['f1'], max_f)
        valid_sum = eval_metrics['auc_roc'] + eval_metrics[
            'auc_prc'] + eval_metrics['f1']
        if valid_sum > max_sum:
            max_acc = eval_metrics['acc']
            max_p = eval_metrics['precision']
            max_r = eval_metrics['recall']
            max_f = eval_metrics['f1']
            max_roc = eval_metrics['auc_roc']
            max_prc = eval_metrics['auc_prc']
            max_sum = valid_sum
            max_epoch = ep
            FALSE = {'FP': eval_metrics['fp'], 'FN': eval_metrics['fn']}
            ROC = {'FPR': fpr, 'TPR': tpr}
            PRC = {'PRECISION': precision, 'RECALL': recall}
            torch.save(model.state_dict(),
                       os.path.join(args.model_dir, 'model.bin'))

        # scheduler.step(metrics=eval_metrics['f1'])
        random.shuffle(train_file)

    logger.info('Max Acc - {}'.format(max_acc))
    logger.info('Max Precision - {}'.format(max_p))
    logger.info('Max Recall - {}'.format(max_r))
    logger.info('Max F1 - {}'.format(max_f))
    logger.info('Max ROC - {}'.format(max_roc))
    logger.info('Max PRC - {}'.format(max_prc))
    logger.info('Max Epoch - {}'.format(max_epoch))
    logger.info('Max Sum - {}'.format(max_sum))

    dump_json(os.path.join(args.result_dir, 'FALSE_valid.json'), FALSE)
    dump_json(os.path.join(args.result_dir, 'ROC_valid.json'), ROC)
    dump_json(os.path.join(args.result_dir, 'PRC_valid.json'), PRC)
    save_loss(train_loss, valid_loss, args.result_dir)
    draw_curve(ROC['FPR'], ROC['TPR'], PRC['PRECISION'], PRC['RECALL'],
               args.pics_dir)
def train(args, file_paths):
    logger = logging.getLogger('2010')
    logger.info('Loading train file...')
    with open(file_paths.train_record_file, 'rb') as fh:
        train_file = pkl.load(fh)
    fh.close()
    logger.info('Loading valid file...')
    with open(file_paths.valid_record_file, 'rb') as fh:
        valid_file = pkl.load(fh)
    fh.close()
    logger.info('Loading train meta...')
    with open(file_paths.train_meta, 'r') as fh:
        train_meta = json.load(fh)
    fh.close()
    logger.info('Loading valid meta...')
    with open(file_paths.valid_meta, 'r') as fh:
        valid_meta = json.load(fh)
    fh.close()
    logger.info('Loading token embeddings...')
    with open(file_paths.token_emb_file, 'rb') as fh:
        token_embeddings = pkl.load(fh)
    fh.close()
    train_num = train_meta['total']
    valid_num = valid_meta['total']

    logger.info('Loading shape meta...')
    logger.info('Num train data {} valid data {}'.format(train_num, valid_num))

    dropout = {'emb': args.emb_dropout, 'layer': args.layer_dropout}
    logger.info('Initialize the model...')

    model = MCIN(token_embeddings, args.max_len, args.n_class, args.n_hidden,
                 args.n_layer, args.n_kernels, args.n_filter, args.n_block,
                 args.n_head, args.is_sinusoid, args.is_ffn, dropout,
                 logger).to(device=args.device)
    # model = SCRN(token_embeddings, args.max_len, args.n_class, args.n_hidden, args.n_layer,
    #              args.n_kernels, args.n_filter, args.n_block, args.n_head, args.is_sinusoid, args.is_ffn,
    #              dropout, logger).to(device=args.device)
    # model = TB(token_embeddings, args.max_len, args.n_class, args.n_hidden, args.n_layer,
    #            args.n_kernels, args.n_filter, args.n_block, args.n_head, args.is_sinusoid, args.is_ffn,
    #            dropout, logger).to(device=args.device)
    lr = args.lr
    optimizer = getattr(optim, args.optim)(model.parameters(),
                                           lr=lr,
                                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'max',
                                                     0.5,
                                                     patience=args.patience,
                                                     verbose=True)
    max_acc, max_p, max_r, max_f, max_sum, max_epoch = 0, 0, 0, 0, 0, 0
    FALSE = {}
    ROC = {}
    PRC = {}
    for ep in range(1, args.epochs + 1):
        logger.info('Training the model for epoch {}'.format(ep))
        avg_loss = train_one_epoch(model, optimizer, train_num, train_file,
                                   args, logger)
        logger.info('Epoch {} AvgLoss {}'.format(ep, avg_loss))

        logger.info('Evaluating the model for epoch {}'.format(ep))
        eval_metrics, fpr, tpr, precision, recall = evaluate_batch(
            model, valid_num, args.batch_eval, valid_file, args.device,
            args.is_fc, 'valid', logger)
        logger.info('Valid Loss - {}'.format(eval_metrics['loss']))
        logger.info('Valid Acc - {}'.format(eval_metrics['acc']))
        logger.info('Valid Precision - {}'.format(eval_metrics['precision']))
        logger.info('Valid Recall - {}'.format(eval_metrics['recall']))
        logger.info('Valid F1 - {}'.format(eval_metrics['f1']))
        logger.info('Valid AUCROC - {}'.format(eval_metrics['auc_roc']))
        logger.info('Valid AUCPRC - {}'.format(eval_metrics['auc_prc']))
        max_acc = max((eval_metrics['acc'], max_acc))
        max_p = max(eval_metrics['precision'], max_p)
        max_r = max(eval_metrics['recall'], max_r)
        max_f = max(eval_metrics['f1'], max_f)
        valid_sum = eval_metrics['precision'] + eval_metrics[
            'recall'] + eval_metrics['f1']
        if valid_sum > max_sum:
            max_sum = valid_sum
            max_epoch = ep
            FALSE = {'FP': eval_metrics['fp'], 'FN': eval_metrics['fn']}
            ROC = {'FPR': fpr, 'TPR': tpr}
            PRC = {'PRECISION': precision, 'RECALL': recall}
            # torch.save(model, os.path.join(args.model_dir, 'model.pth'))
            torch.save(model.state_dict(),
                       os.path.join(args.model_dir, 'model.bin'))

        scheduler.step(metrics=eval_metrics['f1'])
        random.shuffle(train_file)

    logger.info('Max Acc - {}'.format(max_acc))
    logger.info('Max Precision - {}'.format(max_p))
    logger.info('Max Recall - {}'.format(max_r))
    logger.info('Max F1 - {}'.format(max_f))
    logger.info('Max Epoch - {}'.format(max_epoch))
    logger.info('Max Sum - {}'.format(max_sum))
    with open(os.path.join(args.result_dir, 'FALSE.json'), 'w') as f:
        f.write(json.dumps(FALSE) + '\n')
    f.close()
    with open(os.path.join(args.result_dir, 'ROC.json'), 'w') as f:
        f.write(json.dumps(ROC) + '\n')
    f.close()
    with open(os.path.join(args.result_dir, 'PRC.json'), 'w') as f:
        f.write(json.dumps(PRC) + '\n')
    f.close()
    draw_curve(ROC['FPR'], ROC['TPR'], PRC['PRECISION'], PRC['RECALL'],
               args.pics_dir)
コード例 #4
0
def train(args, file_paths):
    logger = logging.getLogger('Causality')
    logger.info('Loading train file...')
    with open(file_paths.train_record_file, 'rb') as fh:
        train_file = pkl.load(fh)
    fh.close()
    logger.info('Loading valid file...')
    with open(file_paths.valid_record_file, 'rb') as fh:
        valid_file = pkl.load(fh)
    fh.close()
    logger.info('Loading train meta...')
    train_meta = load_json(file_paths.train_meta)
    logger.info('Loading valid meta...')
    valid_meta = load_json(file_paths.valid_meta)
    logger.info('Loading token embeddings...')
    with open(file_paths.token_emb_file, 'rb') as fh:
        token_embeddings = pkl.load(fh)
    fh.close()
    train_num = train_meta['total']
    valid_num = valid_meta['total']

    logger.info('Loading shape meta...')
    logger.info('Num train data {} valid data {}'.format(train_num, valid_num))

    args.dropout = {'emb': args.emb_dropout, 'layer': args.layer_dropout}
    records = np.zeros((args.multi, 7), dtype=np.float)
    best_sum = -1
    for i in range(1, args.multi + 1):
        logger.info('Initialize the model...')
        # model = getattr(models, args.model)(token_embeddings, args, logger).to(device=args.device)
        # model = TCN(token_embeddings, args.max_len['full'], args.n_class, n_channel=[args.n_filter] * args.n_level,
        #             n_kernel=args.n_kernel, n_block=args.n_block, n_head=args.n_head, dropout=dropout, logger=logger).
        #             to(device=args.device)
        # model = BiGRU(token_embeddings, args.max_len['full'], args.n_class, args.n_hidden, args.n_layer, args.n_block,
        #               args.n_head, args.is_sinusoid, args.is_ffn, dropout, logger).to(device=args.device)
        # model = TextCNN(token_embeddings, args.max_len, args.n_class, args.n_kernels, args.n_filter, args.is_pos,
        #                 args.is_sinusoid, args.dropout, logger).to(device=args.device)
        # model = TextCNNDeep(token_embeddings, args.max_len, args.n_class, args.n_kernels, args.n_filter,
        #                     args.dropout, logger).to(device=args.device)
        # model = DPCNN(token_embeddings, args, logger).to(device=args.device)
        # model = TextRNN(token_embeddings, args.n_class, args.n_hidden, args.n_layer, args.kmax_pooling,
        #                 args.is_pos, args.is_sinusoid, args.dropout, logger).to(device=args.device)
        model = SelfAttentive(token_embeddings, args.n_class, args.n_hidden,
                              args.n_layer, 128, 32, args.dropout,
                              logger).to(device=args.device)
        lr = args.lr
        optimizer = getattr(optim, args.optim)(model.parameters(),
                                               lr=lr,
                                               weight_decay=args.weight_decay)
        # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=args.patience, verbose=True)
        scheduler = WarmupCosineSchedule(optimizer, args.warmup,
                                         (train_num // args.batch_train + 1) *
                                         args.epochs)
        logger.info('Turn {}'.format(i))
        max_acc, max_p, max_r, max_f, max_roc, max_prc, max_sum, max_epoch = np.zeros(
            8)
        FALSE, ROC, PRC = {}, {}, {}
        train_loss, valid_loss = [], []
        is_best = False
        for ep in range(1, args.epochs + 1):
            logger.info('Training the model for epoch {}'.format(ep))
            avg_loss = train_one_epoch(model, optimizer, scheduler, train_num,
                                       train_file, args, logger)
            train_loss.append(avg_loss)
            logger.info('Epoch {} AvgLoss {}'.format(ep, avg_loss))

            logger.info('Evaluating the model for epoch {}'.format(ep))
            eval_metrics, fpr, tpr, precision, recall = evaluate_batch(
                model, valid_num, args.batch_eval, valid_file, args.device,
                args.is_fc, 'valid', logger)
            valid_loss.append(eval_metrics['loss'])
            logger.info('Valid Loss - {}'.format(eval_metrics['loss']))
            logger.info('Valid Acc - {}'.format(eval_metrics['acc']))
            logger.info('Valid Precision - {}'.format(
                eval_metrics['precision']))
            logger.info('Valid Recall - {}'.format(eval_metrics['recall']))
            logger.info('Valid F1 - {}'.format(eval_metrics['f1']))
            logger.info('Valid AUCROC - {}'.format(eval_metrics['auc_roc']))
            logger.info('Valid AUCPRC - {}'.format(eval_metrics['auc_prc']))
            valid_sum = eval_metrics['acc'] + eval_metrics[
                'precision'] + eval_metrics['recall'] + eval_metrics['f1']
            if valid_sum > max_sum:
                max_acc = eval_metrics['acc']
                max_p = eval_metrics['precision']
                max_r = eval_metrics['recall']
                max_f = eval_metrics['f1']
                max_roc = eval_metrics['auc_roc']
                max_prc = eval_metrics['auc_prc']
                max_sum = valid_sum
                max_epoch = ep
                FALSE = {'FP': eval_metrics['fp'], 'FN': eval_metrics['fn']}
                ROC = {'FPR': fpr, 'TPR': tpr}
                PRC = {'PRECISION': precision, 'RECALL': recall}
                # torch.save(model, os.path.join(args.model_dir, 'model.pth'))
                if max_sum > best_sum:
                    best_sum = max_sum
                    is_best = True
                    torch.save(model.state_dict(),
                               os.path.join(args.model_dir, 'model.bin'))

            random.shuffle(train_file)

        logger.info('Max Acc - {}'.format(max_acc))
        logger.info('Max Precision - {}'.format(max_p))
        logger.info('Max Recall - {}'.format(max_r))
        logger.info('Max F1 - {}'.format(max_f))
        logger.info('Max ROC - {}'.format(max_roc))
        logger.info('Max PRC - {}'.format(max_prc))
        logger.info('Max Epoch - {}'.format(max_epoch))
        logger.info('Max Sum - {}'.format(max_sum))

        if is_best:
            dump_json(os.path.join(args.result_dir, 'FALSE_valid.json'), FALSE)
            dump_json(os.path.join(args.result_dir, 'ROC_valid.json'), ROC)
            dump_json(os.path.join(args.result_dir, 'PRC_valid.json'), PRC)
            save_loss(train_loss, valid_loss, args.result_dir)
            draw_curve(ROC['FPR'], ROC['TPR'], PRC['PRECISION'], PRC['RECALL'],
                       args.pics_dir)

        records[i - 1] = [
            max_acc, max_p, max_r, max_f, max_roc, max_prc, max_epoch
        ]

    save_metrics(records, args.result_dir)