示例#1
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = SummarizationData.from_csv(
        "input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1, src_lang="en_XX", tgt_lang="ro_RO"
    )
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#2
0
def test_postprocess_tokenizer(tmpdir):
    """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different
    backbone is used."""
    backbone = "sshleifer/bart-tiny-random"
    csv_path = csv_data(tmpdir)
    dm = SummarizationData.from_csv(
        "input", "target", backbone=backbone, train_file=csv_path, batch_size=1, src_lang="en_XX", tgt_lang="ro_RO"
    )
    pipeline = dm.data_pipeline
    pipeline.initialize()
    assert pipeline._postprocess_pipeline.backbone_state.backbone == backbone
    assert pipeline._postprocess_pipeline.tokenizer is not None
示例#3
0
def test_from_files(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = SummarizationData.from_csv("input",
                                    "target",
                                    backbone=TEST_BACKBONE,
                                    train_file=csv_path,
                                    val_file=csv_path,
                                    test_file=csv_path,
                                    batch_size=1)
    batch = next(iter(dm.val_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch

    batch = next(iter(dm.test_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#4
0
def from_xsum(
    batch_size: int = 4,
    num_workers: int = 0,
    **input_transform_kwargs,
) -> SummarizationData:
    """Downloads and loads the XSum data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")
    return SummarizationData.from_csv(
        "input",
        "target",
        train_file="data/xsum/train.csv",
        val_file="data/xsum/valid.csv",
        batch_size=batch_size,
        num_workers=num_workers,
        **input_transform_kwargs,
    )
示例#5
0
def test_from_files(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = SummarizationData.from_csv(
        "input",
        "target",
        train_file=csv_path,
        val_file=csv_path,
        test_file=csv_path,
        batch_size=1,
    )
    batch = next(iter(dm.val_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
    assert isinstance(batch[DataKeys.TARGET][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
    assert isinstance(batch[DataKeys.TARGET][0], str)
示例#6
0
def from_xsum(
    backbone: str = "sshleifer/distilbart-xsum-1-1",
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> SummarizationData:
    """Downloads and loads the XSum data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")
    return SummarizationData.from_csv(
        "input",
        "target",
        train_file="data/xsum/train.csv",
        val_file="data/xsum/valid.csv",
        backbone=backbone,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
示例#7
0
    """
    Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
    people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
    They came to Brixton to see work which has started to revitalise the borough.
    It was Charles' first visit to the area since 1996, when he was accompanied by the former
    South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
    for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
    ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
    She asked me were they ripe and I said yes - they're from the Dominican Republic.""
    Mr Chong is one of 170 local retailers who accept the Brixton Pound.
    Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
    or in participating shops.
    During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
    nearby on an estate off Coldharbour Lane. Mr West said:
    ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
    He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
    Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
    The trust hopes to restore and refurbish the building,
    where once Jimi Hendrix and The Clash played, as a new community and business centre."
    """
])
print(predictions)

# 2b. Or generate summaries from a sheet file!
datamodule = SummarizationData.from_csv(
    "input",
    predict_file="data/xsum/predict.csv",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
示例#8
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")

datamodule = SummarizationData.from_csv(
    "input",
    "target",
    train_file="data/xsum/train.csv",
    val_file="data/xsum/valid.csv",
)

# 2. Build the task
model = SummarizationTask()

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule)

# 4. Summarize some text!
predictions = model.predict("""
    Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
    people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
    They came to Brixton to see work which has started to revitalise the borough.
示例#9
0
    parser.add_argument('--test_file', type=str, default="data/xsum/test.csv")
    parser.add_argument('--max_epochs', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--gpus', type=int, default=None)
    args = parser.parse_args()

    # 1. Download the data
    if args.download:
        download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip",
                      "data/")

    # 2. Load the data
    datamodule = SummarizationData.from_csv(
        "input",
        "target",
        train_file=args.train_file,
        val_file=args.valid_file,
        test_file=args.test_file,
    )

    # 3. Build the model
    model = SummarizationTask(backbone=args.backbone,
                              learning_rate=args.learning_rate)

    # 4. Create the trainer. Run once on data
    trainer = Trainer(gpus=args.gpus,
                      max_epochs=args.max_epochs,
                      fast_dev_run=True)

    # 5. Fine-tune the model
    trainer.finetune(model, datamodule=datamodule)