def decode_write_all(writer, logger, epoch, config, model, dataloader, mode):
    # 動態取batch
    num = len(dataloader)
    outFrame = None
    avg_time = 0
    total_scores = dict()
    idx = 0
    for _, batch in enumerate(dataloader):
        start = time.time()
        article_sents = [article for article in batch.original_article]
        ref_sents = [ref for ref in batch.original_abstract]
        decoded_sents = [
            s.summarize(text=article, summary_length=1, query_based_token=batch.key_words[idx])[0] \
            # s.summarize(text=article, summary_length=1, query_based_token=None)[0]

            for idx, article in enumerate(article_sents)
        ]
        decoded_sents = [
            sent if len(sent) > 5 else "xxx xxx xxx xxx xxx"
            for sent in decoded_sents
        ]

        keywords_list = [str(word_list) for word_list in batch.key_words]
        cost = (time.time() - start)
        avg_time += cost
        try:
            # rouge_1, rouge_2, rouge_l, self_Bleu_1, self_Bleu_2, self_Bleu_3, self_Bleu_4,                 Bleu_1, Bleu_2, Bleu_3, Bleu_4, Meteor, batch_frame = total_evaulate(article_sents, keywords_list, decoded_sents, ref_sents)
            multi_scores, batch_frame = total_evaulate(article_sents,
                                                       keywords_list,
                                                       decoded_sents,
                                                       ref_sents)
            review_IDS = [review_ID for review_ID in batch.review_IDS]
            batch_frame['review_ID'] = review_IDS
        except Exception as e:
            continue

        if idx % 1000 == 0 and idx > 0:
            print(idx)
        if idx == 0:
            outFrame = batch_frame
            total_scores = multi_scores
        else:
            outFrame = pd.concat([outFrame, batch_frame],
                                 axis=0,
                                 ignore_index=True)
            for key, scores in total_scores.items():
                scores.extend(multi_scores[key])
                total_scores[key] = scores
        idx += 1
        # ----------------------------------------------------
    avg_time = avg_time / (num * config.batch_size)

    scalar_acc = {}
    num = 0
    for key, scores in total_scores.items():
        num = len(scores)
        scalar_acc[key] = sum(scores) / len(scores)

    total_output(0, mode, writerPath, outFrame, avg_time, num, scalar_acc)
    return scalar_acc['rouge_l_f'], outFrame
def decode(writer, dataloader, epoch):
    # 動態取batch
    num = len(dataloader)
    outFrame = None
    avg_time = 0
    total_scores = dict()
    idx = 0
    for _, inputs in enumerate(dataloader):
        start = time.time()
        # 'Encoder data'
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, coverage, ct_e, enc_key_batch, enc_key_mask, enc_key_lens = get_input_from_batch(
            inputs, config, batch_first=True)
        max_enc_len = max(T.max(enc_lens, dim=0)).tolist()[0]

        if (max_enc_len != max(enc_lens.tolist())[0]): continue

        enc_batch = parallel_model.module.embeds(
            enc_batch)  # Get embeddings for encoder input
        enc_key_batch = parallel_model.module.embeds(
            enc_key_batch)  # Get key embeddings for encoder input

        enc_out, enc_hidden = parallel_model.module.encoder(
            enc_batch, enc_lens, max_enc_len)

        # 'Feed encoder data to predict'
        pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e,
                               extra_zeros, enc_batch_extend_vocab,
                               enc_key_batch, enc_key_mask,
                               parallel_model.module, START, END,
                               UNKNOWN_TOKEN)

        article_sents, decoded_sents, keywords_list, ref_sents, long_seq_index = prepare_result(
            vocab, inputs, pred_ids)
        cost = (time.time() - start)
        avg_time += cost

        multi_scores, batch_frame = total_evaulate(article_sents,
                                                   keywords_list,
                                                   decoded_sents, ref_sents)
        review_IDS = [review_ID for review_ID in inputs.review_IDS]
        batch_frame['review_ID'] = review_IDS
        if idx % 1000 == 0 and idx > 0:
            print(idx)
        if idx == 0:
            outFrame = batch_frame
            total_scores = multi_scores
        else:
            outFrame = pd.concat([outFrame, batch_frame],
                                 axis=0,
                                 ignore_index=True)
            for key, scores in total_scores.items():
                scores.extend(multi_scores[key])
                total_scores[key] = scores
        idx += 1
        # ----------------------------------------------------
    avg_time = avg_time / (num * config.batch_size)

    scalar_acc = {}
    num = 0
    for key, scores in total_scores.items():
        num = len(scores)
        scalar_acc[key] = sum(scores) / len(scores)

    for scalar_name, accuracy in scalar_acc.items():
        if 'rouge_1' in scalar_name:
            writer.add_scalars('scalar/rouge_1', {
                scalar_name: accuracy,
            }, epoch)
        elif 'rouge_2' in scalar_name:
            writer.add_scalars('scalar/rouge_2', {
                scalar_name: accuracy,
            }, epoch)
        elif 'rouge_l' in scalar_name:
            writer.add_scalars('scalar/rouge_l', {
                scalar_name: accuracy,
            }, epoch)
        elif 'bleu' in scalar_name:
            writer.add_scalars('scalar/bleu', {
                scalar_name: accuracy,
            }, epoch)
        elif 'meteor' in scalar_name:
            writer.add_scalars('scalar/meteor', {
                scalar_name: accuracy,
            }, epoch)

    # -----------------------------------------------------------
    total_output(epoch, 'test', writerPath, outFrame, avg_time, num,
                 scalar_acc)
    # -----------------------------------------------------------
    outFrame = outFrame.sort_values(by=['rouge_l'], ascending=False)
    big_frame = outFrame.head()
    small_frame = outFrame.tail()
    # -----------------------------------------------------------
    i = 0
    for view_item in big_frame.to_dict('records'):
        writer.add_text(
            'BigTest/epoch_%s/##%s' % (epoch, i),
            "### rouge_l :    \
                        " + str(view_item['rouge_l']), epoch)
        writer.add_text(
            'BigTest/epoch_%s/##%s' % (epoch, i),
            "### decoded :    \
                        " + view_item['decoded'], epoch)
        writer.add_text(
            'BigTest/epoch_%s/##%s' % (epoch, i),
            "### reference :    \
                        " + view_item['reference'], epoch)
        writer.add_text(
            'BigTest/epoch_%s/##%s' % (epoch, i),
            "### keywords :    \
                        " + view_item['keywords'], epoch)
        writer.add_text(
            'BigTest/epoch_%s/##%s' % (epoch, i),
            "### article :    \
                        " + view_item['article'], epoch)

        i += 1
    # -----------------------------------------------------------
    i = 0
    for view_item in small_frame.to_dict('records'):
        writer.add_text(
            'SmallTest/epoch_%s/##%s' % (epoch, i),
            "### rouge_l :    \
                        " + str(view_item['rouge_l']), epoch)
        writer.add_text(
            'SmallTest/epoch_%s/##%s' % (epoch, i),
            "### decoded :    \
                        " + view_item['decoded'], epoch)
        writer.add_text(
            'SmallTest/epoch_%s/##%s' % (epoch, i),
            "### reference :    \
                        " + view_item['reference'], epoch)
        writer.add_text(
            'SmallTest/epoch_%s/##%s' % (epoch, i),
            "### keywords :    \
                        " + view_item['keywords'], epoch)
        writer.add_text(
            'SmallTest/epoch_%s/##%s' % (epoch, i),
            "### article :    \
                        " + view_item['article'], epoch)
        i += 1
    return outFrame