コード例 #1
0
    def _create_non_trackable_mask_cache(self):
        """Create the cache for dropout and recurrent dropout mask.

    Note that the following two masks will be used in "graph function" mode,
    e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
    tensors will be generated differently than in the "graph function" case,
    and they will be cached.

    Also note that in graph mode, we still cache those masks only because the
    RNN could be created with `unroll=True`. In that case, the `cell.call()`
    function will be invoked multiple times, and we want to ensure same mask
    is used every time.

    Also the caches are created without tracking. Since they are not picklable
    by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
    to track it by default.
    """
        self._dropout_mask_cache = backend.ContextValueCache(
            self._create_dropout_mask)
        self._recurrent_dropout_mask_cache = backend.ContextValueCache(
            self._create_recurrent_dropout_mask)
コード例 #2
0
 def __setstate__(self, state):
     state['_dropout_mask_cache'] = backend.ContextValueCache(
         self._create_dropout_mask)
     state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache(
         self._create_recurrent_dropout_mask)
     super(DropoutRNNCellMixin, self).__setstate__(state)
コード例 #3
0
 def __setstate__(self, state):
     state["_dropout_mask_cache"] = backend.ContextValueCache(
         self._create_dropout_mask)
     state["_recurrent_dropout_mask_cache"] = backend.ContextValueCache(
         self._create_recurrent_dropout_mask)
     super().__setstate__(state)