def require_fairscale(test_case): """ Decorator marking a test that requires fairscale """ if not is_fairscale_available(): return unittest.skip("test requires fairscale")(test_case) else: return test_case
from transformers.models.fsmt.configuration_fsmt import FSMTConfig from transformers.optimization import ( Adafactor, AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, ) from transformers.trainer_pt_utils import get_tpu_sampler from transformers.training_args import ParallelMode from transformers.utils import is_torch_tpu_available if is_fairscale_available(): from fairscale.optim import OSS logger = logging.get_logger(__name__) arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup, "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, "polynomial": get_polynomial_decay_schedule_with_warmup, "constant": get_constant_schedule, "constant_w_warmup": get_constant_schedule_with_warmup, } class Seq2SeqTrainer(Trainer):