Пример #1
0
    def _read_sector(self, arm, s, scope, num_heads):
        heads_sum = 0.
        data_list = []
        for i in range(num_heads):
            # Prepare
            net_head = self.dense(self.total_size, arm,
                                  scope + '_read_' + str(i))
            if self._diff_head:
                net_head, heads_sum = net_head - heads_sum, net_head + heads_sum
            # Read
            head_list = []

            def operator(state, net_h, size):
                h = tf.nn.softmax(net_h, axis=1)
                head_list.append(tf.reshape(h, [-1, size]))
                return tf.reduce_sum(state * h, axis=1, keepdims=True)

            reshape2 = lambda _, n: n
            data = self._binary_operate_over_groups(s,
                                                    net_head,
                                                    operator,
                                                    reshape2=reshape2)
            assert self.get_dimension(data) == self.num_groups
            data_list.append(data)
            # Concatenate head_list and register
            assert len(head_list) == len(self._groups)
            head = linker.concatenate(head_list)
            self._register_gate(scope + '_head_' + str(i), head)

        return linker.concatenate(data_list)
Пример #2
0
    def _distribute(self, s_bar, net_h, gutter=False, full_write=False):
        """This method should be used only for hd-write methods"""
        head_list = []

        def operator(s_block, h_block, size):
            # s_block.shape = [batch_size*n, 1]
            #       h.shape = [batch_size*n, s] or [batch_size*n, s+1]
            h = tf.nn.softmax(h_block)
            if gutter: h = h[:, :-1]
            # y.shape = [batch_size*n, s]
            y = s_block * h
            head_list.append(tf.reshape(h, [-1, size]))
            return y

        sizes1 = self.group_sizes if full_write else self.group_duplicates
        sizes2 = [s + 1
                  for s in self.group_sizes] if gutter else self.group_sizes
        reshape1_1 = lambda s, n: s if full_write else 1
        reshape1_2 = lambda s, n: s + int(gutter)
        data = self._binary_operate_over_groups(s_bar,
                                                net_h,
                                                operator,
                                                sizes1=sizes1,
                                                sizes2=sizes2,
                                                reshape1_1=reshape1_1,
                                                reshape1_2=reshape1_2)
        # Concatenate head_list to head
        assert len(head_list) == len(self._groups)
        head = linker.concatenate(head_list)
        return data, head
Пример #3
0
 def _binary_operate_over_groups(
     self, tensor1, tensor2, operator, sizes1=None, sizes2=None,
     reshape1_1=None, reshape1_2=None, reshape2=None):
   # Sanity check
   assert isinstance(tensor1, tf.Tensor) and isinstance(tensor2, tf.Tensor)
   if sizes1 is None: sizes1 = self.group_sizes
   if sizes2 is None: sizes2 = self.group_sizes
   # Split tensors
   splitted1 = linker.split_by_sizes(tensor1, sizes1)
   splitted2 = linker.split_by_sizes(tensor2, sizes2)
   output_list = []
   for (s, n), data1, data2 in zip(self._groups, splitted1, splitted2):
     # Reshape if necessary
     dim1_1 = reshape1_1(s, n) if callable(reshape1_1) else s
     dim1_2 = reshape1_2(s, n) if callable(reshape1_2) else s
     if n > 1:
       data1 = tf.reshape(data1, [-1, dim1_1])
       data2 = tf.reshape(data2, [-1, dim1_2])
     # Call operator
     num_args = len(inspect.getfullargspec(operator).args)
     if num_args == 2: args = []
     elif num_args == 3: args = [s * n]
     else: raise AssertionError(
       '!! Illegal operator with {} args'.format(num_args))
     data = operator(data1, data2, *args)
     # Reshape back if necessary
     dim2 = reshape2(s, n) if callable(reshape2) else s * n
     if n > 1: data = tf.reshape(data, [-1, dim2])
     # Add result to output list
     output_list.append(data)
   # Concatenate and return
   return linker.concatenate(output_list)
Пример #4
0
 def _operate_over_groups(
     self, tensor, operator, sizes=None, reshape1=None, reshape2=None):
   assert isinstance(tensor, tf.Tensor) and callable(operator)
   if sizes is None: sizes = self.group_sizes
   # Split tensor
   splitted = linker.split_by_sizes(tensor, sizes)
   output_list = []
   for (s, n), data in zip(self._groups, splitted):
     dim1 = reshape1(s, n) if callable(reshape1) else s
     if n > 1: data = tf.reshape(data, [-1, dim1])
     data = operator(data)
     dim2 = reshape2(s, n) if callable(reshape2) else s * n
     if n > 1: data = tf.reshape(data, [-1, dim2])
     output_list.append(data)
   # Concatenate and return
   return linker.concatenate(output_list)
Пример #5
0
    def elect(self, groups, votes):
        """Given a vector with group specification, one representative will be
       elected.
       groups = ((size1, num1), (size2, num2), ...)
       x.shape = [batch_size, Dx]
       y.shape = [batch_size, num_groups]
    """
        # Sanity check
        assert isinstance(groups, (list, tuple))
        groups = [g[:2] for g in groups]
        total_units = sum([s * n for s, n in groups])
        assert total_units == self.input_dim

        # Get votes
        # initializer = tf.constant_initializer(np.concatenate(
        #     [np.ones([1, s * n], dtype=np.float32) / s for s, n in groups], axis=1))
        if votes is None:
            initializer = 'glorot_uniform'
            votes = self._get_weights('V', [1, self.input_dim],
                                      initializer=initializer)

        # Calculate output
        splitted_x = linker.split(self.input_, groups)
        splitted_v = linker.split(votes, groups)
        output_list = []
        for (s, n), x, v in zip(groups, splitted_x, splitted_v):
            if s == 1:
                output_list.append(x)
                continue
            y = tf.multiply(v, x)
            if n > 1: y = tf.reshape(y, [-1, s])
            y = tf.reduce_sum(y, axis=1, keepdims=True)
            if n > 1: y = tf.reshape(y, [-1, n])
            output_list.append(y)

        return linker.concatenate(output_list)