def test_serve(): model = SummarizationTask(TEST_BACKBONE) # TODO: Currently only servable once a preprocess and postprocess have been attached model._preprocess = SummarizationPreprocess(backbone=TEST_BACKBONE) model._postprocess = Seq2SeqPostprocess() model.eval() model.serve()
def test_init_train(tmpdir): if os.name == "nt": # TODO: huggingface stuff timing out on windows return True model = SummarizationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl)
def test_jit(tmpdir): sample_input = { "input_ids": torch.randint(1000, size=(1, 32)), "attention_mask": torch.randint(1, size=(1, 32)), } path = os.path.join(tmpdir, "test.pt") model = SummarizationTask(TEST_BACKBONE) model.eval() # Huggingface only supports `torch.jit.trace` model = torch.jit.trace(model, [sample_input]) torch.jit.save(model, path) model = torch.jit.load(path) out = model(sample_input) assert isinstance(out, torch.Tensor)
# limitations under the License. import torch import flash from flash import download_data, Trainer from flash.text import SummarizationData, SummarizationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data datamodule = SummarizationData.from_files( train_file="data/xsum/train.csv", val_file="data/xsum/valid.csv", test_file="data/xsum/test.csv", input="input", target="target" ) # 3. Build the model model = SummarizationTask() # 4. Create the trainer. Run once on data trainer = flash.Trainer(gpus=int(torch.cuda.is_available()), fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) # 6. Save it! trainer.save_checkpoint("summarization_model_xsum.pt")
def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")): SummarizationTask.load_from_checkpoint("not_a_real_checkpoint.pt")
def test_init_train(tmpdir): model = SummarizationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl)
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning import Trainer from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the model from a checkpoint model = SummarizationTask.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") # 2a. Summarize an article! predictions = model.predict([ """ Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. They came to Brixton to see work which has started to revitalise the borough. It was Charles' first visit to the area since 1996, when he was accompanied by the former South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. She asked me were they ripe and I said yes - they're from the Dominican Republic."" Mr Chong is one of 170 local retailers who accept the Brixton Pound. Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market or in participating shops.
from flash import Trainer from flash.core.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/") datamodule = SummarizationData.from_csv( "input", "target", train_file="data/xsum/train.csv", val_file="data/xsum/valid.csv", ) # 2. Build the task model = SummarizationTask() # 3. Create the trainer and finetune the model trainer = Trainer(max_epochs=3) trainer.finetune(model, datamodule=datamodule) # 4. Summarize some text! predictions = model.predict(""" Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. They came to Brixton to see work which has started to revitalise the borough. It was Charles' first visit to the area since 1996, when he was accompanied by the former South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. She asked me were they ripe and I said yes - they're from the Dominican Republic.""
args = parser.parse_args() # 1. Download the data if args.download: download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data datamodule = SummarizationData.from_csv( "input", "target", train_file=args.train_file, val_file=args.valid_file, test_file=args.test_file, ) # 3. Build the model model = SummarizationTask(backbone=args.backbone, learning_rate=args.learning_rate) # 4. Create the trainer. Run once on data trainer = Trainer(gpus=args.gpus, max_epochs=args.max_epochs, fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) # 6. Save it! trainer.save_checkpoint("summarization_model_xsum.pt")