示例#1
0
def test_from_json(tmpdir):
    json_path = json_data(tmpdir)
    dm = TranslationData.from_files(
        backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1
    )
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#2
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_csv("input",
                                  "target",
                                  backbone=TEST_BACKBONE,
                                  train_file=csv_path,
                                  batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#3
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_csv(
        "input",
        "target",
        train_file=csv_path,
        batch_size=1,
    )
    batch = next(iter(dm.train_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
    assert isinstance(batch[DataKeys.TARGET][0], str)
示例#4
0
def test_from_json_with_field(tmpdir):
    json_path = json_data_with_field(tmpdir)
    dm = TranslationData.from_json(
        "input",
        "target",
        train_file=json_path,
        batch_size=1,
        field="data",
    )
    batch = next(iter(dm.train_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
    assert isinstance(batch[DataKeys.TARGET][0], str)
示例#5
0
def test_from_csv(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        return True
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_files(backbone=TEST_BACKBONE,
                                    train_file=csv_path,
                                    input="input",
                                    target="target",
                                    batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#6
0
def test_from_json(tmpdir):
    json_path = json_data(tmpdir)
    dm = TranslationData.from_json(
        "input",
        "target",
        backbone=TEST_BACKBONE,
        train_file=json_path,
        batch_size=1,
        src_lang="en_XX",
        tgt_lang="ro_RO",
    )
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#7
0
def test_from_files(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_files(backbone=TEST_BACKBONE,
                                    train_file=csv_path,
                                    valid_file=csv_path,
                                    test_file=csv_path,
                                    input="input",
                                    target="target",
                                    batch_size=1)
    batch = next(iter(dm.val_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch

    batch = next(iter(dm.test_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
示例#8
0
def from_wmt_en_ro(
    batch_size: int = 4,
    num_workers: int = 0,
    **input_transform_kwargs,
) -> TranslationData:
    """Downloads and loads the WMT EN RO data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip",
                  "./data")
    return TranslationData.from_csv(
        "input",
        "target",
        train_file="data/wmt_en_ro/train.csv",
        val_file="data/wmt_en_ro/valid.csv",
        batch_size=batch_size,
        num_workers=num_workers,
        **input_transform_kwargs,
    )
示例#9
0
def from_wmt_en_ro(
    backbone: str = "Helsinki-NLP/opus-mt-en-ro",
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> TranslationData:
    """Downloads and loads the WMT EN RO data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip",
                  "./data")
    return TranslationData.from_csv(
        "input",
        "target",
        train_file="data/wmt_en_ro/train.csv",
        val_file="data/wmt_en_ro/valid.csv",
        backbone=backbone,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
示例#10
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer

from flash import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

# 2. Load the model from a checkpoint
model = TranslationTask.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")

# 2a. Translate a few sentences!
predictions = model.predict([
    "BBC News went to meet one of the project's first graduates.",
    "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
])
print(predictions)

# 2b. Or generate translations from a sheet file!
datamodule = TranslationData.from_file(
    predict_file="data/wmt_en_ro/predict.csv",
    input="input",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
示例#11
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

# 2. Load the data
datamodule = TranslationData.from_files(train_file="data/wmt_en_ro/train.csv",
                                        val_file="data/wmt_en_ro/valid.csv",
                                        test_file="data/wmt_en_ro/test.csv",
                                        input="input",
                                        target="target",
                                        batch_size=1)

# 3. Build the model
model = TranslationTask()

# 4. Create the trainer
trainer = flash.Trainer(precision=32,
                        gpus=int(torch.cuda.is_available()),
                        fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
示例#12
0
import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

backbone = "Helsinki-NLP/opus-mt-en-ro"

# 2. Load the data
datamodule = TranslationData.from_csv(
    "input",
    "target",
    train_file="data/wmt_en_ro/train.csv",
    val_file="data/wmt_en_ro/valid.csv",
    test_file="data/wmt_en_ro/test.csv",
    batch_size=1,
    backbone=backbone,
)

# 3. Build the model
model = TranslationTask(backbone=backbone)

# 4. Create the trainer
trainer = flash.Trainer(
    precision=16 if torch.cuda.is_available() else 32,
    gpus=int(torch.cuda.is_available()),
    fast_dev_run=True,
)
示例#13
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")

datamodule = TranslationData.from_csv(
    "input",
    "target",
    train_file="data/wmt_en_ro/train.csv",
    val_file="data/wmt_en_ro/valid.csv",
    backbone="Helsinki-NLP/opus-mt-en-ro",
)

# 2. Build the task
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)

# 4. Translate something!
predictions = model.predict([
    "BBC News went to meet one of the project's first graduates.",
    "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",