def _read_sector(self, arm, s, scope, num_heads): heads_sum = 0. data_list = [] for i in range(num_heads): # Prepare net_head = self.dense(self.total_size, arm, scope + '_read_' + str(i)) if self._diff_head: net_head, heads_sum = net_head - heads_sum, net_head + heads_sum # Read head_list = [] def operator(state, net_h, size): h = tf.nn.softmax(net_h, axis=1) head_list.append(tf.reshape(h, [-1, size])) return tf.reduce_sum(state * h, axis=1, keepdims=True) reshape2 = lambda _, n: n data = self._binary_operate_over_groups(s, net_head, operator, reshape2=reshape2) assert self.get_dimension(data) == self.num_groups data_list.append(data) # Concatenate head_list and register assert len(head_list) == len(self._groups) head = linker.concatenate(head_list) self._register_gate(scope + '_head_' + str(i), head) return linker.concatenate(data_list)
def _distribute(self, s_bar, net_h, gutter=False, full_write=False): """This method should be used only for hd-write methods""" head_list = [] def operator(s_block, h_block, size): # s_block.shape = [batch_size*n, 1] # h.shape = [batch_size*n, s] or [batch_size*n, s+1] h = tf.nn.softmax(h_block) if gutter: h = h[:, :-1] # y.shape = [batch_size*n, s] y = s_block * h head_list.append(tf.reshape(h, [-1, size])) return y sizes1 = self.group_sizes if full_write else self.group_duplicates sizes2 = [s + 1 for s in self.group_sizes] if gutter else self.group_sizes reshape1_1 = lambda s, n: s if full_write else 1 reshape1_2 = lambda s, n: s + int(gutter) data = self._binary_operate_over_groups(s_bar, net_h, operator, sizes1=sizes1, sizes2=sizes2, reshape1_1=reshape1_1, reshape1_2=reshape1_2) # Concatenate head_list to head assert len(head_list) == len(self._groups) head = linker.concatenate(head_list) return data, head
def _binary_operate_over_groups( self, tensor1, tensor2, operator, sizes1=None, sizes2=None, reshape1_1=None, reshape1_2=None, reshape2=None): # Sanity check assert isinstance(tensor1, tf.Tensor) and isinstance(tensor2, tf.Tensor) if sizes1 is None: sizes1 = self.group_sizes if sizes2 is None: sizes2 = self.group_sizes # Split tensors splitted1 = linker.split_by_sizes(tensor1, sizes1) splitted2 = linker.split_by_sizes(tensor2, sizes2) output_list = [] for (s, n), data1, data2 in zip(self._groups, splitted1, splitted2): # Reshape if necessary dim1_1 = reshape1_1(s, n) if callable(reshape1_1) else s dim1_2 = reshape1_2(s, n) if callable(reshape1_2) else s if n > 1: data1 = tf.reshape(data1, [-1, dim1_1]) data2 = tf.reshape(data2, [-1, dim1_2]) # Call operator num_args = len(inspect.getfullargspec(operator).args) if num_args == 2: args = [] elif num_args == 3: args = [s * n] else: raise AssertionError( '!! Illegal operator with {} args'.format(num_args)) data = operator(data1, data2, *args) # Reshape back if necessary dim2 = reshape2(s, n) if callable(reshape2) else s * n if n > 1: data = tf.reshape(data, [-1, dim2]) # Add result to output list output_list.append(data) # Concatenate and return return linker.concatenate(output_list)
def _operate_over_groups( self, tensor, operator, sizes=None, reshape1=None, reshape2=None): assert isinstance(tensor, tf.Tensor) and callable(operator) if sizes is None: sizes = self.group_sizes # Split tensor splitted = linker.split_by_sizes(tensor, sizes) output_list = [] for (s, n), data in zip(self._groups, splitted): dim1 = reshape1(s, n) if callable(reshape1) else s if n > 1: data = tf.reshape(data, [-1, dim1]) data = operator(data) dim2 = reshape2(s, n) if callable(reshape2) else s * n if n > 1: data = tf.reshape(data, [-1, dim2]) output_list.append(data) # Concatenate and return return linker.concatenate(output_list)
def elect(self, groups, votes): """Given a vector with group specification, one representative will be elected. groups = ((size1, num1), (size2, num2), ...) x.shape = [batch_size, Dx] y.shape = [batch_size, num_groups] """ # Sanity check assert isinstance(groups, (list, tuple)) groups = [g[:2] for g in groups] total_units = sum([s * n for s, n in groups]) assert total_units == self.input_dim # Get votes # initializer = tf.constant_initializer(np.concatenate( # [np.ones([1, s * n], dtype=np.float32) / s for s, n in groups], axis=1)) if votes is None: initializer = 'glorot_uniform' votes = self._get_weights('V', [1, self.input_dim], initializer=initializer) # Calculate output splitted_x = linker.split(self.input_, groups) splitted_v = linker.split(votes, groups) output_list = [] for (s, n), x, v in zip(groups, splitted_x, splitted_v): if s == 1: output_list.append(x) continue y = tf.multiply(v, x) if n > 1: y = tf.reshape(y, [-1, s]) y = tf.reduce_sum(y, axis=1, keepdims=True) if n > 1: y = tf.reshape(y, [-1, n]) output_list.append(y) return linker.concatenate(output_list)