def _set_global_seed(self, seed): """Set a global eager mode seed for random ops.""" self._seed = seed self._rng = random.Random(self._seed) # Also clear the kernel cache, to reset any existing seeds if self._context_handle is not None: pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
def monkey_patch_tf_get_seed(seed: int, default_op_seed: int = 1923746) -> None: """ Monkey patching tensorflow.random.get_seed to avoid the increasing memory usage arising from repeated random sampling from tensorflow distributions. This code is taken from https://github.com/lerobitaille/tf-issue-36164-workaround which remedies issue 36164 (https://github.com/tensorflow/tensorflow/issues/36164). We have raised our own clearer and concise issue which should be the point at which should be the reference point for this memory leak: https://github.com/tensorflow/tensorflow/issues/37252 :param seed: Seed to set as the TensorFlow global seed. :param default_op_seed: Default seed for any random operations if required. """ warn( "WARNING: Patching native TensorFlow functionality to avoid memory leak when setting " "a random seed.") warn("WARNING: Patch required due to TensorFlow issue 37252. " "Check if the issue is resolved at " "https://github.com/tensorflow/tensorflow/issues/37252") # Lazy imports to show which imports to remove once the issue is resolved and to avoid wider # usage of monkey patching and usage of the TensorFlow back end which involves imports the # linter does not like. # pylint: disable=no-name-in-module,import-error from tensorflow.python.eager import context from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import random_seed # Remove gorilla dependency completely when issue fixed. (Remove from requirements.txt) import gorilla def better_get_seed(global_seed, op_seed): if op_seed is not None: return global_seed, op_seed else: return global_seed, default_op_seed # Monkey Patch get_seed. def func(op_seed): better_get_seed(seed, op_seed) settings = gorilla.Settings(allow_hit=True, store_hit=True) patch = gorilla.Patch(random_seed, 'get_seed', func, settings=settings) gorilla.apply(patch) # Also clear the kernel cache, to reset any existing seeds # pylint: disable=protected-access _context = context.context() if _context._context_handle is not None: pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)
def __init__(self, configuration_name: str, max_seq_len: int, cache_dir: typing.Optional[str] = None, seed: typing.Optional[int] = None, max_task_examples: float = 2e21, temperature: float = 2., dynamic_mixing: bool = True): self.max_seq_len = max_seq_len self.cache_dir = cache_dir # Task mixing constants self.max_examples = max_task_examples self.temperature = temperature self.dynamic_mixing = dynamic_mixing if seed: logging.debug('Setting seed to %d', seed) self.seed = seed np.random.seed(seed) # Alternate get_seed, see https://github.com/lerobitaille/tf-issue-36164-workaround def _patched_get_seed(op_seed): if op_seed is not None: return seed, op_seed else: return seed, _DEFAULT_OP_SEED # Monkey batch get_seed from tf.random_seed patch_settings = gorilla.Settings(allow_hit=True, store_hit=True) seed_patch = gorilla.Patch(random_seed, 'get_seed', _patched_get_seed, settings=patch_settings) gorilla.apply(seed_patch) # Also clear the kernel cache, to reset any existing seeds _context = tf_eager_context.context() # noinspection PyProtectedMember if _context._context_handle is not None: # noinspection PyProtectedMember pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle) logging.debug('Loading configuration from %s...') self.config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained( configuration_name) logging.debug('Loading tokenizer from %s...', configuration_name) self.tokenizer: transformers.PreTrainedTokenizer = \ transformers.AutoTokenizer.from_pretrained(configuration_name, config=self.config) # if not self.tokenizer.pad_token: # self.tokenizer.pad_token = self.tokenizer.eos_token # self.config.pad_token_id = self.tokenizer.pad_token_id # logging.warning('Tokenizer does not provide a pad token, using %s (id: %d)', # self.tokenizer.pad_token, self.tokenizer.pad_token_id) self.encoder_fn = functools.partial(self.tokenizer.encode_plus, add_special_tokens=False, add_space_before_punct_symbol=True, max_length=self.max_seq_len, pad_to_max_length=True, truncation_strategy="only_first", return_token_type_ids=True, return_attention_mask=True) self.decoder_fn = functools.partial(self.tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=False)