Ejemplo n.º 1
0
def main():
    config = QGConfig()
    args = parser.parse_args()

    model = GPT2LMHeadModel.from_pretrained("taeminlee/kogpt2")
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename="tokenizer/vocab.json",
        merges_filename="tokenizer/merges.txt",
        add_prefix_space=False)
    examples = load_korquad_dataset(config.dev_dataset)
    random.shuffle(examples)
    examples = examples[:args.num_samples]
    dataset = QGDecodingDataset(examples, tokenizer,
                                config.max_sequence_length)
    dataloader = DataLoader(dataset, batch_size=1)

    model = model.to(device)
    model.eval()

    generated_results = []

    for i, batch in tqdm(enumerate(dataloader),
                         desc="generate",
                         total=len(dataloader)):
        input_ids, attention_mask = (v.to(device) for v in batch)
        origin_seq_len = input_ids.size(-1)

        decoded_sequences = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=origin_seq_len + MAX_QUESTION_SPACE,
            min_length=origin_seq_len + MIN_QUESTION_SPACE,
            pad_token_id=0,
            bos_token_id=1,
            eos_token_id=2,
            num_beams=args.num_beams,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3,
            num_return_sequences=1,
        )

        for decoded_tokens in decoded_sequences.tolist():
            decoded_question_text = tokenizer.decode(
                decoded_tokens[origin_seq_len:])
            decoded_question_text = decoded_question_text.split(
                "</s>")[0].replace("<s>", "")
            generated_results.append(
                (examples[i].context, examples[i].answer, examples[i].question,
                 decoded_question_text))

    with open(args.output_path, "w") as f:
        for context, answer, question, generated_question in generated_results:
            f.write(f"문맥\t{context}\n")
            f.write(f"답변\t{answer}\n")
            f.write(f"생성된 질문\t{generated_question}\n")
            f.write(f"실제 질문\t{question}\n\n")
def main():
    config = QGConfig()
    args = parser.parse_args()

    model = GPT2LMHeadModel.from_pretrained("taeminlee/kogpt2")
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename="tokenizer/vocab.json", merges_filename="tokenizer/merges.txt", add_prefix_space=False
    )
    examples = load_korquad_dataset(config.dev_dataset)
    dataset = QGDataset(examples, tokenizer, config.max_sequence_length)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=dynamic_padding_collate_fn)

    model = model.to(device)
    model.eval()

    model.eval()
    loss_list = []
    for batch_data in tqdm(dataloader, desc="[EVAL]"):
        with torch.no_grad():
            input_ids, attention_mask, labels = tuple(value.to(device) for value in batch_data)
            model_outputs = model.forward(input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
            loss_list.append(model_outputs.loss.item())

    mean_loss = np.mean(loss_list)
    print(f"loss:{mean_loss:.4f} perplexity:{math.exp(mean_loss):.4f}")
    model.train()
 def __init__(self, path, max_tokens):
     self.tokenizer = SentencePieceBPETokenizer.from_file(
         str(path / "vocab.json"),
         str(path / "merges.txt"),
     )
     self.max_tokens = max_tokens
     self.idx = {}
     for s in ['</s>', '<s>', '<pad>']:
         self.idx[s] = self.tokenizer.token_to_id(s)
 def __init__(self, path, max_tokens):
     self.logger = log.getLogger("Tokenizer")
     self.logger.info("loading tokenizer")
     self.logger.info(f"path: {path}")
     self.logger.info(f"max_tokens: {max_tokens}")
     self.tokenizer = SentencePieceBPETokenizer.from_file(
         str(path / "vocab.json"),
         str(path / "merges.txt"),
     )
     self.max_tokens = max_tokens
     self.idx = {}
     for s in ['</s>', '<s>', '<pad>']:
         self.idx[s] = self.tokenizer.token_to_id(s)
def main():
    config = QGConfig()
    args = parser.parse_args()

    model = GPT2LMHeadModel.from_pretrained("taeminlee/kogpt2")
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename="tokenizer/vocab.json", merges_filename="tokenizer/merges.txt", add_prefix_space=False
    )
    examples = [
        QAExample(
            "코로나의 온도는 섭씨 수십만~수백만 도로 추정되는데 이는 태양 표면 온도인 5,000~6,000 K의 100배 이상 되는 수치다.[1] 일식 당시 보이는 코로나의 색깔은 백색이나 보라색인데, 섭씨 수십만도 이상은 올라가야 보라색이 된다. 이것은 태양을 점에너지원에서 파생된 단순한 불덩어리로 가정한다면, 즉 봐서는 열역학 제 2법칙에 정면으로 위배되는 것처럼 보이지만 어디 가서 열역학 제 2법칙은 코로나 현상 때문에 위기에 놓여있다는 소리는 하지 말자. 정보가 없을 뿐, 코로나도 열역학 제 2법칙을 정면으로 위배한다는 근거는 어디에도 없다. 코로나의 온도가 태양 표면보다도 높은 기현상을 설명하기 위한 가설로는 크게 2가지 주류설이 있다.",
            "코로나",
        ),
    ]
    dataset = QGDecodingDataset(examples, tokenizer, config.max_sequence_length)
    dataloader = DataLoader(dataset, batch_size=1)

    model = model.to(device)
    model.eval()

    generated_results = []

    for i, batch in tqdm(enumerate(dataloader), desc="generate", total=len(dataloader)):
        input_ids, attention_mask = (v.to(device) for v in batch)
        origin_seq_len = input_ids.size(-1)

        decoded_sequences = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=origin_seq_len + MAX_QUESTION_SPACE,
            min_length=origin_seq_len + MIN_QUESTION_SPACE,
            pad_token_id=0,
            bos_token_id=1,
            eos_token_id=2,
            do_sample=True,
            num_beams=5,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3,
            num_return_sequences=3,
        )

        for decoded_tokens in decoded_sequences.tolist():
            decoded_question_text = tokenizer.decode(decoded_tokens[origin_seq_len:])
            decoded_question_text = decoded_question_text.split("</s>")[0].replace("<s>", "")
            generated_results.append(
                (examples[i].context, examples[i].answer, examples[i].question, decoded_question_text)
            )

    with open("article_qg.tsv", "a") as f:
        for context, answer, question, generated_question in generated_results:
            f.write(f"문맥\t{context}\n")
            f.write(f"답변\t{answer}\n")
            f.write(f"생성된 질문\t{generated_question}\n")
            if question is not None:
                f.write(f"실제 질문\t{question}\n")
            f.write("\n")

            print(f"문맥\t{context}\n")
            print(f"답변\t{answer}\n")
            print(f"생성된 질문\t{generated_question}\n")
            if question is not None:
                print(f"실제 질문\t{question}")
            print()
def main(config: QGConfig):
    logger = _create_logger(output_dir=config.output_dir)
    logger.info("============================")
    for key, value in config._asdict().items():
        logger.info(f"{key:30}:{value}")
    logger.info("============================")
    torch.manual_seed(config.random_seed)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename=config.vocab_path, merges_filename=config.tokenizer_merges_path, add_prefix_space=False
    )

    logger.info("loading train dataset")
    train_examples = load_korquad_dataset(config.train_dataset)
    train_dataset = QGDataset(train_examples, tokenizer, config.max_sequence_length)
    train_dataloader = DataLoader(
        train_dataset, config.train_batch_size, shuffle=True, collate_fn=dynamic_padding_collate_fn
    )

    logger.info("loading dev dataset")
    dev_examples = load_korquad_dataset(config.dev_dataset)
    dev_dataset = QGDataset(dev_examples, tokenizer, config.max_sequence_length, is_train=False)
    dev_dataloader = DataLoader(dev_dataset, config.dev_batch_size, collate_fn=dynamic_padding_collate_fn)

    # model 생성
    model = GPT2LMHeadModel.from_pretrained(config.gpt_model_hub_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = Adam(model.parameters(), lr=config.lr)
    total_steps = len(train_dataloader) * config.epochs
    warmup_steps = int(total_steps * config.warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    loss_list_between_log_interval = []
    for epoch_id in range(config.epochs):
        for step_index, batch_data in tqdm(
            enumerate(train_dataloader), f"[TRAIN] EP:{epoch_id}", total=len(train_dataloader)
        ):
            global_step = len(train_dataloader) * epoch_id + step_index + 1
            optimizer.zero_grad()

            input_ids, attention_mask, labels = tuple(value.to(device) for value in batch_data)
            model_outputs = model.forward(input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)

            model_outputs.loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            optimizer.step()
            scheduler.step()

            # for logging
            loss_list_between_log_interval.append(model_outputs.loss.item())

            if global_step % config.train_log_interval == 0:
                mean_loss = np.mean(loss_list_between_log_interval)
                logger.info(
                    f"EP:{epoch_id} global_step:{global_step} "
                    f"loss:{mean_loss:.4f} perplexity:{math.exp(mean_loss):.4f}"
                )
                loss_list_between_log_interval.clear()

            if global_step % config.validation_interval == 0:
                _validate(model, dev_dataloader, device, logger, global_step)

            if global_step % config.save_interval == 0:
                state_dict = model.state_dict()
                model_path = os.path.join(config.output_dir, f"gpt2_step_{global_step}.pth")
                logger.info(f"global_step: {global_step} model saved at {model_path}")
                torch.save(state_dict, model_path)