def _link(self, x, **kwargs): y, pkg = linker.sparse_affine(x, self.num_neurons, self.heads, self.use_bit_max, self._logits_initializer, self._coef_initializer, self._use_bias, self._bias_initializer, return_package=True) # Encourage softmax activation to be saturated ds_penalty = self._kwargs.get('desaturate_penalty', 0.0) if ds_penalty > 0: a = pkg['activation'] a_bar = tf.subtract(1.0, a) context.add_loss_tensor(ds_penalty * tf.reduce_mean(tf.minimum(a, a_bar))) console.show_status( 'Desaturate penalty added in {}'.format( tf.get_variable_scope().name), '++') # Export variables if hub.export_sparse_weights: scope = '/'.join(tf.get_variable_scope().name.split('/')[1:]) # context.variables_to_export[scope + '/weights'] = pkg['weights'] # context.variables_to_export[scope + '/coef'] = pkg['coef'] context.weights_list.append(pkg['weights']) return y
def _register_gates(self): assert self.is_root for cell in self.rnn_cells: for k, g in cell._gate_dict.items(): if hub.export_gates: context.add_tensor_to_export(k, g) if hub.train_gates: coef = hub.gate_loss_strength context.add_loss_tensor(tf.multiply( coef, tf.reduce_sum(g), name='{}_loss'.format(k)))
def neurons(num, external_input, activation=None, memory=None, fc_memory=True, scope=None, use_bias=True, truncate=False, num_or_size_splits=None, weight_initializer='glorot_uniform', bias_initializer='zeros', weight_regularizer=None, bias_regularizer=None, activity_regularizer=None, **kwargs): """Analogous to tf.keras.layers.Dense""" # Get activation, initializers and regularizers if activation is not None: activation = activations.get(activation) weight_initializer = initializers.get(weight_initializer) bias_initializer = initializers.get(bias_initializer) weight_regularizer = regularizers.get(weight_regularizer) bias_regularizer = regularizers.get(bias_regularizer) activity_regularizer = regularizers.get(activity_regularizer) # a. Check prune configs if 'prune_frac' in kwargs.keys(): x_prune_frac, s_prune_frac = (kwargs['prune_frac'], ) * 2 else: x_prune_frac = kwargs.get('x_prune_frac', 0) s_prune_frac = kwargs.get('s_prune_frac', 0) prune_is_on = hub.pruning_rate_fc > 0.0 and x_prune_frac + s_prune_frac > 0 # b. Check sparse configs x_heads = kwargs.get('x_heads', 0) s_heads = kwargs.get('s_heads', 0) sparse_is_on = x_heads + s_heads > 0 # :: Decide to concatenate or not considering a and b # .. a if memory is None: should_concate = False elif prune_is_on: should_concate = x_prune_frac == s_prune_frac else: should_concate = fc_memory # .. b should_concate = should_concate and not sparse_is_on # separate_memory_neurons = memory is not None and not should_concate def get_weights(name, tensor, p_frac, heads): shape = [get_dimension(tensor), num] if prune_is_on and p_frac > 0: assert heads == 0 return get_weights_to_prune(name, shape, weight_initializer, p_frac) elif heads > 0: return _get_sparse_weights(shape[0], shape[1], heads, use_bit_max=True, coef_initializer=weight_initializer) else: return get_variable(name, shape, weight_initializer) def forward(): # Prepare a weight list for potential regularizer calculation weight_list = [] # Get x x = (tf.concat([external_input, memory], axis=1, name='x_concat_s') if should_concate else external_input) # - Calculate net input for x # .. get weights name = 'Wx' if separate_memory_neurons else 'W' Wx = get_weights(name, x, x_prune_frac, x_heads) weight_list.append(Wx) # .. append weights to context, currently only some extractors will use it context.weights_list.append(Wx) # .. do matrix multiplication net_y = get_matmul(truncate)(x, Wx) # - Calculate net input for memory and add to net_y if necessary if separate_memory_neurons: if not fc_memory: assert not (prune_is_on and s_prune_frac > 0) memory_dim = get_dimension(memory) assert memory_dim == num Ws = get_variable('Ws', [1, num], weight_initializer) net_s = get_multiply(truncate)(memory, Ws) else: assert prune_is_on or sparse_is_on Ws = get_weights('Ws', memory, s_prune_frac, s_heads) net_s = get_matmul(truncate)(memory, Ws) # Append Ws to weight list and add net_s to net_y weight_list.append(Ws) net_y = tf.add(net_y, net_s) # - Add bias if necessary b = None if use_bias: b = get_bias('bias', num, bias_initializer) net_y = tf.nn.bias_add(net_y, b) # - Activate and return if callable(activation): net_y = activation(net_y) return net_y, weight_list, b if scope is not None: with tf.variable_scope(scope): y, W_list, b = forward() else: y, W_list, b = forward() # Add regularizer if necessary if callable(weight_regularizer): context.add_loss_tensor( tf.add_n([weight_regularizer(w) for w in W_list])) if callable(bias_regularizer) and b is not None: context.add_loss_tensor(bias_regularizer(b)) if callable(activity_regularizer): context.add_loss_tensor(activity_regularizer(y)) # Split if necessary if num_or_size_splits is not None: return tf.split(y, num_or_size_splits=num_or_size_splits, axis=1) return y
def sparse_sog(self, axis, group_size): """Given x of shape (bs, dim_x) y = x @ (W_bar \odot C) where W_bar and C has shape (dim_x, dim_y), and C = \eta_{SxN}(C_bar, axis), \eta is the operator of softmax over groups `axis` may be passed from SparseSOG(..., axis, ...) -> neuro_base.sparse_sog(..., axis, ...) -> neural_array.add_kernel(..., axis=axis) -> PsiKernel(..., axis=axis) -> kernel_base.kwargs['axis] So does `group_size` """ S = group_size # Check dim and calculate N (num_groups) dim_to_be_partitioned = self.input_dim if axis == 0 else self.num_neurons assert dim_to_be_partitioned % S == 0 N = dim_to_be_partitioned // S # Prepare weight matrix W_bar W_bar = self._get_weights('W_bar', shape=[self.input_dim, self.num_neurons]) if S == 1: return self.input_ @ W_bar # .. make sure inputs are vectors assert len(self.input_.shape) == 2 # .. create connection matrix C according to axis # .. (While shape_C can be determined by 1 line of code, readability is # of more importance) if axis == 0: assert S * N == self.input_dim shape_C = [S, self.num_neurons * N] elif axis == 1: assert S * N == self.num_neurons shape_C = [self.input_dim * N, S] else: raise AssertionError('`axis` must be either 0 or 1') C_tilde = self._get_weights('C_tilde', shape=shape_C) C_bar = tf.nn.softmax(C_tilde, axis=axis, name='C_bar') C = tf.reshape(C_bar, shape=[self.input_dim, self.num_neurons], name='C') # assert all(tf.reduce_sum(C, axis) == N) W = tf.multiply(W_bar, C, name='W') # Codes for exporting weights if hub.export_sparse_weights: context.add_var_to_export('connection', C) # Encourage saturation if hub.saturation_penalty is not None and hub.saturation_penalty > 0: from tframe.losses import saturate_loss sta_loss = saturate_loss(C) context.add_loss_tensor(sta_loss) # TODO: STILL DEVELOPING # from tframe.losses import saturate_loss # sta_loss = saturate_loss(C, mu=1/S) * hub.saturation_penalty # vips = tf.reduce_max(C_bar, axis=axis) # right_loss = tf.reduce_mean(1. - vips) # left = C_bar[tf.less(C_bar, 1 / S)] # left_loss = tf.reduce_mean(left) # sta_loss = (left_loss + right_loss) * hub.saturation_penalty # context.add_loss_tensor(sta_loss) # Calculate output and return return tf.matmul(self.input_, W)