コード例 #1
0
ファイル: main.py プロジェクト: robertkibet/TransformerSum
    parser.add_argument("--weight_decay", default=1e-2, type=float)
    parser.add_argument(
        "-l",
        "--log",
        dest="logLevel",
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        help="Set the logging level (default: 'Info').",
    )

    main_args = parser.parse_known_args()

    if main_args[0].mode == "abstractive":
        parser = AbstractiveSummarizer.add_model_specific_args(parser)
    else:
        parser = ExtractiveSummarizer.add_model_specific_args(parser)

    if main_args[0].custom_checkpoint_every_n and (not main_args[0].weights_save_path):
        logger.error(
            "You must specify the `--weights_save_path` to use `--custom_checkpoint_every_n`."
        )

    main_args = parser.parse_args()

    # Setup logging config
    logging.basicConfig(
        format="%(asctime)s|%(name)s|%(levelname)s> %(message)s",
        level=logging.getLevelName(main_args.logLevel),
    )

    # Set the `nlp` logging verbosity since its default is not INFO.
コード例 #2
0
ファイル: inference.py プロジェクト: sentian/TransformerSum
!git clone [email protected]:sentian/TransformerSum.git
!cd transformersum
!conda env create --file environment.yml

import os
os.chdir('/home/sentian/Projects/code/gentext/TransformerSum/src')
from extractive import ExtractiveSummarizer

## the model checkpoint needs to be downloaded from https://github.com/sentian/TransformerSum#extractive
## I use the roberta-base-ext-sum model
## change the path_to_checkpoint below
model = ExtractiveSummarizer.load_from_checkpoint("/home/sentian/Projects/model/gentext/transformersum/roberta-base-ext-sum/epoch=3.ckpt", strict=False)
text_to_summarize = "I love the extra heat rods which make more even heat. The farther back on the rack, \
the hotter it is. The controls are easy to understand and the pre-heat function  is a very nice thing to \
have. It heats up fast. Because the unit is small, the 1800W coils are able to heat it rapidly. \
The heating elements are radiant coils in a quartz tube. I really wish the unit had an extra heating \
element both top and bottom. The Breville is easy to use and heats up quickly. It heats up quickly and \
maintains temperature perfectly. My unit has a 'hot spot' in the right rear corner. It gets a good even \
heat on both sides. It gets hot very quickly so preheating takes no time. The temperature and time display \
is very easy to read. The unit is very light so I can move it to my island if I am worried about the heat. \
It is easy to operate, great dial choices, easy to set or adjust time and temp. Great preheat function. \
The automatic preheat feature is nicer than I expected and the size is perfect. Heats up VERY quickly. \
Way louder than my full size Breville."
summary = model.predict(text_to_summarize, num_summary_sentences=5)
print(summary)
コード例 #3
0
def summarize_text(text, model_choice):
    summarizer = ExtractiveSummarizer.load_from_checkpoint(model_choice)
    return summarizer.predict(text)
import time
import json
from datetime import timedelta
from pandas import json_normalize
from extractive import ExtractiveSummarizer

from tqdm import tqdm
tqdm.pandas()

# ======================================================================================================================
# Define summarizer (transformersum)
# ======================================================================================================================

# using extractive model "distilroberta-base-ext-sum"
model = ExtractiveSummarizer.load_from_checkpoint("models\\epoch=3.ckpt")

# ======================================================================================================================
# ACM dataset
# ======================================================================================================================

file = 'datasets\\ACM.json'  # TEST data to evaluate the final model

# ======================================================================================================================
# Read data
# ======================================================================================================================

json_data = []
for line in open(file, 'r', encoding="utf8"):
    json_data.append(json.loads(line))

# convert json to dataframe
コード例 #5
0
import torch
import time
from extractive import ExtractiveSummarizer

fpath = "./lightning_logs/version_12/checkpoints/epoch=12.ckpt"
model = ExtractiveSummarizer.load_from_checkpoint(fpath)
model.to(torch.device("cuda:0"))

print(f"device: {model.device}")

t = "Do you like games? I want to get a refund for a order I placed. 2 items were damaged. Are you a bot? 123082HSF."

print(f"text: {t}\n\n")

k = 100
start = time.time()
for _ in range(k):
    summary_sents = model.predict(t, raw_scores=True)
end = time.time()

print(f"time taken: {(end-start)/k * 1000}ms")

print(summary_sents)

print("\n")

summary = model.predict(t)

print(summary)