def main():
    parser = build_argparse()
    parser.add_argument('--look_k', type=int, default=5)
    parser.add_argument('--look_alpha', type=float, default=0.5)
    args = parser.parse_args()
    # output dir
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)
    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path, do_lower_case=args.do_lower_case)
    processor = ColaProcessor(data_dir=args.data_dir,
                              tokenizer=tokenizer,
                              prefix=prefix)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    args.num_labels = num_labels
    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(
        args.model_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)
    # trainer
    logger.info("initializing traniner")
    trainer = LookaheadTrainer(logger=logger,
                               args=args,
                               collate_fn=processor.collate_fn,
                               batch_input_keys=processor.get_batch_keys(),
                               metrics=[MattewsCorrcoef()])
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length,
                                                 'train.tsv', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.tsv', 'dev')
        trainer.train(model,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset)
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.tsv', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir,
                                          args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model,
                             eval_dataset,
                             save_preds=True,
                             prefix=str(global_step))
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in trainer.records['result'].items()
                }
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'test.tsv', 'test')
        if args.checkpoint_number == 0:
            raise ValueError("checkpoint number should > 0,but get %d",
                             args.checkpoint_number)
        checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number,
                                      WEIGHTS_NAME)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model,
                            test_dataset=test_dataset,
                            prefix=str(global_step))
Beispiel #2
0
def main():
    parser = build_argparse()
    parser.add_argument('--markup', type=str, default='bios', choices=['bios', 'bio'])
    parser.add_argument('--use_crf', action='store_true', default=True)
    parser.add_argument('--crf_learning_rate', default=1e-3, type=float)
    args = parser.parse_args()

    # output dir
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)

    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(args.model_path, do_lower_case=args.do_lower_case)
    processor = CnerProcessor(data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    id2label = {i: label for i, label in enumerate(label_list)}
    args.id2label = id2label
    args.num_labels = num_labels

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(args.model_path, num_labels=num_labels,
                                          cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)
    # trainer
    logger.info("initializing traniner")
    trainer = LayerLRTrainer(logger=logger, args=args, collate_fn=processor.collate_fn,
                             input_keys=processor.get_input_keys(),
                             metrics=[SequenceLabelingScore(id2label, markup=args.markup)])
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length, 'train.char.bmes', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.char.bmes', 'dev')
        trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset)
    # do eval
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.char.bmes', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model, eval_dataset, save_preds=True, prefix=str(global_step))
            if global_step:
                result = {"{}_{}".format(global_step, k): v for k, v in trainer.records['result'].items()}
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)
    # do predict
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length, 'test.char.bmes', 'test')
        if args.checkpoint_number == 0:
            raise ValueError("checkpoint number should > 0,but get %d", args.checkpoint_number)
        checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number, WEIGHTS_NAME)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model, test_dataset=test_dataset, prefix=str(global_step))
def main():
    args = build_argparse().parse_args()
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)

    # output dir
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)
    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path, do_lower_case=args.do_lower_case)
    processor = WSCProcessor(data_dir=args.data_dir,
                             tokenizer=tokenizer,
                             prefix=prefix)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    id2label = {i: label for i, label in enumerate(label_list)}
    args.num_labels = num_labels
    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(
        args.model_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)

    # trainer
    logger.info("initializing traniner")
    trainer = TextClassifierTrainer(
        logger=logger,
        args=args,
        collate_fn=processor.collate_fn,
        batch_input_keys=processor.get_batch_keys(),
        metrics=[Accuracy()])

    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length,
                                                 'train.json', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.json', 'dev')
        trainer.train(model,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset)

    # do eval
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.json', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir,
                                          args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model,
                             eval_dataset,
                             save_preds=True,
                             prefix=str(global_step))
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in trainer.records['result'].items()
                }
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)

    # do predict
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'test.json', 'test')
        checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number,
                                      WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model,
                            test_dataset=test_dataset,
                            prefix=str(global_step))
            predict_label = trainer.records['preds'].argmax(dim=1).numpy()
            output_submit_file = os.path.join(
                args.output_dir, f"{args.task_name}_predict.json")
            # 保存标签结果
            with open(output_submit_file, "w") as writer:
                for i, pred in enumerate(predict_label):
                    json_d = {}
                    json_d['id'] = i
                    json_d['label'] = str(id2label[pred])
                    writer.write(json.dumps(json_d) + '\n')
Beispiel #4
0
def main():
    # args = build_argparse().parse_args()
    parser = build_argparse()
    parser.add_argument('--adv_lr', type=float, default=1e-2)
    parser.add_argument('--adv_K', type=int, default=3, help="should be at least 1")
    parser.add_argument('--adv_init_mag', type=float, default=2e-2)
    parser.add_argument('--adv_norm_type', type=str, default="l2", choices=["l2", "linf"])
    parser.add_argument('--adv_max_norm', type=float, default=0, help="set to 0 to be unlimited")
    parser.add_argument('--base_model', default='bert')
    parser.add_argument('--hidden_dropout_prob', type=float, default=0.1)
    parser.add_argument('--attention_probs_dropout_prob', type=float, default=0)
    args = parser.parse_args()
    if args.model_path is None:
        args.model_path = args.model_name
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)

    # output dir
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)

    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(args.model_path, do_lower_case=args.do_lower_case)
    # processor = ChnSentiProcessor(data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix)
    processor = CommonDataProcessor(data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    args.num_labels = num_labels

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(args.model_path, num_labels=num_labels,
                                          cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)

    # trainer
    logger.info("initializing traniner")
    trainer = FreelbTrainer(logger=logger, args=args, collate_fn=processor.collate_fn,
                            input_keys=processor.get_input_keys(),
                            metrics=[Accuracy()])
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length, 'train.csv', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.csv', 'dev')
        trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset)
    # do eval
    checkpoint_numbers = list()
    loss_list = list()
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.csv', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model, eval_dataset, save_preds=True, prefix=str(global_step))
            if global_step:
                result = {"{}_{}".format(global_step, k): v for k, v in trainer.records['result'].items()}
                results.update(result)
            # 筛选出最好的三个
            loss_list.append(trainer.records['result']['eval_loss'])

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)

        if len(loss_list)>3:
            sorted_loss_list = loss_list.sort()[:3]
            for i, k in enumerate(loss_list):
                if k in sorted_loss_list:
                    checkpoint_numbers.append(k)
        else:
            checkpoint_numbers = [i for i in range(len(loss_list))]
    # do predict
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length, 'test.csv', 'test')
        if args.checkpoint_number != 0:
            checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number, WEIGHTS_NAME)
        else:
            checkpoints = list()
            for i in checkpoint_numbers:
                checkpoints.extend(get_checkpoints(args.output_dir, i, WEIGHTS_NAME))

        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model, test_dataset=test_dataset, prefix=str(global_step))
def main():
    parser = build_argparse()
    # bert for theseus
    parser.add_argument('--replacing_rate', default=0.3, required=True, type=float,
                        help="Constant replacing rate. Also base replacing rate if using a scheduler.")
    parser.add_argument("--scheduler_type", default='none', choices=['none', 'linear'], help="Scheduler function.")
    parser.add_argument("--scheduler_linear_k", default=0, type=float, help="Linear k for replacement scheduler.")
    parser.add_argument("--steps_for_replacing", default=0, type=int,
                        help="Steps before entering successor fine_tuning (only useful for constant replacing)")
    parser.add_argument('--predecessor_model_path', type=str, required=True)
    args = parser.parse_args()
    # output dir
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)
    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(args.model_path, do_lower_case=args.do_lower_case)
    processor = Sst2Processor(data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    args.num_labels = num_labels
    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(args.model_path, num_labels=num_labels,
                                          cache_dir=args.cache_dir if args.cache_dir else None)
    config.output_hidden_states = True
    model = model_class.from_pretrained(args.predecessor_model_path, config=config)
    scc_n_layer = model.bert.encoder.scc_n_layer
    model.bert.encoder.scc_layer = nn.ModuleList([deepcopy(model.bert.encoder.layer[ix]) for ix in range(scc_n_layer)])
    model.to(args.device)
    # trainer
    logger.info("initializing traniner")
    # Replace rate scheduler
    if args.scheduler_type == 'linear':
        replacing_rate_scheduler = LinearReplacementScheduler(bert_encoder=model.bert.encoder,
                                                              base_replacing_rate=args.replacing_rate,
                                                              k=args.scheduler_linear_k)
    elif args.scheduler_type == 'none':
        replacing_rate_scheduler = ConstantReplacementScheduler(bert_encoder=model.bert.encoder,
                                                                replacing_rate=args.replacing_rate,
                                                                replacing_steps=args.steps_for_replacing)
    trainer = TheseusTrainer(logger=logger, args=args,
                             batch_input_keys=processor.get_batch_keys(),
                             replacing_rate_scheduler=replacing_rate_scheduler,
                             collate_fn=processor.collate_fn,
                             metrics=[Accuracy()])
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length, 'train.tsv', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length, 'dev.tsv', 'dev')
        trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset)
Beispiel #6
0
def main():
    args = build_argparse().parse_args()
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + "{}".format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)

    # output dir
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)

    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path, do_lower_case=args.do_lower_case
    )
    processor = ChnSentiProcessor(
        data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix
    )
    label_list = processor.get_labels()
    num_labels = len(label_list)
    args.num_labels = num_labels

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(
        args.model_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)

    # trainer
    logger.info("initializing traniner")
    trainer = TextClassifierTrainer(
        logger=logger,
        args=args,
        collate_fn=processor.collate_fn,
        input_keys=processor.get_input_keys(),
        metrics=[Accuracy()],
    )
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(
            args.train_max_seq_length, "train.tsv", "train"
        )
        eval_dataset = processor.create_dataset(
            args.eval_max_seq_length, "dev.tsv", "dev"
        )
        trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset)
    # do eval
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(
            args.eval_max_seq_length, "test.tsv", "test"
        )
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(
                args.output_dir, args.checkpoint_number, WEIGHTS_NAME
            )
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(
                model, eval_dataset, save_preds=True, prefix=str(global_step)
            )
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in trainer.records["result"].items()
                }
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)
Beispiel #7
0
def main():
    args = build_argparse().parse_args()
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]

    # output dir
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)

    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path, do_lower_case=args.do_lower_case)
    processor = ToxicProcessor(data_dir=args.data_dir,
                               tokenizer=tokenizer,
                               prefix=prefix)
    label_list = processor.get_labels()
    id2label = {i: key for i, key in enumerate(label_list)}
    num_labels = len(label_list)
    args.num_labels = num_labels

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(
        args.model_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)
    # Trainer
    logger.info("initializing traniner")
    trainer = TextClassifierTrainer(logger=logger,
                                    args=args,
                                    collate_fn=processor.collate_fn,
                                    input_keys=processor.get_input_keys(),
                                    metrics=[
                                        AUC(average='micro',
                                            task_type='binary'),
                                        MultiLabelReport(id2label)
                                    ])
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length,
                                                 'train.csv', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.csv', 'dev')
        trainer.train(model,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset)
    # do eval
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.csv', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir,
                                          args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model,
                             eval_dataset,
                             save_preds=True,
                             prefix=str(global_step))
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in trainer.records['result'].items()
                }
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)
    # do predict
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'test.csv', 'test')
        if args.checkpoint_number == 0:
            raise ValueError("checkpoint number should > 0,but get %d",
                             args.checkpoint_number)
        checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number,
                                      WEIGHTS_NAME)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model,
                            test_dataset=test_dataset,
                            prefix=str(global_step))
def main():
    parser = build_argparse()
    parser.add_argument('--distance_metric',
                        type=str,
                        default="educlidean",
                        choices=["cosine", 'educlidean', "manhattan"])
    args = parser.parse_args()

    # output dir
    if args.model_name is None:
        args.model_name = args.model_path.split("/")[-1]
    args.output_dir = args.output_dir + '{}'.format(args.model_name)
    os.makedirs(args.output_dir, exist_ok=True)
    prefix = "_".join([args.model_name, args.task_name])
    logger = TrainLogger(log_dir=args.output_dir, prefix=prefix)

    # device
    logger.info("initializing device")
    args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank)
    seed_everything(args.seed)
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        args.model_path, do_lower_case=args.do_lower_case)
    processor = EpidemicProcessor(data_dir=args.data_dir,
                                  tokenizer=tokenizer,
                                  prefix=prefix,
                                  encode_mode='triple')

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(
        args.model_path, cache_dir=args.cache_dir if args.cache_dir else None)
    config.distance_metric = args.distance_metric
    model = model_class.from_pretrained(args.model_path, config=config)
    model.to(args.device)

    # trainer
    logger.info("initializing traniner")
    trainer = TripleTrainer(logger=logger,
                            args=args,
                            metrics=[Accuracy()],
                            input_keys=processor.get_input_keys(),
                            collate_fn=processor.collate_fn)
    # do train
    if args.do_train:
        train_dataset = processor.create_dataset(args.train_max_seq_length,
                                                 'train.json', 'train')
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.json', 'dev')
        trainer.train(model,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset)
    if args.do_eval and args.local_rank in [-1, 0]:
        results = {}
        eval_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'dev.json', 'dev')
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints or args.checkpoint_number > 0:
            checkpoints = get_checkpoints(args.output_dir,
                                          args.checkpoint_number, WEIGHTS_NAME)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(args.device)
            trainer.evaluate(model,
                             eval_dataset,
                             save_preds=True,
                             prefix=str(global_step))
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in trainer.records['result'].items()
                }
                results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        dict_to_text(output_eval_file, results)
    if args.do_predict:
        test_dataset = processor.create_dataset(args.eval_max_seq_length,
                                                'test.json', 'test')
        if args.checkpoint_number == 0:
            raise ValueError("checkpoint number should > 0,but get %d",
                             args.checkpoint_number)
        checkpoints = get_checkpoints(args.output_dir, args.checkpoint_number,
                                      WEIGHTS_NAME)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("/")[-1].split("-")[-1]
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            trainer.predict(model,
                            test_dataset=test_dataset,
                            prefix=str(global_step))