#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import unittest from typing import List from pytext.data.sources.data_source import SafeFileWrapper from pytext.data.sources.tsv import ( BlockShardedTSVDataSource, SessionTSVDataSource, TSVDataSource, ) from pytext.utils.test import import_tests_module tests_module = import_tests_module() class TSVDataSourceTest(unittest.TestCase): def setUp(self): self.data = TSVDataSource( SafeFileWrapper( tests_module.test_file("train_dense_features_tiny.tsv")), SafeFileWrapper( tests_module.test_file("test_dense_features_tiny.tsv")), eval_file=None, field_names=["label", "slots", "text", "dense"], schema={ "text": str, "label": str }, )
import itertools import tempfile import unittest import torch from pytext.config import LATEST_VERSION, PyTextConfig from pytext.data import Data from pytext.data.sources import TSVDataSource from pytext.task import create_task from pytext.task.serialize import load, save from pytext.task.tasks import DocumentClassificationTask from pytext.utils import test tests_module = test.import_tests_module() class TaskLoadSaveTest(unittest.TestCase): def assertModulesEqual(self, mod1, mod2, message=None): for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()): self.assertTrue(p1.equal(p2), message) def test_load_saved_model(self): with tempfile.NamedTemporaryFile() as snapshot_file: train_data = tests_module.test_file("train_data_tiny.tsv") eval_data = tests_module.test_file("test_data_tiny.tsv") config = PyTextConfig( task=DocumentClassificationTask.Config( data=Data.Config( source=TSVDataSource.Config(