def scatter(self, points, indices, values, shape, duplicates_handling='undefined'): # Change indexing so batch number is included as first element of the index, for example: [0,31,24] indexes the first batch (batch 0) and 2D coordinates (31,24). z = tf.zeros(shape, dtype=values.dtype) if duplicates_handling == 'add': #Only for Tensorflow with custom gradient @tf.custom_gradient def scatter_density(points, indices, values): result = tf.tensor_scatter_add(z, indices, values) def grad(dr): return self.resample(gradient(dr, difference='central'), points), None, None return result, grad return scatter_density(points, indices, values) elif duplicates_handling == 'mean': # Won't entirely work with out of bounds particles (still counted in mean) count = tf.tensor_scatter_add(z, indices, tf.ones_like(values)) total = tf.tensor_scatter_add(z, indices, values) return (total / tf.maximum(1.0, count)) else: # last, any, undefined st = tf.SparseTensor(indices, values, shape) st = tf.sparse.reorder(st) # only needed if not ordered return tf.sparse.to_dense(st)
def _build_graph(self, int_group): if int_group == 0: P, Q, reg = self.P, self.Q, self.opt.reg_u else: P, Q, reg = self.Q, self.P, self.opt.reg_i start_x, next_x, rows, keys, vals = \ self.start_x, self.next_x, self.rows, self.keys, self.vals # compute ys Fgtr = tf.gather(Q, keys) coeff = self.vals * self.opt.alpha ys = tf.scatter_nd(tf.expand_dims(rows, axis=1), Fgtr * tf.expand_dims(coeff + 1, axis=1), shape=(next_x - start_x, self.opt.d)) # prepare cg _P = P[start_x:next_x] Axs = tf.matmul(_P, self.FF) + reg * _P dots = self._dot(tf.gather(_P, rows), Fgtr) Axs = tf.tensor_scatter_add( Axs, tf.expand_dims(rows, axis=1), Fgtr * tf.expand_dims(dots * coeff, axis=1)) rs = ys - Axs ps = rs rss_old = tf.reduce_sum(tf.square(rs), axis=1) # iterate cg steps for i in range(self.opt.num_cg_max_iters): Aps = tf.matmul(ps, self.FF) + ps * reg _dots = coeff * self._dot(tf.gather(ps, rows), Fgtr) Aps = tf.tensor_scatter_add(Aps, tf.expand_dims(rows, axis=1), Fgtr * tf.expand_dims(_dots, axis=1)) pAps = self._dot(Aps, ps) alphas = rss_old / (pAps + self.opt.eps) _P = _P + ps * tf.expand_dims(alphas, axis=1) rs = rs - tf.expand_dims(alphas, axis=1) * Aps rss_new = tf.reduce_sum(tf.square(rs), axis=1) betas = rss_new / (rss_old + self.opt.eps) ps = rs + (tf.expand_dims(betas, axis=1) * ps) rss_old = rss_new if int_group == 1: if self.opt.compute_loss_on_training: self.err = tf.reduce_sum(tf.square(vals - dots)) else: self.err = tf.constant(0.0, dtype=tf.float32) name = "updateP" if int_group == 0 else "updateQ" _update = P[start_x:next_x].assign(_P) with self.graph.control_dependencies([_update]): update = tf.constant(True) setattr(self, name, update) _FF = tf.assign(self.FF, tf.matmul(P, P, transpose_a=True)) with self.graph.control_dependencies([_FF]): FF = tf.constant(True) name = "precomputeP" if int_group == 0 else "precomputeQ" setattr(self, name, FF)
def scatter(self, indices, values, shape, duplicates_handling='undefined', outside_handling='undefined'): assert duplicates_handling in ('undefined', 'add', 'mean', 'any') assert outside_handling in ('discard', 'clamp', 'undefined') if duplicates_handling == 'undefined': pass # Change indexing so batch number is included as first element of the index, for example: [0,31,24] indexes the first batch (batch 0) and 2D coordinates (31,24). buffer = tf.zeros(shape, dtype=values.dtype) repetitions = [] for dim in range(len(indices.shape) - 1): if values.shape[dim] == 1: repetitions.append(indices.shape[dim]) else: assert indices.shape[dim] == values.shape[dim] repetitions.append(1) repetitions.append(1) values = self.tile(values, repetitions) if duplicates_handling == 'add': # Only for Tensorflow with custom spatial_gradient @tf.custom_gradient def scatter_density(points, indices, values): result = tf.tensor_scatter_add(buffer, indices, values) def grad(dr): return self.resample(gradient(dr, difference='central'), points), None, None return result, grad return scatter_density(points, indices, values) elif duplicates_handling == 'mean': # Won't entirely work with out of bounds particles (still counted in mean) count = tf.tensor_scatter_add(buffer, indices, tf.ones_like(values)) total = tf.tensor_scatter_add(buffer, indices, values) return total / tf.maximum(1.0, count) else: # last, any, undefined # indices = self.to_int(indices, int64=True) # st = tf.SparseTensor(indices, values, shape) # ToDo this only supports 2D shapes # st = tf.sparse.reorder(st) # only needed if not ordered # return tf.sparse.to_dense(st) count = tf.tensor_scatter_add(buffer, indices, tf.ones_like(values)) total = tf.tensor_scatter_add(buffer, indices, values) return total / tf.maximum(1.0, count)
def scatter_density(points, indices, values): result = tf.tensor_scatter_add(z, indices, values) def grad(dr): return self.resample(gradient(dr, difference='central'), points), None, None return result, grad
def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update): # pylint: disable=missing-docstring hessian_column_with_l2 = sparse_or_dense_matvecmul( hessian_unregularized_loss_outer, hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord), adjoint_a=True) if l2_regularizer is not None: hessian_column_with_l2 += _one_hot_like( hessian_column_with_l2, coord, on_value=2. * l2_regularizer) # Move the batch dimensions of `hessian_column_with_l2` to rightmost in # order to conform to `hess_matmul_x_update`. n = tf.rank(hessian_column_with_l2) perm = tf.roll(tf.range(n), shift=1, axis=0) hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2, perm=perm) # Update the entire batch at `coord` even if `delta` may be 0 at some # batch coordinates. In those cases, adding `delta` is a no-op. x_update = tf.tensor_scatter_add(x_update, [[coord]], [delta]) with tf.control_dependencies([x_update]): x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2 hess_matmul_x_update_ = (hess_matmul_x_update + delta * hessian_column_with_l2) # Hint that loop vars retain the same shape. x_update_diff_norm_sq_.set_shape( x_update_diff_norm_sq_.shape.merge_with( x_update_diff_norm_sq.shape)) hess_matmul_x_update_.set_shape( hess_matmul_x_update_.shape.merge_with( hess_matmul_x_update.shape)) return [ x_update_diff_norm_sq_, x_update, hess_matmul_x_update_ ]