def call(self, inputs): def apply_separate_filter_for_each_batch(inputs): kernel = inputs[1] x = K.expand_dims(inputs[0], axis=0) outputs = K.conv2d( x, kernel, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.bias is not None: bias = inputs[2] outputs = K.bias_add(outputs, bias, data_format=self.data_format) return K.squeeze(outputs, axis=0) x = inputs[0] classes = K.squeeze(inputs[1], axis=1) if self.bias is not None: outputs = K.map_fn(apply_separate_filter_for_each_batch, [x, K.gather(self.kernel, classes), K.gather(self.bias, classes)], dtype='float32') else: outputs = K.map_fn(apply_separate_filter_for_each_batch, [x, K.gather(self.kernel, classes)], dtype='float32') if self.activation is not None: return self.activation(outputs) return outputs
def slide_sum(paras): _embedded_results = paras[0] _categories_size_input = paras[1] def fn(elements): _embedded_results_ = elements[0] _categories_size_input_ = K.cast(K.squeeze(elements[1], axis=0), tf.int32) # 具体原因未知:bug: From merging shape 0 with other shapes. for 'model_2_2/map/while/strided_slice/stack_1' (op: 'Pack') with input shapes: [1], []. if len(_categories_size_input_.shape) == 1: _categories_size_input_ = K.cast(K.squeeze(_categories_size_input_, axis=0), tf.int32) # print('_embedded_results_1',_embedded_results_) # print('_categories_size_input_', _categories_size_input_) # def slice2D(x, index): # return x[150-index:, :] # embedded_results_ = Lambda(slice2D, arguments={'index':_categories_size_input_})(_embedded_results_) # 切片 2D _embedded_results_ = _embedded_results_[MAX_SEQUENCE_LENGTH - _categories_size_input_:, :] # 切片 2D # print('_embedded_results_2',_embedded_results_) _embedded_results_ = Lambda(lambda x: K.sum(x, axis=0))(_embedded_results_) # print('_embedded_results_3', _embedded_results_) return _embedded_results_ return K.map_fn(fn, (_embedded_results, _categories_size_input), dtype=(tf.float32))
def get_cluster_centroids(self): weight_min = tf.reduce_min(self.weights) weight_max = tf.reduce_max(self.weights) # Calculating interpolation nodes, +/- 0.01 is introduced to guarantee that # CDF will have 0 and 1 and the first and last value respectively. # The value 30 is a guess. We just need a sufficiently large number here # since we are going to interpolate values linearly anyway and the initial # guess will drift away. For these reasons we do not really # care about the granularity of the lookup. cdf_x_grid = tf.linspace(weight_min - 0.01, weight_max + 0.01, 30) f = TFCumulativeDistributionFunction(weights=self.weights) cdf_values = k.map_fn(f.get_cdf_value, cdf_x_grid) probability_space = tf.linspace(0 + 0.01, 1, self.number_of_clusters) # Use upper-bound algorithm to find the appropriate bounds matching_indices = tf.searchsorted(sorted_sequence=cdf_values, values=probability_space, side='right') # Interpolate linearly between every found indices I at position using I at # pos n-1 as a second point. The value of x is a new cluster centroid def get_single_centroid(i): i_clipped = tf.minimum(i, tf.size(cdf_values) - 1) i_previous = tf.maximum(0, i_clipped - 1) s = TFLinearEquationSolver(x1=cdf_x_grid[i_clipped], y1=cdf_values[i_clipped], x2=cdf_x_grid[i_previous], y2=cdf_values[i_previous]) y = cdf_values[i_clipped] single_centroid = s.solve_for_x(y) return single_centroid centroids = k.map_fn(get_single_centroid, matching_indices, dtype=tf.float32) cluster_centroids = tf.reshape(centroids, (self.number_of_clusters, )) return cluster_centroids
def __call__(self, w): w_shape = w.shape if w_shape.rank is None or w_shape.rank != 4: raise ValueError( 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) height, width, channels, kernels = w_shape w = K.reshape(w, (height, width, channels * kernels)) # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch # is supported. w = K.map_fn(self._kernel_constraint, K.stack(array_ops.unstack(w, axis=-1), axis=0)) return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1), (height, width, channels, kernels))