def test_from_json(tmpdir):
    json_path = json_data(tmpdir)
    dm = TranslationData.from_files(
        backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1
    )
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
Beispiel #2
0
def test_from_csv(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        return True
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_files(backbone=TEST_BACKBONE,
                                    train_file=csv_path,
                                    input="input",
                                    target="target",
                                    batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
Beispiel #3
0
def test_from_files(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TranslationData.from_files(backbone=TEST_BACKBONE,
                                    train_file=csv_path,
                                    valid_file=csv_path,
                                    test_file=csv_path,
                                    input="input",
                                    target="target",
                                    batch_size=1)
    batch = next(iter(dm.val_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch

    batch = next(iter(dm.test_dataloader()))
    assert "labels" in batch
    assert "input_ids" in batch
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

# 2. Load the data
datamodule = TranslationData.from_files(train_file="data/wmt_en_ro/train.csv",
                                        val_file="data/wmt_en_ro/valid.csv",
                                        test_file="data/wmt_en_ro/test.csv",
                                        input="input",
                                        target="target",
                                        batch_size=1)

# 3. Build the model
model = TranslationTask()

# 4. Create the trainer
trainer = flash.Trainer(precision=32,
                        gpus=int(torch.cuda.is_available()),
                        fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test model