Пример #1
0
def main(cli_args):
    # Read from config file and make args
    with open(os.path.join(cli_args.config_dir, cli_args.config_file)) as f:
        args = AttrDict(json.load(f))
    logger.info("Training/evaluation parameters {}".format(args))
    logger.info("cliargs parameters {}".format(cli_args))

    args.output_dir = os.path.join(args.ckpt_dir, cli_args.result_dir)
    args.model_mode = cli_args.model_mode
    args.margin = cli_args.margin

    init_logger()
    set_seed(args)

    model_link = None
    if cli_args.transformer_mode.upper() == "T5":
        model_link = "t5-base"
    elif cli_args.transformer_mode.upper() == "ELECTRA":
        model_link = "google/electra-base-discriminator"
    elif cli_args.transformer_mode.upper() == "ALBERT":
        model_link = "albert-base-v2"
    elif cli_args.transformer_mode.upper() == "ROBERTA":
        model_link = "roberta-base"
    elif cli_args.transformer_mode.upper() == "BERT":
        model_link = "bert-base-uncased"

    print(model_link)
    tokenizer = AutoTokenizer.from_pretrained(model_link)

    args.test_file = os.path.join(cli_args.dataset, args.test_file)
    args.dev_file = os.path.join(cli_args.dataset, args.dev_file)
    args.train_file = os.path.join(cli_args.dataset, args.train_file)
    # Load dataset
    train_dataset = BaseDataset(args, tokenizer,
                                mode="train") if args.train_file else None
    dev_dataset = BaseDataset(args, tokenizer,
                              mode="dev") if args.dev_file else None
    test_dataset = BaseDataset(args, tokenizer,
                               mode="test") if args.test_file else None

    if dev_dataset == None:
        args.evaluate_test_during_training = True  # If there is no dev dataset, only use testset

    args.logging_steps = int(len(train_dataset) / args.train_batch_size) + 1
    args.save_steps = args.logging_steps
    labelNumber = train_dataset.getLabelNumber()

    labels = [str(i) for i in range(labelNumber)]
    config = AutoConfig.from_pretrained(model_link)

    # GPU or CPU
    args.device = "cuda:{}".format(
        cli_args.gpu
    ) if torch.cuda.is_available() and not args.no_cuda else "cpu"
    config.device = args.device
    args.model_mode = cli_args.model_mode

    model = MODEL_LIST[cli_args.model_mode](model_link, args.model_type,
                                            args.model_name_or_path, config,
                                            labelNumber, args.margin)
    model.to(args.device)

    if args.do_train:
        global_step, tr_loss = train(args, model, train_dataset, dev_dataset,
                                     test_dataset)
        logger.info(" global_step = {}, average loss = {}".format(
            global_step, tr_loss))

    results = {}
    if args.do_eval:
        checkpoints = list(
            os.path.dirname(c) for c in sorted(
                glob.glob(args.output_dir + "/**/" + "pytorch_model.bin",
                          recursive=True)))
        if not args.eval_all_checkpoints:
            checkpoints = checkpoints[-1:]
        else:
            logging.getLogger("transformers.configuration_utils").setLevel(
                logging.WARN)  # Reduce logging
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1]
            model = MODEL_LIST[args.model_type].from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args,
                              model,
                              test_dataset,
                              mode="test",
                              global_step=global_step)
            result = dict(
                (k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as f_w:
            for key in sorted(results.keys()):
                f_w.write("{} = {}\n".format(key, str(results[key])))
Пример #2
0
def main(cli_args):
    # Read from config file and make args
    max_checkpoint = "checkpoint-best"

    args = torch.load(
        os.path.join("ckpt", cli_args.result_dir, max_checkpoint,
                     "training_args.bin"))
    with open(os.path.join(cli_args.config_dir, cli_args.config_file)) as f:
        args = AttrDict(json.load(f))
    logger.info("Training/evaluation parameters {}".format(args))
    logger.info("cliargs parameters {}".format(cli_args))

    args.output_dir = os.path.join(args.ckpt_dir, cli_args.result_dir)
    args.model_mode = cli_args.model_mode
    args.device = "cuda:{}".format(
        cli_args.gpu
    ) if torch.cuda.is_available() and not args.no_cuda else "cpu"

    init_logger()
    set_seed(args)

    model_link = None
    if cli_args.transformer_mode.upper() == "T5":
        model_link = "t5-base"
    elif cli_args.transformer_mode.upper() == "ELECTRA":
        model_link = "google/electra-base-discriminator"
    elif cli_args.transformer_mode.upper() == "ALBERT":
        model_link = "albert-base-v2"
    elif cli_args.transformer_mode.upper() == "ROBERTA":
        model_link = "roberta-base"
    elif cli_args.transformer_mode.upper() == "BERT":
        model_link = "bert-base-uncased"

    tokenizer = AutoTokenizer.from_pretrained(model_link)

    args.test_file = os.path.join(cli_args.dataset, args.test_file)
    args.dev_file = os.path.join(cli_args.dataset, args.train_file)
    args.train_file = os.path.join(cli_args.dataset, args.train_file)
    # Load dataset
    train_dataset = BaseDataset(args, tokenizer,
                                mode="train") if args.train_file else None
    dev_dataset = BaseDataset(args, tokenizer,
                              mode="dev") if args.dev_file else None
    test_dataset = BaseDataset(args, tokenizer,
                               mode="test") if args.test_file else None

    if dev_dataset == None:
        args.evaluate_test_during_training = True  # If there is no dev dataset, only use testset

    args.logging_steps = int(len(train_dataset) / args.train_batch_size) + 1
    args.save_steps = args.logging_steps
    labelNumber = train_dataset.getLabelNumber()

    labels = [str(i) for i in range(labelNumber)]
    config = AutoConfig.from_pretrained(model_link)

    args.device = "cuda:{}".format(
        cli_args.gpu
    ) if torch.cuda.is_available() and not args.no_cuda else "cpu"
    config.device = args.device
    args.model_mode = cli_args.model_mode

    logger.info("Testing model checkpoint to {}".format(max_checkpoint))
    global_step = max_checkpoint.split("-")[-1]

    # GPU or CPU
    args.device = "cuda:{}".format(
        cli_args.gpu
    ) if torch.cuda.is_available() and not args.no_cuda else "cpu"
    config.device = args.device
    args.model_mode = cli_args.model_mode

    model = MODEL_LIST[cli_args.model_mode](model_link, args.model_type,
                                            args.model_name_or_path, config,
                                            labelNumber, -0.75)
    model.load_state_dict(
        torch.load(
            os.path.join("ckpt", cli_args.result_dir, max_checkpoint,
                         "training_model.bin")))

    model.to(args.device)

    preds, labels, result, txt_all = evaluate(args,
                                              model,
                                              test_dataset,
                                              mode="test",
                                              global_step=global_step)
    pred_and_labels = pd.DataFrame([])
    pred_and_labels["data"] = txt_all
    pred_and_labels["pred"] = preds
    pred_and_labels["label"] = labels
    pred_and_labels["result"] = preds == labels
    decode_result = list(pred_and_labels["data"].apply(
        lambda x: tokenizer.convert_ids_to_tokens(tokenizer(x)["input_ids"])))
    pred_and_labels["tokenizer"] = decode_result

    pred_and_labels.to_csv(os.path.join(
        "ckpt", cli_args.result_dir, "test_result_" + max_checkpoint + ".csv"),
                           encoding="utf-8")