コード例 #1
0
def inference(model, dev_data, save_predictions=False, verbose=False):
    predictions = []
    bos_token_id = dev_data.tokenizer.bos_token_id
    for idx, batch in enumerate(dev_data.dataloader.inference_dataloader()):

        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]

        pad_token_id = dev_data.tokenizer.pad_token_id
        batch[0], batch[1] = trim_batch(batch[0].unsqueeze(0), pad_token_id,
                                        batch[1].unsqueeze(0))
        batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

        with torch.no_grad():
            model.set_relation(batch[0], batch[1])

            outputs = model.model.generate(
                input_ids=batch[2],
                attention_mask=batch[3],
                num_beams=dev_data.args.num_beams,
                max_length=dev_data.args.max_output_length,
                decoder_start_token_id=model.config.bos_token_id,
                early_stopping=dev_data.gen_early_stop,
            )
        for input_, output in zip(batch[2], outputs):
            pred = dev_data.decode(output)
            predictions.append(pred)

    if save_predictions:
        dev_data.save_predictions(predictions)

    return dev_data.evaluate(predictions, verbose=verbose)
コード例 #2
0
    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
            batch = self._encode(batch)
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
        else:
            input_ids = torch.stack([x["input_ids"] for x in batch])
            attention_mask = torch.stack([x["attention_mask"] for x in batch])
            labels = torch.stack([x["labels"] for x in batch])

            labels = trim_batch(labels, self.pad_token_id)
            input_ids, attention_mask = trim_batch(
                input_ids, self.pad_token_id, attention_mask=attention_mask)

        if isinstance(self.tokenizer, T5Tokenizer):
            decoder_input_ids = self._shift_right_t5(labels)
            labels = labels
        else:
            decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
            labels = labels

        batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
        }
        return batch
コード例 #3
0
    def generate(self, queries, decode_method="beam", num_generate=5):
        with torch.no_grad():
            examples = queries

            decs = []
            for batch in list(chunks(examples, self.batch_size)):

                batch = self.tokenizer(batch,
                                       return_tensors="pt",
                                       truncation=True,
                                       padding="max_length").to(self.device)
                input_ids, attention_mask = trim_batch(
                    **batch, pad_token_id=self.tokenizer.pad_token_id)

                summaries = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_start_token_id=self.decoder_start_token_id,
                    num_beams=num_generate,
                    num_return_sequences=num_generate,
                )

                dec = self.tokenizer.batch_decode(
                    summaries,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False)
                decs.append(dec)

            return decs
コード例 #4
0
def generate_summaries_or_translations(
    examples: list,
    out_file: str,
    model_name: str,
    batch_size: int = 8,
    device: str = DEFAULT_DEVICE,
    fp16=False,
    task="summarization",
    **gen_kwargs,
) -> None:
    fout = Path(out_file).open("w", encoding="utf-8")
    model_name = str(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    if fp16:
        model = model.half()

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # update config with summarization specific params
    use_task_specific_params(model, task)

    for batch in tqdm(list(chunks(examples, batch_size))):
        if "t5" in model_name:
            batch = [model.config.prefix + text for text in batch]
        batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
            device
        )
        input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
        summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
        dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()
コード例 #5
0
    def update(engine, batch):
        # remove extra pad from batches
        batch = trim_batch(batch, pad)
        qgen.train()

        loss = 0.0
        ###################################
        # MLE training with teacher forcing
        ###################################
        if 'sl' in args.learning:
            input_ids, lm_labels, token_type_ids, _, _, _ = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            loss_ce = qgen(input_ids=input_ids,
                           labels=lm_labels,
                           token_type_ids=token_type_ids)[0]
            loss = apply_loss(engine.state.iteration, qgen_optimizer, loss_ce,
                              args)
        return loss.item()
コード例 #6
0
ファイル: run_maml.py プロジェクト: INK-USC/CrossFit
def train(args, logger, model, train_data, dev_data, optimizer, scheduler):
    model.train()
    global_batch = 0
    global_step = 0
    train_losses = []
    dev_losses = []
    best_accuracy = -1.0
    stop_training = False

    logger.info("Starting training!")
    for epoch in range(int(args.num_train_epochs)):
        for batch in tqdm(train_data.dataloader,
                          desc="Epoch {}".format(epoch)):

            global_batch += 1
            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch[0]]

            pad_token_id = train_data.tokenizer.pad_token_id

            # train batch
            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            # dev batch
            batch[4], batch[5] = trim_batch(batch[4], pad_token_id, batch[5])
            batch[6], batch[7] = trim_batch(batch[6], pad_token_id, batch[7])

            inner_opt = torch.optim.SGD(model.parameters(), lr=args.inner_lr)
            with higher.innerloop_ctx(model,
                                      inner_opt,
                                      copy_initial_weights=False) as (fnet,
                                                                      diffopt):
                # print("train batch")
                train_loss = fnet(input_ids=batch[0],
                                  attention_mask=batch[1],
                                  decoder_input_ids=batch[2],
                                  decoder_attention_mask=batch[3],
                                  is_training=True)

                if torch.isnan(train_loss).data:
                    logger.info("Stop training because loss=%s" %
                                (train_loss.data))
                    stop_training = True
                    break  # does this ever happen?

                train_losses.append(train_loss.detach().cpu())
                diffopt.step(train_loss)

                # print("dev batch")
                dev_loss = fnet(input_ids=batch[4],
                                attention_mask=batch[5],
                                decoder_input_ids=batch[6],
                                decoder_attention_mask=batch[7],
                                is_training=True)
                dev_losses.append(dev_loss.detach().cpu())

                dev_loss.backward()

            if global_batch % args.gradient_accumulation_steps == 0:
                global_step += 1
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()  # We have accumulated enough gradients
                scheduler.step()
                model.zero_grad()

                if global_step % args.eval_period == 0:
                    #     model.eval()
                    #     curr_em = inference(model if args.n_gpu==1 else model.module, dev_data)
                    #     logger.info("Step %d Train loss %.2f %s %s on epoch=%d" % (
                    #             global_step,
                    #             np.mean(train_losses),
                    #             dev_data.metric,
                    #             curr_em,
                    #             epoch))
                    logger.info("train loss: {}; dev loss: {}".format(
                        np.mean(train_losses), np.mean(dev_losses)))
                    train_losses = []
                    dev_losses = []
                #     if best_accuracy < curr_em:
                #         model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
                #         torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt"))
                #         logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \
                #                 (dev_data.metric, best_accuracy, curr_em, epoch, global_step))
                #         best_accuracy = curr_em
                #         wait_step = 0
                #         stop_training = False
                #     else:
                #         wait_step += 1
                #         if wait_step >= args.wait_step:
                #             stop_training = True
                #             break
                #     model.train()

            if global_step >= args.total_steps:
                stop_training = True
                break

        if stop_training:
            break

    model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()}
    torch.save(model_state_dict, os.path.join(args.output_dir,
                                              "last-model.pt"))
コード例 #7
0
ファイル: run_multitask.py プロジェクト: INK-USC/CrossFit
def train(args, logger, model, train_data, dev_data, optimizer, scheduler):
    model.train()
    global_step = 0
    train_losses = []
    best_accuracy = -1.0
    stop_training = False

    logger.info("Starting training!")
    for epoch in range(int(args.num_train_epochs)):
        for batch in tqdm(train_data.dataloader,
                          desc="Epoch {}".format(epoch)):
            global_step += 1
            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]

            pad_token_id = train_data.tokenizer.pad_token_id

            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            loss = model(input_ids=batch[0],
                         attention_mask=batch[1],
                         decoder_input_ids=batch[2],
                         decoder_attention_mask=batch[3],
                         is_training=True)
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if torch.isnan(loss).data:
                logger.info("Stop training because loss=%s" % (loss.data))
                stop_training = True
                break
            train_losses.append(loss.detach().cpu())
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()  # We have accumulated enough gradients
                scheduler.step()
                model.zero_grad()

            if global_step % args.eval_period == 0:
                model.eval()
                curr_em = inference(model if args.n_gpu == 1 else model.module,
                                    dev_data)
                logger.info("Step %d Train loss %.2f %s %s on epoch=%d" %
                            (global_step, np.mean(train_losses),
                             dev_data.metric, curr_em, epoch))
                train_losses = []
                if best_accuracy < curr_em:
                    model_state_dict = {
                        k: v.cpu()
                        for (k, v) in model.state_dict().items()
                    }
                    torch.save(model_state_dict,
                               os.path.join(args.output_dir, "best-model.pt"))
                    logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \
                            (dev_data.metric, best_accuracy, curr_em, epoch, global_step))
                    best_accuracy = curr_em
                    wait_step = 0
                    stop_training = False
                else:
                    wait_step += 1
                    if wait_step >= args.wait_step:
                        stop_training = True
                        break
                model.train()
        if stop_training:
            break

    model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()}
    torch.save(model_state_dict, os.path.join(args.output_dir,
                                              "last-model.pt"))
コード例 #8
0
def main(args):

    debug = False
    # Load a pre-defined tokenizer (GPT-2), create config and model
    logger.info("Prepare tokenizer, pretrained model and optimizer - add \
                special tokens for fine-tuning")
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_path,
                                              cache_dir=args.dataset_cache)
    tokenizer.add_tokens(SPECIAL_TOKENS)
    tokenizer.sep_token = '<sep>'

    if 'amr' in args.dataset_type:
        qgen = GPT2LMHeadModel.from_pretrained(args.model_path,
                                               cache_dir=args.dataset_cache)
    else:
        qgen = GPT2ConditionalLMHeadModel.\
            from_pretrained(args.model_path, cache_dir=args.dataset_cache)
    qgen.resize_token_embeddings(len(tokenizer))
    qgen.to(args.device)
    qgen.eval()

    logsoftmax = nn.LogSoftmax(dim=0)

    bos, eos, ctx, ans, que, pad, gen = \
        tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)

    if args.n_gpu > 1:
        logger.info("Setting model to DataParallel.")
        qgen = torch.nn.DataParallel(qgen)

    logger.info("Prepare datasets")
    if "amr" in args.dataset_type:
        logger.info("Decoding with AMR dev set")
        dataloader = get_data_loaders(args,
                                      tokenizer,
                                      qgen,
                                      dataset_name=args.output_data,
                                      shuffle=False)
    else:
        dataloader = get_data_loaders(args, tokenizer, qgen, shuffle=False)

    if 'amr' in args.dataset_type:
        if args.output_data.lower() == "test":
            ref = os.path.join(args.dataset_path, "test.tok.text")
        else:
            ref = os.path.join(args.dataset_path, "dev.tok.text")
        ref = open(ref).readlines()

    logger.info("Decode: " + args.decoder)

    # Output file name
    f = open(os.path.join(args.checkpoint, 'output.txt'), 'w')
    text_outputs = list()

    # beam search variables
    beam_size = 1 if args.beam_size is None else args.beam_size
    output_size = 1 if args.output_size is None else args.output_size
    beam_candidates = args.beam_candidates

    # General variables
    max_length = args.max_input_length

    instance = 0
    for batch in tqdm(dataloader):

        batch = trim_batch(batch, pad)
        _, _, _, _, input_ids, _, token_type_ids, attention_mask = \
            tuple(input_tensor.to(args.device) for input_tensor in batch)

        past = None

        o = 0
        all_probs = torch.zeros(beam_size, 1).to(args.device)
        original_input_len = input_ids.shape[1]
        start = True

        # general variables
        questions = []

        for idx in range(max_length):
            ###################
            # Greedy decoding
            ###################
            if args.decoder == "greedy":
                with torch.no_grad():
                    logits, past = qgen(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        past=past)
                outputs = torch.argmax(logits[0, -1, :])
                outputs = outputs.unsqueeze(0).unsqueeze(0)

            ###################
            # Nucleous Sampling
            ###################
            elif args.decoder == "sampling":
                with torch.no_grad():
                    logits, past = qgen(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        past=past)
                # bs x seq_len x V
                logits = logits[:, -1, :] / args.temperature
                logits = top_k_top_p_filtering(logits,
                                               top_k=args.top_k,
                                               top_p=args.top_p)
                # bs x V
                probs = F.softmax(logits, dim=-1)
                # bs x 1
                outputs = torch.multinomial(probs, num_samples=1)
                outputs = torch.where(input_ids[:, -1:] == eos,
                                      input_ids[:, -1:], outputs)

            ###################
            # BEAM Search
            ###################
            elif args.decoder == "beam":
                # Beam search

                with torch.no_grad():
                    logits = qgen(input_ids)[0]

                out_paths = None
                probs = None

                for k in range(logits.shape[0]):
                    log_p = logsoftmax(logits[k, -1, :])
                    p = log_p + all_probs[k]

                    if start:
                        predicted_top_k = torch.topk(p, beam_size)
                        start = False
                    else:
                        predicted_top_k = torch.topk(p, beam_candidates)

                    p_top_k_tokens = predicted_top_k.indices[:, None]
                    p_top_k_probs = predicted_top_k.values[:, None]

                    # Store paths
                    if out_paths is None:
                        out_paths = torch.cat((input_ids[k].expand(
                            p_top_k_tokens.shape[0],
                            input_ids.shape[1]), p_top_k_tokens), 1)

                    else:
                        out_paths = torch.cat(
                            (out_paths,
                             torch.cat((input_ids[k].expand(
                                 p_top_k_tokens.shape[0],
                                 input_ids.shape[1]), p_top_k_tokens), 1)), 0)
                    if probs is None:
                        probs = p_top_k_probs
                    else:
                        probs = torch.cat((probs, p_top_k_probs), 0)

                global_top_k = torch.topk(probs, k=beam_size, dim=0)
                input_ids = out_paths[global_top_k.indices[:, 0], :]
                all_probs = global_top_k.values
                o += 1

            else:
                raise Exception('Not valid decoder ' + args.decoder)

            #######################
            # Termination condition
            #######################
            if not args.decoder == 'beam':
                # correctly shape inputs for next round
                input_ids = outputs
                token_type_ids = token_type_ids[:, -1:]

                # if all the outputs are special tokens
                questions.append(outputs)
                if (outputs == eos).all():
                    break
            else:
                outputs = input_ids[:, original_input_len:]

                if ((outputs == eos).sum(dim=1) > 0).all():
                    break

        ################
        # Output to file
        ################
        if args.decoder != 'beam':
            # append an extra <eos> in case max length is reached
            questions.append(torch.zeros_like(outputs).fill_(eos))
            questions = torch.cat(questions, dim=1)

        else:
            questions.append(outputs)
            questions = questions[0]

        for i, question in enumerate(questions):
            question = question.tolist()

            if eos in question:
                idx = question.index(eos)
            else:
                idx = -1

            question = tokenizer.decode(question[:idx])
            if '<generate>' in question:
                question = question.split('<generate>')[1]

            # Print outputs to file and save in text_outputs
            print(question.replace('\n', ' '), file=f)
            f.flush()
            text_outputs.append(question.replace('\n', ' ').lower())

            # Limit number of outputs to output_size
            if i >= output_size - 1:
                break

        if 'amr' in args.dataset_type and debug:
            print("GOLD: ", ref[instance])
            print("final: ", text_outputs[instance])

        instance += 1
コード例 #9
0
def train(args, logger, model, train_data, dev_data, optimizer, scheduler):

    model.train()
    global_step = 0
    train_losses = []
    best_accuracy = (-1.0, -1.0,
                     -1.0) if args.dataset == "zest_grouped" else -1.0
    stop_training = False

    # curr_em = inference(model if args.n_gpu==1 else model.module, dev_data)
    # logger.info("[Before Training] %s %s" % (
    #         dev_data.metric,
    #         curr_em))

    logger.info("Starting training!")

    model.model.backup_layer_norm_parameters()

    for epoch in range(int(args.num_train_epochs)):
        for batch in tqdm(train_data.dataloader,
                          desc="Epoch {}".format(epoch)):
            global_step += 1

            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch[0]]

            rel_ids, rel_masks = batch[0].unsqueeze(0), batch[1].unsqueeze(0)
            input_ids, input_masks = batch[2], batch[3]
            output_ids, output_masks = batch[4], batch[5]

            pad_token_id = train_data.tokenizer.pad_token_id
            rel_ids, rel_masks = trim_batch(rel_ids, pad_token_id, rel_masks)
            input_ids, input_masks = trim_batch(input_ids, pad_token_id,
                                                input_masks)
            output_ids, output_masks = trim_batch(output_ids, pad_token_id,
                                                  output_masks)

            loss = model.forward(rel_ids=rel_ids,
                                 rel_masks=rel_masks,
                                 input_ids=input_ids,
                                 input_masks=input_masks,
                                 output_ids=output_ids,
                                 output_masks=output_masks,
                                 is_training=True)

            train_losses.append(loss.detach().cpu())
            loss.backward()

            model.model.restore_layer_norm_parameters()

            if global_step % args.gradient_accumulation_steps == 0:
                # for p in model.meta_model.decoders.parameters():
                # print(p)
                # print(p.grad)
                # break

                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                # for p in model.meta_model.decoders.parameters():
                #     print()
                #     print(p)
                #     print(p.grad)
                #     break

                optimizer.step()  # We have accumulated enough gradients
                scheduler.step()
                model.zero_grad()

                # print(model.meta_model.decoders[-1].linear1.weight)

            if global_step % args.eval_period == 0:
                model.eval()
                # curr_em = 0.0
                curr_em = inference(model if args.n_gpu == 1 else model.module,
                                    dev_data,
                                    save_predictions=True)
                logger.info("Step %d Train loss %.2f %s %s on epoch=%d" %
                            (global_step, np.mean(train_losses),
                             dev_data.metric, curr_em, epoch))
                train_losses = []
                if best_accuracy < curr_em:
                    model_state_dict = {
                        k: v.cpu()
                        for (k, v) in model.state_dict().items()
                    }
                    torch.save(model_state_dict,
                               os.path.join(args.output_dir, "best-model.pt"))
                    logger.info("Saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \
                            (dev_data.metric, best_accuracy, curr_em, epoch, global_step))
                    best_accuracy = curr_em
                    wait_step = 0
                    stop_training = False
                else:
                    wait_step += 1
                    if wait_step >= args.wait_step:
                        stop_training = True
                        break
                model.train()

        if stop_training:
            break

    model_state_dict = {k: v.cpu() for (k, v) in model.state_dict().items()}
    torch.save(model_state_dict, os.path.join(args.output_dir,
                                              "last-model.pt"))
コード例 #10
0
ファイル: run_eval.py プロジェクト: linshaoxin-maker/taas
def generate_summaries(
    examples: list,
    out_file: str,
    model_name: str,
    batch_size: int = 8,
    device: str = DEFAULT_DEVICE,
    fp16=True,
    task="summarization",
    decoder_start_token_id=None,
    finetune_flag: int = 0,
    checkpoint_path: str = "",
    **gen_kwargs,
) -> None:
    fout = Path(out_file).open("w", encoding="utf-8")

    # initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # if our goal is to evaluate the original checkpoint
    if finetune_flag < 1:
        # initialize the model checkpoints
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    # if our goal is to evaluate our fine-tuned checkpoint
    else:
        # load the finetuned checkpoints
        model = AutoModelForSeq2SeqLM.from_pretrained(
            f"{checkpoint_path}/best_tfmr").to(device)

    if fp16:
        model = model.half()
    if decoder_start_token_id is None:
        decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

    # update config with summarization specific params
    use_task_specific_params(model, task)

    for batch in tqdm(list(chunks(examples, batch_size))):
        batch = tokenizer(batch,
                          return_tensors="pt",
                          truncation=True,
                          padding="max_length").to(device)
        input_ids, attention_mask = trim_batch(
            **batch, pad_token_id=tokenizer.pad_token_id)

        # -----------------------------------------
        # Topic Modeling - GSM
        # -----------------------------------------
        docs = []
        # load dict
        dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram'))
        # remove [SEP]
        sep_list = [
            '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]',
            '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]'
        ]
        # vocab size for topic modeling
        vocab_size = len(dictionary)
        # load config for GSM
        config = yaml_load(f"data/config/gsm.yaml")
        # model
        config['hidden']['features'][0] = vocab_size

        # trainer batch
        config['trainer_batch']['test_sample'] = 1
        config = extend_config_reference(config)
        gsm_trainer = config['GSMtrainer']
        gsm_trainer['base_dir'] = f"log/bart-large-cnn-finetune"
        gsm_trainer = GSMTrainer.from_config(gsm_trainer)

        total_sample = len(batch['input_ids'])

        for batch_num in range(total_sample):
            # extract the batch_sentence
            batch_sentence = tokenizer.decode(
                batch['input_ids'][batch_num].tolist(),
                skip_special_tokens=True)
            # change to lowercase and split to list
            batch_sentence_list = batch_sentence.split(" ")
            # remove [SEP]
            batch_sentence_list_nosep = [
                item for item in batch_sentence_list if item not in sep_list
            ]
            text = ' '.join([x for x in batch_sentence_list_nosep])
            fine_text = text.replace(' ##', '').lower()
            batch_sentence = re.sub(r'[^\w\s]', '', fine_text)
            # batch_sentence: change to the cleaned news for topic modeling
            # change to training data format in topic modeling
            gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" "))
            docs.append(gsm_data_bow)

        # gsm_data: data for topic modeling
        gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'),
                              batch_size=config['dataset']['batch_size'],
                              drop_last=False,
                              num_workers=0)

        gsm_trainer.__dict__['train_iterator'] = gsm_data

        gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size=vocab_size,
                                               training=False)

        del gsm_data

        topic_p = gsm_p.cuda()

        summaries = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_start_token_id=decoder_start_token_id,
            topic_p=topic_p,
            **gen_kwargs,
        )
        dec = tokenizer.batch_decode(summaries,
                                     skip_special_tokens=True,
                                     clean_up_tokenization_spaces=False)
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()