示例#1
0
def run_test(args):
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    targets, sentences = data.read_data(raw_data_path=config['test_path'],
                                        preprocessor=EnglishPreProcessor(),
                                        is_train=False)
    lines = list(zip(sentences, targets))
    processor = BertProcessor(vocab_path=config['bert_vocab_path'], do_lower_case=args.do_lower_case)
    label_list = processor.get_labels()
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(lines=test_data,
                                              example_type='test',
                                              cached_examples_file=config[
                                                                       'data_dir'] / f"cached_test_examples_{args.arch}")
    test_features = processor.create_features(examples=test_examples,
                                              max_seq_len=args.eval_max_seq_len,
                                              cached_features_file=config[
                                                                       'data_dir'] / "cached_test_features_{}_{}".format(
                                                  args.eval_max_seq_len, args.arch
                                              ))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.train_batch_size)
    model = BertForMultiLable.from_pretrained(config['checkpoint_dir'], num_labels=len(label_list))

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model,
                          logger=logger,
                          n_gpu=args.n_gpu)
    result = predictor.predict(data=test_dataloader)
    print(result)
def run_test(args):
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    ids, targets, sentences = data.read_data(
        raw_data_path=config['test_path'],
        preprocessor=ChinesePreProcessor(),
        is_train=False)
    lines = list(zip(sentences, targets))
    #print(ids,sentences)
    processor = BertProcessor(vocab_path=config['bert_vocab_path'],
                              do_lower_case=args.do_lower_case)
    label_list = processor.get_labels()
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(
        lines=test_data,
        example_type='test',
        cached_examples_file=config['data_dir'] /
        f"cached_test_examples_{args.arch}")
    test_features = processor.create_features(
        examples=test_examples,
        max_seq_len=args.eval_max_seq_len,
        cached_features_file=config['data_dir'] /
        "cached_test_features_{}_{}".format(args.eval_max_seq_len, args.arch))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=test_sampler,
                                 batch_size=args.train_batch_size,
                                 collate_fn=collate_fn)
    model = BertForMultiLable.from_pretrained(config['checkpoint_dir'],
                                              num_labels=len(label_list))

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model, logger=logger, n_gpu=args.n_gpu)
    result = predictor.predict(data=test_dataloader)
    ids = np.array(ids)
    df1 = pd.DataFrame(ids, index=None)
    df2 = pd.DataFrame(result, index=None)
    all_df = pd.concat([df1, df2], axis=1)

    all_df.columns = ['id', 'sg', 'pj']
    all_df['sg'] = all_df['sg'].apply(lambda x: 1 if x > 0.5 else 0)
    all_df['pj'] = all_df['pj'].apply(lambda x: 1 if x > 0.5 else 0)
    #all_df['qs'] = all_df['qs'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['tz'] = all_df['tz'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['ggjc'] = all_df['ggjc'].apply(lambda x: 1 if x>0.5 else 0)

    #all_df.columns = ['id','zy','gfgqzr','qs','tz','ggjc']
    #all_df['zy'] = all_df['zy'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['gfgqzr'] = all_df['gfgqzr'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['qs'] = all_df['qs'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['tz'] = all_df['tz'].apply(lambda x: 1 if x>0.5 else 0)
    #all_df['ggjc'] = all_df['ggjc'].apply(lambda x: 1 if x>0.5 else 0)
    all_df.to_csv(
        "/home/LAB/liqian/test/game/Fin/CCKS-Cls/test_output/cls_out.csv",
        index=False)
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert', type=str)
    parser.add_argument("--do_data", action='store_true')
    parser.add_argument("--do_train", action='store_true')
    parser.add_argument("--do_test", action='store_true')
    parser.add_argument("--save_best", action='store_true') 
    parser.add_argument("--do_lower_case", action='store_true')
    parser.add_argument('--data_name', default='train', type=str)
    parser.add_argument("--epochs", default=4, type=int)
    parser.add_argument("--resume_path", default='', type=str)
    parser.add_argument("--mode", default='max', type=str)
    parser.add_argument("--monitor", default='valid_f1', type=str)
    parser.add_argument("--valid_size", default=0.2, type=float)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted", default=1, type=int, help='1 : True  0:False ')
    parser.add_argument("--n_gpu", type=str, default='0', help='"0,1,.." or "0" or "" ')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument("--train_batch_size", default=8, type=int)
    parser.add_argument('--eval_batch_size', default=8, type=int)
    parser.add_argument("--train_max_seq_len", default=256, type=int)
    parser.add_argument("--eval_max_seq_len", default=256, type=int)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument("--warmup_proportion", default=0.1, type=int, )
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--grad_clip", default=1.0, type=float)
    parser.add_argument("--learning_rate", default=2e-5, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fp16_opt_level', type=str, default='O1')

    args = parser.parse_args()
    config['checkpoint_dir'] = config['checkpoint_dir'] / args.arch
    config['checkpoint_dir'].mkdir(exist_ok=True)
    # Good practice: save your training arguments together with the trained model
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)
    init_logger(log_file=config['log_dir'] / f"{args.arch}.log")

    logger.info("Training/evaluation parameters %s", args)

    if args.do_data:
        from pybert.io.task_data import TaskData
        processor = BertProcessor(vocab_path=config['bert_vocab_path'], do_lower_case=args.do_lower_case)
        label_list = processor.get_labels()
        label2id = {label: i for i, label in enumerate(label_list)}
        data = TaskData()
        targets, sentences = data.read_data(raw_data_path=config['raw_data_path'],
                                            preprocessor=None, is_train=True,label2id=label2id)
        data.train_val_split(X=sentences, y=targets, shuffle=True, stratify=targets,
                             valid_size=args.valid_size, data_dir=config['data_dir'],
                             data_name=args.data_name)
    if args.do_train:
        run_train(args)

    if args.do_test:
        run_test(args)
示例#4
0
def run_test(args):
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    # targets, sentences = data.read_data(raw_data_path=config['test_path'],
    #                                     preprocessor=EnglishPreProcessor(),
    #                                     is_train=False)
    _, _, targets, sentences = data.read_data(config, raw_data_path=config['test_path'],
                                        is_train=False)
    lines = list(zip(sentences, targets))
    # processor = BertProcessor(vocab_path=config['bert_vocab_path'], do_lower_case=args.do_lower_case)
    processor = BertProcessor()
    label_list = processor.get_labels()
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(lines=test_data,
                                              example_type='test',
                                              cached_examples_file=config[
                                            'data_dir'] / f"cached_test_examples_{args.arch}")
    test_features = processor.create_features(examples=test_examples,
                                              max_seq_len=args.eval_max_seq_len,
                                              cached_features_file=config[
                                            'data_dir'] / "cached_test_features_{}_{}".format(
                                                args.eval_max_seq_len, args.arch
                                              ))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.train_batch_size,
                                 collate_fn=collate_fn)
    model = BertForMultiLable.from_pretrained(config['checkpoint_dir'], num_labels=len(label_list))

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model,
                          logger=logger,
                          n_gpu=args.n_gpu)
    result = predictor.predict(data=test_dataloader)
    result[result<0.5] = 0
    result[result>=0.5] = 1
    labels = []
    for i in range(result.shape[0]):
        ids = np.where(result[i]==1)[0]
        each_patent_label = [id2label[id] for id in ids]
        labels.append(each_patent_label)
    if os.path.exists(config['predictions']):
        os.remove(config['predictions'])
    with open(config['test_path'], 'r') as f:
        reader = csv.reader(f)
        for j, line in enumerate(reader):
            id = line[0]
            with open(config['predictions'], 'a+') as g:
                g.write("{}\t".format(id))
                for label in labels[j]:
                    g.write("{}\t".format(label))
                g.write("\n")
def run_test(args):
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    ids,targets, sentences = data.read_data(raw_data_path=config['test_path'],
                                        preprocessor=None,
                                        is_train=False)
    lines = list(zip(sentences, targets))
    processor = BertProcessor(vocab_path=config['bert_vocab_path'], do_lower_case=args.do_lower_case)
    label_list = processor.get_labels()
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(lines=test_data,
                                              example_type='test',
                                              cached_examples_file=config[
                                                                       'data_dir'] / f"cached_test_examples_{args.arch}")
    test_features = processor.create_features(examples=test_examples,
                                              max_seq_len=args.eval_max_seq_len,
                                              cached_features_file=config[
                                                                       'data_dir'] / "cached_test_features_{}_{}".format(
                                                  args.eval_max_seq_len, args.arch
                                              ))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.train_batch_size)
    model = BertForMultiClass.from_pretrained(config['checkpoint_dir'], num_labels=len(label_list))

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model, logger=logger, n_gpu=args.n_gpu)
    result = predictor.predict(data=test_dataloader)
    import numpy as np
    result=np.argmax(result,axis=1)
    with open('submit1.csv','w',encoding='utf-8') as f:
        for id,pre in zip(ids,result):
            f.write(id+','+str(pre)+'\n')
    print(result)
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert', type=str)
    parser.add_argument("--do_data", action='store_true')
    parser.add_argument("--do_train", action='store_true')
    parser.add_argument("--do_test", action='store_true')
    parser.add_argument("--save_best", action='store_true')
    parser.add_argument("--do_lower_case", action='store_true')
    parser.add_argument('--data_name', default='kaggle', type=str)
    parser.add_argument("--mode", default='min', type=str)
    parser.add_argument("--monitor", default='valid_loss', type=str)

    parser.add_argument("--epochs", default=20, type=int)
    parser.add_argument("--resume_path", default='', type=str)
    parser.add_argument("--predict_checkpoints", type=int, default=0)
    parser.add_argument("--valid_size", default=0.2, type=float)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted",
                        default=1,
                        type=int,
                        help='1 : True  0:False ')
    parser.add_argument("--n_gpu",
                        type=str,
                        default='0',
                        help='"0,1,.." or "0" or "" ')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument("--train_batch_size", default=8, type=int)
    parser.add_argument('--eval_batch_size', default=8, type=int)
    parser.add_argument("--train_max_seq_len", default=256, type=int)
    parser.add_argument("--eval_max_seq_len", default=256, type=int)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument("--warmup_proportion", default=0.1, type=float)
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--grad_clip", default=1.0, type=float)
    parser.add_argument("--learning_rate", default=2e-5, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fp16_opt_level', type=str, default='O1')
    args = parser.parse_args()

    init_logger(
        log_file=config['log_dir'] /
        f'{args.arch}-{time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())}.log'
    )
    config['checkpoint_dir'] = config['checkpoint_dir'] / args.arch
    config['checkpoint_dir'].mkdir(exist_ok=True)
    # Good practice: save your training arguments together with the trained model
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)
    logger.info("Training/evaluation parameters %s", args)
    args.save_best = False
    args.do_train = True
    args.resume_path = 'pybert/output/checkpoints/bert/checkpoint-epoch-3'
    args.do_lower_case = True
    if args.do_data:
        from pybert.io.task_data import TaskData
        data = TaskData()
        targets, sentences = data.read_data(
            raw_data_path=config['raw_data_path'],
            preprocessor=EnglishPreProcessor(),
            is_train=True)
        data.train_val_split(X=sentences,
                             y=targets,
                             shuffle=True,
                             stratify=False,
                             valid_size=args.valid_size,
                             data_dir=config['data_dir'],
                             data_name=args.data_name)
    if args.do_train:
        run_train(args)

    if args.do_test:
        run_test(args)
示例#7
0
def run_test(args, test=False, k=7, med_map='pybert/dataset/med_map.csv'):
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    targets, sentences = data.read_data(raw_data_path=config['test_path'],
                                        preprocessor=EnglishPreProcessor(),
                                        is_train=test)
    print(
        f'-----------------------------------------\ntargets {targets}\n---------------------------------------------------'
    )
    lines = list(zip(sentences, targets))
    processor = BertProcessor(vocab_path=config['bert_vocab_path'],
                              do_lower_case=args.do_lower_case)
    label_list = processor.get_labels()
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(
        lines=test_data,
        example_type='test',
        cached_examples_file=config['data_dir'] /
        f"cached_test_examples_{args.arch}")
    test_features = processor.create_features(
        examples=test_examples,
        max_seq_len=args.eval_max_seq_len,
        cached_features_file=config['data_dir'] /
        "cached_test_features_{}_{}".format(args.eval_max_seq_len, args.arch))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=test_sampler,
                                 batch_size=args.train_batch_size)
    model = BertForMultiLable.from_pretrained(config['checkpoint_dir'])

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model,
                          logger=logger,
                          n_gpu=args.n_gpu,
                          test=test)
    if test:
        results, targets = predictor.predict(data=test_dataloader)
        #print(f'results {results.shape}')
        #print(f'targets {targets.shape}')
        result = dict()
        metrics = [Recall(), Acc()]
        for metric in metrics:
            metric.reset()
            metric(logits=results, target=targets)
            value = metric.value()
            if value is not None:
                result[f'valid_{metric.name()}'] = value
        return result
    else:
        results = predictor.predict(data=test_dataloader)
        pred = np.argsort(results)[:, -k:][:, ::-1]
        with open('pybert/dataset/med_map.csv', mode='r') as infile:
            reader = csv.reader(infile)
            med_dict = {int(rows[0]): rows[1] for rows in reader}
            pred = np.vectorize(med_dict.get)(pred)
            return pred
示例#8
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert', type=str)
    parser.add_argument("--do_data", action='store_true')
    parser.add_argument("--train", action='store_true')
    parser.add_argument("--test", action='store_true')
    parser.add_argument("--save_best", action='store_true')
    parser.add_argument("--do_lower_case", action='store_true')
    parser.add_argument('--data_name', default='job_dataset', type=str)
    parser.add_argument("--epochs", default=10, type=int)
    parser.add_argument("--resume_path", default='', type=str)
    parser.add_argument("--test_path", default='', type=str)
    parser.add_argument("--mode", default='min', type=str)
    parser.add_argument("--monitor", default='valid_loss', type=str)
    parser.add_argument("--valid_size", default=0.05, type=float)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted",
                        default=1,
                        type=int,
                        help='1 : True  0:False ')
    parser.add_argument("--n_gpu",
                        type=str,
                        default='0',
                        help='"0,1,.." or "0" or "" ')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument("--train_batch_size", default=4, type=int)
    parser.add_argument('--eval_batch_size', default=4, type=int)
    parser.add_argument("--train_max_seq_len", default=256, type=int)
    parser.add_argument("--eval_max_seq_len", default=256, type=int)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=int,
    )
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--grad_clip", default=1.0, type=float)
    parser.add_argument("--learning_rate", default=1.0e-4, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fp16_opt_level', type=str, default='O1')
    parser.add_argument('--predict_labels', type=bool, default=False)
    parser.add_argument('--predict_idx',
                        type=str,
                        default="0",
                        help=' "idx" or "start-end" or "all" ')

    args = parser.parse_args()
    config['checkpoint_dir'] = config['checkpoint_dir'] / args.arch
    config['checkpoint_dir'].mkdir(exist_ok=True)
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)
    init_logger(log_file=config['log_dir'] / f"{args.arch}.log")

    logger.info("Training/evaluation parameters %s", args)

    if args.do_data:
        from pybert.io.task_data import TaskData
        data = TaskData()
        targets, sentences = data.read_data(
            raw_data_path=config['raw_data_path'],
            preprocessor=EnglishPreProcessor(),
            is_train=True)
        data.train_val_split(X=sentences,
                             y=targets,
                             shuffle=False,
                             stratify=False,
                             valid_size=args.valid_size,
                             data_dir=config['data_dir'],
                             data_name=args.data_name)
    if args.train:
        run_train(args)

    if args.test:
        run_test(args)
示例#9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert', type=str)
    parser.add_argument("--do_data", action='store_true')
    parser.add_argument("--do_train", action='store_true')
    parser.add_argument("--do_test", action='store_true')
    parser.add_argument("--save_best", action='store_true')
    parser.add_argument("--do_lower_case", action='store_true')
    # parser.add_argument('--data_name', default='HPC', type=str)
    parser.add_argument("--mode", default='min', type=str)
    parser.add_argument("--monitor", default='valid_loss', type=str)

    parser.add_argument("--epochs", default=10, type=int)
    parser.add_argument("--resume_path", default='', type=str)
    parser.add_argument("--predict_checkpoints", type=int, default=0)
    parser.add_argument("--valid_size", default=0.2, type=float)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted", default=1, type=int, help='1 : True  0:False ')
    parser.add_argument("--n_gpu", type=str, default='0', help='"0,1,.." or "0" or "" ')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument("--train_batch_size", default=8, type=int)
    parser.add_argument('--eval_batch_size', default=8, type=int)
    parser.add_argument("--train_max_seq_len", default=256, type=int)
    parser.add_argument("--eval_max_seq_len", default=256, type=int)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument("--warmup_proportion", default=0.1, type=float)
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--grad_clip", default=1.0, type=float)
    parser.add_argument("--learning_rate", default=2e-5, type=float)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fp16_opt_level', type=str, default='O1')
    args = parser.parse_args()

    init_logger(log_file=config['log_dir'] / f'{args.arch}-{time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())}.log')
    config['checkpoint_dir'] = config['checkpoint_dir'] / args.arch
    config['checkpoint_dir'].mkdir(exist_ok=True)
    # Good practice: save your training arguments together with the trained model
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)
    logger.info("Training/evaluation parameters %s", args)

    if args.do_data:
        data_names = []
        train_sentenses_all = []
        train_target_all = []
        from pybert.io.task_data import TaskData
        data = TaskData()
        total_valid = 0
        for filename in os.listdir(config['summary_path']):
            if filename == ".DS_Store" or filename == "summary":
                continue
            filename_int = int(filename.split('.')[0].split('_')[-1])
            if filename_int > 3500:
                try:
                    raw_data_path = os.path.join(config['summary_path'], filename)
                    # train_targets, train_sentences, val_targets, val_sentences = data.read_data(config,
                    #                                                                             raw_data_path=raw_data_path,
                    #                                                                             preprocessor=EnglishPreProcessor())
                    train_targets, train_sentences, val_targets, val_sentences = data.read_data(config,
                                                                                                raw_data_path=raw_data_path)
                    train_sentenses_all = train_sentenses_all + train_sentences
                    train_target_all = train_target_all + train_targets
                    total_valid = len(train_target_all)
                    print("valid number: ", total_valid)
                    # data.save_pickle(train_sentences, train_targets, data_dir=config['data_dir'],
                    #                  data_name=filename.split('.')[0].split('_')[-1], is_train=True)
                    # data.save_pickle(val_sentences, val_targets, data_dir=config['data_dir'],
                    #                  data_name=filename.split('.')[0].split('_')[-1], is_train=False)

                    # data_names.append(filename.split('.')[0].split('_')[-1])
                except:
                    pass
        total_valid = len(train_target_all)
        print("valid number: ", total_valid)
        data.save_pickle(train_sentenses_all, train_target_all, data_dir=config['data_dir'],
                         data_name="all_valid", is_train=False)

        # with open(config['data_name'], 'w') as f:
        #     json.dump(data_names, f)

    with open(config['data_name'], 'r') as f:
        data_names = json.load(f)

    if args.do_train:
        run_train(args, data_names)

    if args.do_test:
            run_test(args)
def main():
    parser = ArgumentParser()
    parser.add_argument("--arch", default='bert', type=str)  # 使用的预训练语言模型
    parser.add_argument("--do_data", action='store_true')  # 进行数据切分
    parser.add_argument("--do_train", action='store_true')  # 进行模型训练
    parser.add_argument("--do_test", action='store_true')  # 进行模型推断
    parser.add_argument("--save_best", action='store_true')  # 保留最好的模型
    parser.add_argument("--do_lower_case", action='store_true')
    parser.add_argument('--data_name', default='ccks', type=str)  # 数据集的名字
    parser.add_argument("--mode", default='min', type=str)  # 设置monitor关注的角度
    parser.add_argument("--monitor", default='valid_loss', type=str)
    parser.add_argument("--task_type", default='base', type=str)

    parser.add_argument("--epochs", default=4, type=int)
    parser.add_argument("--resume_path", default='',
                        type=str)  # 恢复路径,从pretrained model中载入模型
    parser.add_argument("--predict_checkpoints", type=int, default=0)
    parser.add_argument("--valid_size", default=0.2, type=float)  # 验证集的大小
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--sorted",
                        default=1,
                        type=int,
                        help='1 : True  0:False ')  # 表示是否按照序列的长度排序
    parser.add_argument("--n_gpu",
                        type=str,
                        default='0',
                        help='"0,1,.." or "0" or "" ')
    parser.add_argument(
        '--gradient_accumulation_steps', type=int,
        default=1)  # gradient_accumulation_steps的大小,用于解决内存小,无法使用大batch_size的问题
    parser.add_argument("--train_batch_size", default=8,
                        type=int)  # 训练集batch_size
    parser.add_argument('--eval_batch_size', default=8,
                        type=int)  # 测试集batch_size
    parser.add_argument("--train_max_seq_len", default=256,
                        type=int)  # 训练集sequence的最大长度
    parser.add_argument("--eval_max_seq_len", default=256,
                        type=int)  # 测试集sequence的最大长度
    parser.add_argument('--loss_scale', type=float,
                        default=0)  # TODO: 理解loss scale的作用
    parser.add_argument("--warmup_proportion", default=0.1,
                        type=float)  # 用于learning rate上的warmup proportion
    parser.add_argument("--weight_decay", default=0.01,
                        type=float)  # TODO: 理解weight decay的含义
    parser.add_argument("--adam_epsilon", default=1e-8,
                        type=float)  # adam优化器的参数
    parser.add_argument("--grad_clip", default=1.0,
                        type=float)  # TODO: 理解grad clip的含义
    parser.add_argument("--learning_rate", default=2e-5, type=float)  # 学习率
    parser.add_argument('--seed', type=int, default=42)  # 随机数种子
    parser.add_argument('--fp16', action='store_true')  # TODO: 理解fp16是什么
    parser.add_argument('--fp16_opt_level', type=str, default='O1')
    args = parser.parse_args()
    # 初始化日志记录器logger
    config['log_dir'].mkdir(exist_ok=True)  # 源代码没有写这句代码
    init_logger(
        log_file=config['log_dir'] /
        f'{args.arch}-{time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())}.log'
    )
    config['checkpoint_dir'] = config[
        'checkpoint_dir'] / args.arch / args.task_type  # 重新调整输出的位置
    config['checkpoint_dir'].mkdir(exist_ok=True)
    BASE_DIR = Path('pybert')
    config[
        'raw_data_path'] = BASE_DIR / f'dataset/train_{args.task_type}_sample.csv'
    config['test_path'] = BASE_DIR / f'dataset/test_{args.task_type}.csv'
    config['figure_dir'] = config['figure_dir'] / f'{args.task_type}'
    config['figure_dir'].mkdir(exist_ok=True)
    # 动态修改文件路径
    # BASE_DIR = Path('pybert')
    # if args.task_type == 'trans':
    #     config['raw_data_path'] = BASE_DIR / 'dataset/train_trans_sample.csv'
    #     config['test_path'] = BASE_DIR / 'dataset/test_trans.csv'
    #     config['figure_dir'] = config['figure_dir'] / f'{args.task_type}'
    #     config['figure_dir'].mkdir(exist_ok=True)
    # elif args.task_type == 'base':
    #     config['raw_data_path'] = BASE_DIR / 'dataset/train_base_sample.csv'
    #     config['test_path'] = BASE_DIR / 'dataset/test_base.csv'
    #     config['figure_dir'] = config['figure_dir'] / f'{args.task_type}'
    #     config['figure_dir'].mkdir(exist_ok=True)
    # else:
    #     raise ValueError(f"Invalid task_type {args.task_type}")

    # Good practice: save your training arguments together with the trained model
    torch.save(args, config['checkpoint_dir'] / 'training_args.bin')
    seed_everything(args.seed)  # 一个方法设置所有的seed
    logger.info("Training/evaluation parameters %s", args)
    if args.do_data:
        from pybert.io.task_data import TaskData
        data = TaskData()
        ids, targets, sentences = data.read_data(
            raw_data_path=config['raw_data_path'],
            preprocessor=ChinesePreProcessor(),
            is_train=True)
        data.train_val_split(X=sentences,
                             y=targets,
                             shuffle=True,
                             stratify=False,
                             valid_size=args.valid_size,
                             data_dir=config['data_dir'],
                             data_name=args.data_name,
                             task_type=args.task_type)  # 增加了task_type参数
    if args.do_train:
        run_train(args)

    if args.do_test:
        run_test(args)
def run_test(args):
    # TODO: 对训练集使用micro F1-score进行结果评测
    from pybert.io.task_data import TaskData
    from pybert.test.predictor import Predictor
    data = TaskData()
    ids, targets, sentences = data.read_data(
        raw_data_path=config['test_path'],
        preprocessor=ChinesePreProcessor(),
        is_train=True)  # 设置为True
    lines = list(zip(sentences, targets))
    #print(ids,sentences)
    processor = BertProcessor(vocab_path=config['bert_vocab_path'],
                              do_lower_case=args.do_lower_case)
    label_list = processor.get_labels(args.task_type)
    id2label = {i: label for i, label in enumerate(label_list)}

    test_data = processor.get_test(lines=lines)
    test_examples = processor.create_examples(
        lines=test_data,
        example_type=f'test_{args.task_type}',
        cached_examples_file=config['data_dir'] /
        f"cached_test_{args.task_type}_examples_{args.arch}")
    test_features = processor.create_features(
        examples=test_examples,
        max_seq_len=args.eval_max_seq_len,
        cached_features_file=config['data_dir'] /
        "cached_test_{}_features_{}_{}".format(
            args.task_type, args.eval_max_seq_len, args.arch))
    test_dataset = processor.create_dataset(test_features)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=test_sampler,
                                 batch_size=args.train_batch_size,
                                 collate_fn=collate_fn)
    model = None
    if args.task_type == 'base':
        model = BertForMultiLable.from_pretrained(config['checkpoint_dir'],
                                                  num_labels=len(label_list))
    else:
        # model = BertForMultiLable.from_pretrained(config['checkpoint_dir'], num_labels=len(label_list))
        model = BertForMultiLable_Fewshot.from_pretrained(
            config['checkpoint_dir'], num_labels=len(label_list))

    # ----------- predicting
    logger.info('model predicting....')
    predictor = Predictor(model=model, logger=logger, n_gpu=args.n_gpu)
    result = predictor.predict(data=test_dataloader)  # 感觉这个变量名叫all_logits可能更好
    # TODO: 计算F1-score,这个功能模块需要用代码测试一下~
    f1_metric = F1Score(task_type='binary',
                        average='micro',
                        search_thresh=True)
    all_logits = torch.tensor(result, dtype=torch.float)  # 转换成tensor
    all_labels = torch.tensor(targets, dtype=torch.long)  # 转换成tensor
    f1_metric(all_logits, all_labels)  # 会自动打印结果
    print(f1_metric.value())
    # 将结果写入一个文件之中
    with open('test_output/test.log', 'a+') as f:
        f.write(str(f1_metric.value()) + "\n")
    thresh = f1_metric.thresh

    ids = np.array(ids)
    df1 = pd.DataFrame(ids, index=None)
    df2 = pd.DataFrame(result, index=None)
    all_df = pd.concat([df1, df2], axis=1)
    if args.task_type == 'base':
        all_df.columns = ['id', 'zy', 'gfgqzr', 'qs', 'tz', 'ggjc']
    else:
        all_df.columns = ['id', 'sg', 'pj', 'zb', 'qsht', 'db']
    for column in all_df.columns[1:]:
        all_df[column] = all_df[column].apply(lambda x: 1 if x > thresh else 0)
    # all_df['zy'] = all_df['zy'].apply(lambda x: 1 if x>thresh else 0)
    # all_df['gfgqzr'] = all_df['gfgqzr'].apply(lambda x: 1 if x>thresh else 0)
    # all_df['qs'] = all_df['qs'].apply(lambda x: 1 if x>thresh else 0)
    # all_df['tz'] = all_df['tz'].apply(lambda x: 1 if x>thresh else 0)
    # all_df['ggjc'] = all_df['ggjc'].apply(lambda x: 1 if x>thresh else 0)
    all_df.to_csv(f"test_output/{args.task_type}/cls_out.csv", index=False)