def __init__(self, output_dim, force_real=False, use_bias=True, weight_initializer='xavier_uniform', bias_initializer='zeros', weight_regularizer=None, bias_regularizer=None, **kwargs): if not np.isscalar(output_dim): raise TypeError('!! output_dim must be a scalar, not {}'.format( type(output_dim))) self._output_dim = output_dim self._force_real = force_real self._use_bias = use_bias self._weight_initializer = initializers.get(weight_initializer) self._bias_initializer = initializers.get(bias_initializer) self._weight_regularizer = regularizers.get(weight_regularizer, **kwargs) self._bias_regularizer = regularizers.get(bias_regularizer, **kwargs) self.weights = None self.biases = None self.neuron_scale = [output_dim]
def __init__(self, state_size, periods=None, activation='tanh', use_bias=True, weight_initializer='xavier_uniform', bias_initializer='zeros', **kwargs): """ :param state_size: State size :param periods: a list of integers. If not provided, periods will be set to a default exponential series {2^{i-1}}_{i=0}^{state_size} """ # Call parent's constructor RNet.__init__(self, ClockworkRNN.net_name) # Attributes self._state_size = checker.check_positive_integer(state_size) self._periods = self._get_periods(periods, **kwargs) self._activation = activations.get(activation, **kwargs) self._use_bias = checker.check_type(use_bias, bool) self._weight_initializer = initializers.get(weight_initializer) self._bias_initializer = initializers.get(bias_initializer) # modules = [(start_index, size, period)+] self._modules = [] self._init_modules(**kwargs)
def __init__( self, state_size, activation='tanh', use_bias=True, weight_initializer='xavier_normal', bias_initializer='zeros', **kwargs): """ :param state_size: state size: positive int :param activation: activation: string or callable :param use_bias: whether to use bias :param weight_initializer: weight initializer identifier :param bias_initializer: bias initializer identifier """ # Call parent's constructor RNet.__init__(self, self.net_name) # Attributes self._state_size = state_size self._activation = activations.get(activation, **kwargs) self._use_bias = checker.check_type(use_bias, bool) self._weight_initializer = initializers.get(weight_initializer) self._bias_initializer = initializers.get(bias_initializer) self._output_scale = state_size
def _initial_define_test(self): for net in self.trunk_net: for var in net.var_list: if 'biases' in var.name or 'bias' in var.name: var.initializers = initializers.get(tf.zeros_initializer()) else: var.initializers = initializers.get('identity')
def __init__( self, state_size, activation='tanh', use_bias=True, weight_initializer='xavier_uniform', bias_initializer='zeros', input_gate=True, output_gate=True, forget_gate=True, with_peepholes=False, **kwargs): """ :param state_size: state size: positive int :param activation: activation: string or callable :param use_bias: whether to use bias :param weight_initializer: weight initializer identifier :param bias_initializer: bias initializer identifier """ # Call parent's constructor RNet.__init__(self, BasicLSTMCell.net_name) # Attributes self._state_size = state_size self._activation = activations.get(activation, **kwargs) self._use_bias = checker.check_type(use_bias, bool) self._weight_initializer = initializers.get(weight_initializer) self._bias_initializer = initializers.get(bias_initializer) self._input_gate = checker.check_type(input_gate, bool) self._output_gate = checker.check_type(output_gate, bool) self._forget_gate = checker.check_type(forget_gate, bool) self._with_peepholes = checker.check_type(with_peepholes, bool) self._output_scale = state_size
def __init__( self, state_size, activation='tanh', weight_initializer='xavier_normal', use_bias=True, couple_fi=False, cell_bias_initializer='zeros', input_bias_initializer='zeros', output_bias_initializer='zeros', forget_bias_initializer='zeros', use_output_activation=True, **kwargs): # Call parent's constructor CellBase.__init__(self, activation, weight_initializer, use_bias, cell_bias_initializer, **kwargs) # Specific attributes self._state_size = checker.check_positive_integer(state_size) self._input_bias_initializer = initializers.get(input_bias_initializer) self._output_bias_initializer = initializers.get(output_bias_initializer) self._forget_bias_initializer = initializers.get(forget_bias_initializer) self._couple_fi = checker.check_type(couple_fi, bool) self._use_output_activation = checker.check_type( use_output_activation, bool)
def _get_sparse_weights(x_dim, y_dim, heads=1, use_bit_max=False, logits_initializer='random_normal', coef_initializer='random_normal', return_package=False): logits_initializer = initializers.get(logits_initializer) coef_initializer = initializers.get(coef_initializer) # Get 3-D variable of shape (x_dim, y_dim, heads) if use_bit_max: num_bits = int(np.ceil(np.log2(x_dim))) logits = tf.get_variable('brick', shape=[num_bits, y_dim, heads], dtype=hub.dtype, initializer=logits_initializer, trainable=True) activation = expand_bit(logits, axis=0) # Trim if necessary if 2**num_bits > x_dim: activation = _trim_and_normalize(activation, axis=0, dim=x_dim, normalize=True) else: logits = tf.get_variable('logits', shape=[x_dim, y_dim, heads], dtype=hub.dtype, initializer=logits_initializer, trainable=True) activation = tf.nn.softmax(logits, axis=0) # Get coef variable of shape (y_dim, heads) coef_shape = [x_dim, y_dim, 1] if hub.full_weight else [1, y_dim, heads] coef = tf.get_variable('coef', shape=coef_shape, dtype=hub.dtype, initializer=coef_initializer, trainable=True) # Calculate weight matrix weights = tf.reduce_sum(tf.multiply(coef, activation), axis=-1) assert weights.shape.as_list() == [x_dim, y_dim] context.sparse_weights_list.append(weights) # Return if return_package: package = { 'logits': logits, 'activation': activation, 'coef': coef, } return weights, package else: return weights
def masked_neurons(x, num, scope, activation=None, s=None, x_mask=None, s_mask=None, use_bias=True, weight_initializer='glorot_normal', bias_initializer='zeros', **kwargs): # Sanity check assert isinstance(x, tf.Tensor) # Get activation and initializers if activation is not None: activation = activations.get(activation) weight_initializer = initializers.get(weight_initializer) bias_initializer = initializers.get(bias_initializer) def matmul(x, y): batch_matmul = len(x.shape) == len(y.shape) - 1 if batch_matmul: x = tf.expand_dims(x, axis=1) assert len(x.shape) == len(y.shape) output = tf.matmul(x, y) if batch_matmul: output = tf.squeeze(output, axis=1) return output def get_weights(tensor, name, mask=None): shape = [get_dimension(tensor), num] if mask is None: return get_variable(name, shape, weight_initializer) else: return get_masked_weights(name, shape, weight_initializer, mask) def forward(): # x -> y Wx = get_weights(x, 'Wx', x_mask) # .. do matrix multiplication net_y = matmul(x, Wx) # s -> y if s exists if s is not None: assert isinstance(s, tf.Tensor) Ws = get_weights(s, 'Ws', s_mask) # .. add s * Ws to net_y net_y = tf.add(net_y, matmul(s, Ws)) # Add bias if necessary if use_bias: b = get_bias('bias', num, bias_initializer) net_y = tf.nn.bias_add(net_y, b) # Activate if necessary if activation is not None: return activation(net_y) else: return net_y with tf.variable_scope(scope): y = forward() # Return return y
def layer_normalize(x, axis=-1, epsilon=1e-3, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones'): """Layer normalization for single axis""" # Check axis x_shape = x.shape.as_list() ndims = len(x_shape) assert isinstance(axis, int) if axis < 0: axis += ndims assert 0 <= axis < ndims # Get gamma and beta gamma, beta = None, None param_shape = [x_shape[axis]] if scale: gamma = tf.get_variable( name='gamma', shape=param_shape, dtype=hub.dtype, initializer=initializers.get(gamma_initializer), trainable=True) if center: beta = tf.get_variable( name='beta', shape=param_shape, dtype=hub.dtype, initializer=initializers.get(beta_initializer), trainable=True) # Calculate the moments on the last axis (layer activations). mean, variance = tf.nn.moments(x, axis, keep_dims=True) # Broadcast gamma and beta broadcast_shape = [1] * ndims broadcast_shape[axis] = x_shape[axis] def _broadcast(v): if v is not None and len(v.shape) != ndims and axis != ndims - 1: return tf.reshape(v, broadcast_shape) return v scale, offset = _broadcast(gamma), _broadcast(beta) # Compute layer normalization using the batch_normalization function return tf.nn.batch_normalization(x, mean, variance, offset=offset, scale=scale, variance_epsilon=epsilon)
def __init__(self, state_size, mem_fc=True, **kwargs): # Call parent's constructor RNet.__init__(self, self.net_name) # Attributes self._state_size = state_size self._activation = activations.get('tanh', **kwargs) # self._use_bias = True self._weight_initializer = initializers.get('xavier_normal') self._bias_initializer = initializers.get('zeros') self._output_scale = state_size self._fully_connect_memories = mem_fc
def _initial_define(self): for i, net in enumerate(self.children): if i == 0 or i == len(self.children) - 1: continue if net.is_branch is True: # define in traning step continue else: for f in net.children: if isinstance(f, Linear): f._weight_initializer = initializers.get('identity') f._bias_initializer = initializers.get( tf.zeros_initializer())
def _identity_define(self): for i, net in enumerate(self.children): if i == 0: continue else: for var in net.var_list: var.initializers = initializers.get(tf.zeros_initializer())
def __init__(self, kernel_key, num_neurons, input_, suffix, weight_initializer='glorot_normal', prune_frac=0, LN=False, gain_initializer='ones', etch=None, weight_dropout=0.0, **kwargs): # Call parent's initializer super().__init__(kernel_key, num_neurons, weight_initializer, prune_frac, etch=etch, weight_dropout=weight_dropout, **kwargs) self.input_ = checker.check_type(input_, tf.Tensor) self.suffix = checker.check_type(suffix, str) self.LN = checker.check_type(LN, bool) self.gain_initializer = initializers.get(gain_initializer)
def sparse_affine(x, y_dim, heads=1, use_bit_max=False, logits_initializer='random_normal', coef_initializer='random_normal', use_bias=True, bias_initializer='zeros', return_package=False): """This method should be used inside a variable scope""" bias_initializer = initializers.get(bias_initializer) # Sanity check assert isinstance(x, tf.Tensor) and len(x.shape) == 2 x_dim = get_dimension(x) # Get sparse weights weights, package = _get_sparse_weights(x_dim, y_dim, heads, use_bit_max, logits_initializer, coef_initializer, True) assert weights.shape.as_list() == [x_dim, y_dim] # Calculate y y = tf.matmul(x, weights) bias = get_bias('bias', y_dim, bias_initializer) if use_bias else None y = tf.nn.bias_add(y, bias) # Return if return_package: package['weights'] = weights return y, package else: return y
def __init__(self, state_size, use_reset_gate=True, activation='tanh', weight_initializer='xavier_normal', use_bias=True, bias_initializer='zeros', z_bias_initializer='zeros', reset_who='s', dropout=0.0, zoneout=0.0, **kwargs): """ :param reset_who: in ('x', 'y') 'x': a_h = W_h * (h_{t-1} \odot r_t) 'y': a_h = r_t \odot (W_h * h_{t-1}) \hat{h}_t = \varphi(Wx*x + a_h + b) in which r_t is the reset gate at time step t, \odot is the Hadamard product, W_h is the hidden-to-hidden matrix """ # Call parent's constructor CellBase.__init__(self, activation, weight_initializer, use_bias, bias_initializer, **kwargs) # Specific attributes self._state_size = checker.check_positive_integer(state_size) self._use_reset_gate = checker.check_type(use_reset_gate, bool) self._z_bias_initializer = initializers.get(z_bias_initializer) self._dropout_rate = checker.check_type(dropout, float) self._zoneout_rate = checker.check_type(zoneout, float) assert reset_who in ('s', 'a') self._reset_who = reset_who
def layer_normalization(a, gain_initializer, use_bias=False): from tframe.operators.apis.neurobase import NeuroBase assert not use_bias gain_initializer = initializers.get(gain_initializer) return NeuroBase.layer_normalize(a, axis=1, center=False, gamma_initializer=gain_initializer)
def hyper16(self, seed, weight_initializer): shape = [linker.get_dimension(seed), self.num_neurons] weight_initializer = initializers.get(weight_initializer) Wzb = self._get_weights('Wzb', shape, initializer=weight_initializer) bias = seed @ Wzb b0 = tf.get_variable('bias', shape=[self.num_neurons], dtype=hub.dtype, initializer=self.initializer) return tf.nn.bias_add(bias, b0)
def __init__( self, state_size, activation='tanh', weight_initializer='xavier_normal', input_gate=True, forget_gate=True, output_gate=True, use_g_bias=True, g_bias_initializer='zeros', use_i_bias=True, i_bias_initializer='zeros', use_f_bias=True, f_bias_initializer='zeros', use_o_bias=True, o_bias_initializer='zeros', output_as_mem=True, fully_connect_memory=True, activate_memory=True, truncate_grad=False, **kwargs): # Call parent's constructor RNet.__init__(self, self.net_name) # Attributes self._state_size = state_size self._input_gate = checker.check_type(input_gate, bool) self._forget_gate = checker.check_type(forget_gate, bool) self._output_gate = checker.check_type(output_gate, bool) self._activation = activations.get(activation, **kwargs) self._weight_initializer = initializers.get(weight_initializer) self._use_g_bias = checker.check_type(use_g_bias, bool) self._g_bias_initializer = initializers.get(g_bias_initializer) self._use_i_bias = checker.check_type(use_i_bias, bool) self._i_bias_initializer = initializers.get(i_bias_initializer) self._use_f_bias = checker.check_type(use_f_bias, bool) self._f_bias_initializer = initializers.get(f_bias_initializer) self._use_o_bias = checker.check_type(use_o_bias, bool) self._o_bias_initializer = initializers.get(o_bias_initializer) self._activate_mem = checker.check_type(activate_memory, bool) self._truncate_grad = checker.check_type(truncate_grad, bool) self._fc_memory = checker.check_type(fully_connect_memory, bool) self._output_as_mem = checker.check_type(output_as_mem, bool) self._kwargs = kwargs
def __init__(self, output_dim, neurons_per_amu=3, activation='tanh', use_bias=True, weight_initializer='xavier_uniform', bias_initializer='zeros', **kwargs): # Call parent's constructor RNet.__init__(self, AMU.net_name) # Attributes self._output_dim = output_dim self._neurons_per_amu = neurons_per_amu self._activation = activations.get(activation, **kwargs) self._use_bias = checker.check_type(use_bias, bool) self._weight_initializer = initializers.get(weight_initializer) self._bias_initializer = initializers.get(bias_initializer) self._output_scale = output_dim
def __init__(self, num_neurons, heads=1, use_bit_max=False, logits_initializer='random_normal', coef_initializer='random_normal', use_bias=True, bias_initializer='zeros', **kwargs): self.num_neurons = checker.check_positive_integer(num_neurons) self.heads = checker.check_positive_integer(heads) self.use_bit_max = checker.check_type(use_bit_max, bool) self._logits_initializer = initializers.get(logits_initializer) self._coef_initializer = initializers.get(coef_initializer) self._use_bias = checker.check_type(use_bias, bool) self._bias_initializer = initializers.get(bias_initializer) self.neuron_scale = [self.num_neurons] self._kwargs = kwargs
def __init__(self, vocab_size, hidden_size, initializer='default'): # Initialize keep probability until while linking to put the # the placeholder in the right name scope # self._keep_prob = None self._vocab_size = vocab_size self._hidden_size = hidden_size if initializer == 'default': initializer = tf.random_uniform_initializer(-0.1, 0.1) self._initializer = initializers.get(initializer) self.neuron_scale = [hidden_size]
def __init__(self, activation=None, weight_initializer='xavier_normal', use_bias=False, bias_initializer='zeros', layer_normalization=False, weight_dropout=0.0, **kwargs): if activation: activation = activations.get(activation) self._activation = activation self._weight_initializer = initializers.get(weight_initializer) self._use_bias = checker.check_type(use_bias, bool) self._bias_initializer = initializers.get(bias_initializer) self._layer_normalization = checker.check_type(layer_normalization, bool) self._gain_initializer = initializers.get( kwargs.get('gain_initializer', 'ones')) self._normalize_each_psi = kwargs.pop('normalize_each_psi', False) self._weight_dropout = checker.check_type(weight_dropout, float) assert 0 <= self._weight_dropout < 1 self._nb_kwargs = kwargs
def _get_weights(self, name, shape, dtype=None, initializer=None): if initializer is None: initializer = self.initializer else: initializer = initializers.get(initializer) # Set default dtype if not specified if dtype is None: dtype = hub.dtype # Get regularizer if necessary regularizer = None if hub.use_global_regularizer: regularizer = hub.get_global_regularizer() # Get constraint if necessary constraint = hub.get_global_constraint() # Get weights weights = tf.get_variable(name, shape, dtype=dtype, initializer=initializer, regularizer=regularizer, constraint=constraint) # If weight dropout is positive, dropout and return if self.weight_dropout > 0: return linker.dropout(weights, self.weight_dropout, rescale=True) # If no mask is needed to be created, return weight variable directly if not any( [self.prune_is_on, self.being_etched, hub.force_to_use_pruner]): return weights # Register, context.pruner should be created in early model.build assert context.pruner is not None # Merged lottery logic into etch logic if self.prune_is_on: assert not self.being_etched self.etch = 'lottery:prune_frac={}'.format(self.prune_frac) # Register etch kernel to pruner masked_weights = context.pruner.register_to_dense(weights, self.etch) # if self.prune_is_on: # masked_weights = context.pruner.register_to_dense( # weights, self.prune_frac) # else: # # TODO # assert self.being_etched # mask = self._get_etched_surface(weights) # masked_weights = context.pruner.register_with_mask(weights, mask) # Return assert isinstance(masked_weights, tf.Tensor) return masked_weights
def __init__(self, layer_width, num_layers, activation='relu', use_bias=True, weight_initializer='xavier_normal', bias_initializer='zeros', t_bias_initializer=-1, **kwargs): # Call parent's constructor LayerWithNeurons.__init__(self, activation, weight_initializer, use_bias, bias_initializer, **kwargs) self._layer_width = checker.check_positive_integer(layer_width) self._num_layers = checker.check_positive_integer(num_layers) assert isinstance(activation, str) self._activation_string = activation self._t_bias_initializer = initializers.get(t_bias_initializer)
def __init__(self, kernel_key, num_neurons, initializer, prune_frac=0, etch=None, weight_dropout=0.0, **kwargs): self.kernel_key = checker.check_type(kernel_key, str) self.kernel = self._get_kernel(kernel_key) self.num_neurons = checker.check_positive_integer(num_neurons) self.initializer = initializers.get(initializer) assert 0 <= prune_frac <= 1 # IMPORTANT self.prune_frac = prune_frac * hub.pruning_rate_fc self.etch = etch self.weight_dropout = checker.check_type(weight_dropout, float) assert 0 <= self.weight_dropout < 1 self.kwargs = kwargs self._check_arguments()
def neurons(num, external_input, activation=None, memory=None, fc_memory=True, scope=None, use_bias=True, truncate=False, num_or_size_splits=None, weight_initializer='glorot_uniform', bias_initializer='zeros', weight_regularizer=None, bias_regularizer=None, activity_regularizer=None, **kwargs): """Analogous to tf.keras.layers.Dense""" # Get activation, initializers and regularizers if activation is not None: activation = activations.get(activation) weight_initializer = initializers.get(weight_initializer) bias_initializer = initializers.get(bias_initializer) weight_regularizer = regularizers.get(weight_regularizer) bias_regularizer = regularizers.get(bias_regularizer) activity_regularizer = regularizers.get(activity_regularizer) # a. Check prune configs if 'prune_frac' in kwargs.keys(): x_prune_frac, s_prune_frac = (kwargs['prune_frac'], ) * 2 else: x_prune_frac = kwargs.get('x_prune_frac', 0) s_prune_frac = kwargs.get('s_prune_frac', 0) prune_is_on = hub.pruning_rate_fc > 0.0 and x_prune_frac + s_prune_frac > 0 # b. Check sparse configs x_heads = kwargs.get('x_heads', 0) s_heads = kwargs.get('s_heads', 0) sparse_is_on = x_heads + s_heads > 0 # :: Decide to concatenate or not considering a and b # .. a if memory is None: should_concate = False elif prune_is_on: should_concate = x_prune_frac == s_prune_frac else: should_concate = fc_memory # .. b should_concate = should_concate and not sparse_is_on # separate_memory_neurons = memory is not None and not should_concate def get_weights(name, tensor, p_frac, heads): shape = [get_dimension(tensor), num] if prune_is_on and p_frac > 0: assert heads == 0 return get_weights_to_prune(name, shape, weight_initializer, p_frac) elif heads > 0: return _get_sparse_weights(shape[0], shape[1], heads, use_bit_max=True, coef_initializer=weight_initializer) else: return get_variable(name, shape, weight_initializer) def forward(): # Prepare a weight list for potential regularizer calculation weight_list = [] # Get x x = (tf.concat([external_input, memory], axis=1, name='x_concat_s') if should_concate else external_input) # - Calculate net input for x # .. get weights name = 'Wx' if separate_memory_neurons else 'W' Wx = get_weights(name, x, x_prune_frac, x_heads) weight_list.append(Wx) # .. append weights to context, currently only some extractors will use it context.weights_list.append(Wx) # .. do matrix multiplication net_y = get_matmul(truncate)(x, Wx) # - Calculate net input for memory and add to net_y if necessary if separate_memory_neurons: if not fc_memory: assert not (prune_is_on and s_prune_frac > 0) memory_dim = get_dimension(memory) assert memory_dim == num Ws = get_variable('Ws', [1, num], weight_initializer) net_s = get_multiply(truncate)(memory, Ws) else: assert prune_is_on or sparse_is_on Ws = get_weights('Ws', memory, s_prune_frac, s_heads) net_s = get_matmul(truncate)(memory, Ws) # Append Ws to weight list and add net_s to net_y weight_list.append(Ws) net_y = tf.add(net_y, net_s) # - Add bias if necessary b = None if use_bias: b = get_bias('bias', num, bias_initializer) net_y = tf.nn.bias_add(net_y, b) # - Activate and return if callable(activation): net_y = activation(net_y) return net_y, weight_list, b if scope is not None: with tf.variable_scope(scope): y, W_list, b = forward() else: y, W_list, b = forward() # Add regularizer if necessary if callable(weight_regularizer): context.add_loss_tensor( tf.add_n([weight_regularizer(w) for w in W_list])) if callable(bias_regularizer) and b is not None: context.add_loss_tensor(bias_regularizer(b)) if callable(activity_regularizer): context.add_loss_tensor(activity_regularizer(y)) # Split if necessary if num_or_size_splits is not None: return tf.split(y, num_or_size_splits=num_or_size_splits, axis=1) return y
def get_variable(name, shape, initializer='glorot_uniform'): initializer = initializers.get(initializer) v = tf.get_variable(name, shape, dtype=hub.dtype, initializer=initializer) return v
def get_bias(name, dim, initializer='zeros'): initializer = initializers.get(initializer) return tf.get_variable(name, shape=[dim], dtype=hub.dtype, initializer=initializer)
def __init__( self, state_size, cell_activation='sigmoid', # g cell_activation_range=(-2, 2), memory_activation='sigmoid', # h memory_activation_range=(-1, 1), weight_initializer='random_uniform', weight_initial_range=(-0.1, 0.1), use_cell_bias=False, cell_bias_initializer='random_uniform', cell_bias_init_range=(-0.1, 0.1), use_in_bias=True, in_bias_initializer='zeros', use_out_bias=True, out_bias_initializer='zeros', truncate=True, forward_gate=True, **kwargs): # Call parent's constructor RNet.__init__(self, OriginalLSTMCell.net_name) # Set state size self._state_size = state_size # Set activation # .. In LSTM98, cell activation is referred to as 'g', # .. while memory activation is 'h' and gate activation is 'f' self._cell_activation = activations.get( cell_activation, range=cell_activation_range) self._memory_activation = activations.get( memory_activation, range=memory_activation_range) self._gate_activation = activations.get('sigmoid') # Set weight and bias configs self._weight_initializer = initializers.get( weight_initializer, range=weight_initial_range) self._use_cell_bias = use_cell_bias self._cell_bias_initializer = initializers.get( cell_bias_initializer, range=cell_bias_init_range) self._use_in_bias = use_in_bias self._in_bias_initializer = initializers.get(in_bias_initializer) self._use_out_bias = use_out_bias self._out_bias_initializer = initializers.get(out_bias_initializer) if kwargs.get('rule97', False): self._cell_bias_initializer = self._weight_initializer self._in_bias_initializer = self._weight_initializer # Additional options self._truncate = truncate self._forward_gate = forward_gate # ... self._num_splits = 3 self._output_scale = state_size self._h_size = (state_size * self._num_splits if self._forward_gate else state_size) # TODO: BETA self.compute_gradients = self.truncated_rtrl