from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
    ExtendSysPath,
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_torch_gpu,
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/../../examples/seq2seq"):
    from run_translation import main  # noqa

set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart"


# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
Beispiel #2
0
from parameterized import parameterized
from transformers import is_torch_available
from transformers.testing_utils import (
    ExtendSysPath,
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_deepspeed,
    require_torch_gpu,
    slow,
)
from transformers.trainer_utils import set_seed

tests_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
root_dir = os.path.dirname(tests_dir)
with ExtendSysPath(tests_dir):
    from test_trainer import TrainerIntegrationCommon  # noqa

    if is_torch_available():
        from test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer  # noqa

set_seed(42)

# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"

# translation
FSMT_TINY = "stas/tiny-wmt19-en-de"
BART_TINY = "sshleifer/bart-tiny-random"
T5_SMALL = "t5-small"
T5_TINY = "patrickvonplaten/t5-tiny-random"
    ExtendSysPath,
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    get_torch_dist_unique_port,
    require_torch,
    require_torch_gpu,
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"):
    from run_translation import main  # noqa

set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart"


# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
Beispiel #4
0
    CaptureStderr,
    ExtendSysPath,
    LoggingLevel,
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    mockenv_context,
    require_deepspeed,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
)
from transformers.trainer_utils import set_seed

bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/.."):
    from test_trainer import TrainerIntegrationCommon  # noqa

    if is_torch_available():
        from test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer  # noqa

set_seed(42)
MBART_TINY = "sshleifer/tiny-mbart"
T5_SMALL = "t5-small"
T5_TINY = "patrickvonplaten/t5-tiny-random"


def load_json(path):
    with open(path) as f:
        return json.load(f)