def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = TranslationData.from_files( backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 ) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch
def test_from_csv(tmpdir): if os.name == "nt": # TODO: huggingface stuff timing out on windows return True csv_path = csv_data(tmpdir) dm = TranslationData.from_files(backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch
def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_files(backbone=TEST_BACKBONE, train_file=csv_path, valid_file=csv_path, test_file=csv_path, input="input", target="target", batch_size=1) batch = next(iter(dm.val_dataloader())) assert "labels" in batch assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) assert "labels" in batch assert "input_ids" in batch
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash import download_data from flash.text import TranslationData, TranslationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the data datamodule = TranslationData.from_files(train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", test_file="data/wmt_en_ro/test.csv", input="input", target="target", batch_size=1) # 3. Build the model model = TranslationTask() # 4. Create the trainer trainer = flash.Trainer(precision=32, gpus=int(torch.cuda.is_available()), fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) # 6. Test model