Пример #1
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest
from typing import List

from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import (
    BlockShardedTSVDataSource,
    SessionTSVDataSource,
    TSVDataSource,
)
from pytext.utils.test import import_tests_module

tests_module = import_tests_module()


class TSVDataSourceTest(unittest.TestCase):
    def setUp(self):
        self.data = TSVDataSource(
            SafeFileWrapper(
                tests_module.test_file("train_dense_features_tiny.tsv")),
            SafeFileWrapper(
                tests_module.test_file("test_dense_features_tiny.tsv")),
            eval_file=None,
            field_names=["label", "slots", "text", "dense"],
            schema={
                "text": str,
                "label": str
            },
        )
Пример #2
0
import itertools
import tempfile
import unittest

import torch
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.data import Data
from pytext.data.sources import TSVDataSource
from pytext.task import create_task
from pytext.task.serialize import load, save
from pytext.task.tasks import DocumentClassificationTask
from pytext.utils import test


tests_module = test.import_tests_module()


class TaskLoadSaveTest(unittest.TestCase):
    def assertModulesEqual(self, mod1, mod2, message=None):
        for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()):
            self.assertTrue(p1.equal(p2), message)

    def test_load_saved_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    data=Data.Config(
                        source=TSVDataSource.Config(