def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch
def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv( "input", "target", train_file=csv_path, batch_size=1, ) batch = next(iter(dm.train_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str) assert isinstance(batch[DataKeys.TARGET][0], str)
def from_wmt_en_ro( batch_size: int = 4, num_workers: int = 0, **input_transform_kwargs, ) -> TranslationData: """Downloads and loads the WMT EN RO data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data") return TranslationData.from_csv( "input", "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", batch_size=batch_size, num_workers=num_workers, **input_transform_kwargs, )
def from_wmt_en_ro( backbone: str = "Helsinki-NLP/opus-mt-en-ro", batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs, ) -> TranslationData: """Downloads and loads the WMT EN RO data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data") return TranslationData.from_csv( "input", "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", backbone=backbone, batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, )
def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv( "input", "target", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, src_lang="en_XX", tgt_lang="ro_RO", ) batch = next(iter(dm.val_dataloader())) assert "labels" in batch assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) assert "labels" in batch assert "input_ids" in batch
import flash from flash.core.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") backbone = "Helsinki-NLP/opus-mt-en-ro" # 2. Load the data datamodule = TranslationData.from_csv( "input", "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", test_file="data/wmt_en_ro/test.csv", batch_size=1, backbone=backbone, ) # 3. Build the model model = TranslationTask(backbone=backbone) # 4. Create the trainer trainer = flash.Trainer( precision=16 if torch.cuda.is_available() else 32, gpus=int(torch.cuda.is_available()), fast_dev_run=True, )
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash.core.data.utils import download_data from flash.text import TranslationData, TranslationTask # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data") datamodule = TranslationData.from_csv( "input", "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", backbone="Helsinki-NLP/opus-mt-en-ro", ) # 2. Build the task model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro") # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) # 4. Translate something! predictions = model.predict([ "BBC News went to meet one of the project's first graduates.", "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",