def __init__( self, units, activation="tanh", recurrent_activation="sigmoid", use_bias=True, kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros", kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, time_major=False, reset_after=True, **kwargs, ): # return_runtime is a flag for testing, which shows the real backend # implementation chosen by grappler in graph mode. self._return_runtime = kwargs.pop("return_runtime", False) implementation = kwargs.pop("implementation", 2) if implementation == 0: logging.warning("`implementation=0` has been deprecated, " "and now defaults to `implementation=2`." "Please update your layer call.") if "enable_caching_device" in kwargs: cell_kwargs = { "enable_caching_device": kwargs.pop("enable_caching_device") } else: cell_kwargs = {} cell = GRUCell( units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, reset_after=reset_after, dtype=kwargs.get("dtype"), trainable=kwargs.get("trainable", True), **cell_kwargs, ) super().__init__( cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, time_major=time_major, **kwargs, ) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)] # GPU kernel uses following setting by default and not configurable. self._could_use_gpu_kernel = ( self.activation in (activations.tanh, tf.tanh) and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) and recurrent_dropout == 0 and not unroll and use_bias and reset_after and tf.compat.v1.executing_eagerly_outside_functions()) if tf.config.list_logical_devices("GPU"): # Only show the message when there is GPU available, user will not care # about the cuDNN if there isn't any GPU. if self._could_use_gpu_kernel: logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) else: logging.warning(gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name) if gru_lstm_utils.use_new_gru_lstm_impl(): self._defun_wrapper = gru_lstm_utils.DefunWrapper( time_major, go_backwards, "gru")
def __init__(self, units, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, time_major=False, unroll=False, **kwargs): # return_runtime is a flag for testing, which shows the real backend # implementation chosen by grappler in graph mode. self.return_runtime = kwargs.pop('return_runtime', False) super(LSTM, self).__init__( units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=kwargs.pop('implementation', 2), return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, time_major=time_major, unroll=unroll, **kwargs) self.state_spec = [ InputSpec(shape=(None, dim)) for dim in (self.units, self.units) ] self._could_use_gpu_kernel = ( self.activation in (activations.tanh, tf.tanh) and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) and recurrent_dropout == 0 and not unroll and use_bias and tf.compat.v1.executing_eagerly_outside_functions()) if tf.config.list_logical_devices('GPU'): # Only show the message when there is GPU available, user will not care # about the cuDNN if there isn't any GPU. if self._could_use_gpu_kernel: logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) else: logging.warning(gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name) if gru_lstm_utils.use_new_gru_lstm_impl(): self._defun_wrapper = gru_lstm_utils.DefunWrapper( time_major, go_backwards, 'lstm')