Example #1
0
 def _set_global_seed(self, seed):
     """Set a global eager mode seed for random ops."""
     self._seed = seed
     self._rng = random.Random(self._seed)
     # Also clear the kernel cache, to reset any existing seeds
     if self._context_handle is not None:
         pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
Example #2
0
def monkey_patch_tf_get_seed(seed: int,
                             default_op_seed: int = 1923746) -> None:
    """
    Monkey patching tensorflow.random.get_seed to avoid the increasing memory usage arising from
    repeated random sampling from tensorflow distributions.

    This code is taken from https://github.com/lerobitaille/tf-issue-36164-workaround which remedies
    issue 36164 (https://github.com/tensorflow/tensorflow/issues/36164).

    We have raised our own clearer and concise issue which should be the point at which should be
    the reference point for this memory leak: https://github.com/tensorflow/tensorflow/issues/37252

    :param seed: Seed to set as the TensorFlow global seed.
    :param default_op_seed: Default seed for any random operations if required.
    """
    warn(
        "WARNING: Patching native TensorFlow functionality to avoid memory leak when setting "
        "a random seed.")
    warn("WARNING: Patch required due to TensorFlow issue 37252. "
         "Check if the issue is resolved at "
         "https://github.com/tensorflow/tensorflow/issues/37252")
    # Lazy imports to show which imports to remove once the issue is resolved and to avoid wider
    # usage of monkey patching and usage of the TensorFlow back end which involves imports the
    # linter does not like.
    # pylint: disable=no-name-in-module,import-error
    from tensorflow.python.eager import context
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.framework import random_seed
    # Remove gorilla dependency completely when issue fixed. (Remove from requirements.txt)
    import gorilla

    def better_get_seed(global_seed, op_seed):
        if op_seed is not None:
            return global_seed, op_seed
        else:
            return global_seed, default_op_seed

    # Monkey Patch get_seed.
    def func(op_seed):
        better_get_seed(seed, op_seed)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(random_seed, 'get_seed', func, settings=settings)
    gorilla.apply(patch)

    # Also clear the kernel cache, to reset any existing seeds
    # pylint: disable=protected-access
    _context = context.context()
    if _context._context_handle is not None:
        pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)
Example #3
0
    def __init__(self,
                 configuration_name: str,
                 max_seq_len: int,
                 cache_dir: typing.Optional[str] = None,
                 seed: typing.Optional[int] = None,
                 max_task_examples: float = 2e21,
                 temperature: float = 2.,
                 dynamic_mixing: bool = True):
        self.max_seq_len = max_seq_len
        self.cache_dir = cache_dir

        # Task mixing constants
        self.max_examples = max_task_examples
        self.temperature = temperature
        self.dynamic_mixing = dynamic_mixing

        if seed:
            logging.debug('Setting seed to %d', seed)
            self.seed = seed
            np.random.seed(seed)

            # Alternate get_seed, see https://github.com/lerobitaille/tf-issue-36164-workaround
            def _patched_get_seed(op_seed):
                if op_seed is not None:
                    return seed, op_seed
                else:
                    return seed, _DEFAULT_OP_SEED

            # Monkey batch get_seed from tf.random_seed
            patch_settings = gorilla.Settings(allow_hit=True, store_hit=True)
            seed_patch = gorilla.Patch(random_seed,
                                       'get_seed',
                                       _patched_get_seed,
                                       settings=patch_settings)
            gorilla.apply(seed_patch)

        # Also clear the kernel cache, to reset any existing seeds
        _context = tf_eager_context.context()
        # noinspection PyProtectedMember
        if _context._context_handle is not None:
            # noinspection PyProtectedMember
            pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)

        logging.debug('Loading configuration from %s...')
        self.config: transformers.PretrainedConfig = transformers.AutoConfig.from_pretrained(
            configuration_name)

        logging.debug('Loading tokenizer from %s...', configuration_name)
        self.tokenizer: transformers.PreTrainedTokenizer = \
            transformers.AutoTokenizer.from_pretrained(configuration_name, config=self.config)

        # if not self.tokenizer.pad_token:
        #     self.tokenizer.pad_token = self.tokenizer.eos_token
        #     self.config.pad_token_id = self.tokenizer.pad_token_id
        #     logging.warning('Tokenizer does not provide a pad token, using %s (id: %d)',
        #                     self.tokenizer.pad_token, self.tokenizer.pad_token_id)

        self.encoder_fn = functools.partial(self.tokenizer.encode_plus,
                                            add_special_tokens=False,
                                            add_space_before_punct_symbol=True,
                                            max_length=self.max_seq_len,
                                            pad_to_max_length=True,
                                            truncation_strategy="only_first",
                                            return_token_type_ids=True,
                                            return_attention_mask=True)

        self.decoder_fn = functools.partial(self.tokenizer.decode,
                                            skip_special_tokens=True,
                                            clean_up_tokenization_spaces=False)