def graph_sampler(self, batch_size, seed, beta): #Same as sample method above but specialised for graph compilation sample = tf.zeros([batch_size, self.L, self.L, 1], tf.float32) tf_binomial = tf.random.stateless_binomial full_ones = tf.ones([batch_size], tf.int32) full_zeros = tf.zeros_like(full_ones) r = self.learn_range for i in range(self.L): for j in range(self.L): seed.assign((seed * 1664525 + 1013904223) % 2**31) sub_sample = sample[:, np.maximum(i - 1, 0):i + 1, np.maximum(j - r, 0):np. minimum(j + r + 1, self.L)] x_hat = self.call(sub_sample, beta) i_h = tfm.minimum(i, 1) j_h = tfm.minimum(j, r) probs = 0.5 if i == 0 and j == 0 else x_hat[:, i_h, j_h, 0] indices = tf.stack([ tf.range(batch_size), i * full_ones, j * full_ones, full_zeros ], 1) updates = tf_binomial([batch_size], seed, 1., probs, tf.float32) * 2 - 1 sample = tf.tensor_scatter_nd_add(sample, tf.cast(indices, tf.int32), updates) #x_hat = self.call(sample) if self.z2: seed.assign((seed * 1664525 + 1013904223) % 2**31) flip = tf_binomial([batch_size, 1, 1, 1], seed, 1., 0.5, tf.float32) * 2 - 1 sample = sample * flip return sample
def esat(T): T0 = np.float32(273.16) T00 = np.float32(253.16) omtmp = (T-T00)/(T0-T00) omega = tfm.maximum(np.float32(0.0),tfm.minimum(np.float32(1.0),omtmp)) return tf.where(T>T0,eliq(T),tf.where(T<T00,eice(T),(omega*eliq(T)+(1-omega)*eice(T))))
def _graph_update(self, sample, beta, seed, pos): batch_size = sample.shape[0] r = self.learn_range i, j = tf.unstack(pos) seed.assign((seed * 1664525 + 1013904223) % 2**31) begin = tf.stack([0, tfm.maximum(i - 1, 0), tfm.maximum(j - r, 0), 0]) end = tf.stack([batch_size, i + 1, tfm.minimum(j + r + 1, self.L), 1]) sub_sample = tf.strided_slice(sample, begin, end) x_hat = self.call(sub_sample, beta) i_h = tfm.minimum(i, 1) j_h = tfm.minimum(j, r) probs = 0.5 if i == 0 and j == 0 else x_hat[:, i_h, j_h, 0] indices = tf.stack([ tf.range(batch_size), i * self.full_ones, j * self.full_ones, self.full_zeros ], 1) updates = tf.random.stateless_binomial([batch_size], seed, 1., probs, tf.float32) * 2 - 1 return tf.tensor_scatter_nd_add(sample, tf.cast(indices, tf.int32), updates)
def sample(self, batch_size): sample = np.zeros([batch_size, self.L, self.L, 1], np.float32) r = self.learn_range for i in range(self.L): for j in range(self.L): sub_sample = sample[:, np.maximum(i - 1, 0):i + 1, np.maximum(j - r, 0):np. minimum(j + r + 1, self.L)] x_hat = self.call(sub_sample) i_h = tfm.minimum(i, 1) j_h = tfm.minimum(j, r) probs = 0.5 if i == 0 and j == 0 else x_hat[:, i_h, j_h, :] sample[:, i, j, :] = np.random.binomial( 1, probs, [batch_size, 1]) * 2 - 1 #x_hat = self.call(sample) if self.z2: flip = np.random.binomial(1, 0.5, [batch_size, 1, 1, 1]) * 2 - 1 sample = sample * flip return sample
def AlphaConstraint(w): "Constraints w to range [0,1]" w = tfm.abs(w) return tfm.minimum(w, 1.0)
def _delta_phi_tf(x, y): from tensorflow.math import abs, minimum import math pi = math.pi d = abs(x - y) return minimum(d, 2 * pi - d)
def _delta_phi_np(x, y): from numpy import abs, minimum import math pi = math.pi d = abs(x - y) return minimum(d, 2 * pi - d)
def __call__(self, step): new_lr = 2e-3 * \ minimum((step ** (-0.5)), step * (self.warmup_steps ** (-1.5))) return new_lr
def interpolate(grid, query_points, indexing="ij", name=None): """ Reference: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/dense_image_warp.py Similar to Matlab's interp2 function. Finds values for query points on a grid using bilinear interpolation. Args: grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. indexing: whether the query points are specified as row and column (ij), or Cartesian coordinates (xy). name: a name for the operation (optional). Returns: values: a 3-D `Tensor` with shape `[batch, N, channels]` Raises: ValueError: if the indexing mode is invalid, or if the shape of the inputs invalid. """ if indexing != "ij" and indexing != "xy": raise ValueError("Indexing mode must be \'ij\' or \'xy\'") with name_scope(name or "interpolate_bilinear"): grid = convert_to_tensor(grid) query_points = convert_to_tensor(query_points) if len(grid.shape) != 4: msg = "Grid must be 4 dimensional. Received size: " raise ValueError(msg + str(grid.shape)) if len(query_points.shape) != 3: raise ValueError("Query points must be 3 dimensional.") if query_points.shape[2] is not None and query_points.shape[2] != 2: raise ValueError("Query points must be size 2 in dim 2.") if grid.shape[1] is not None and grid.shape[1] < 2: raise ValueError("Grid height must be at least 2.") if grid.shape[2] is not None and grid.shape[2] < 2: raise ValueError("Grid width must be at least 2.") grid_shape = shape(grid) query_shape = shape(query_points) batch_size, height, width, channels = (grid_shape[0], grid_shape[1], grid_shape[2], grid_shape[3]) shape_list = [batch_size, height, width, channels] # pylint: disable=bad-continuation with control_dependencies([ assert_equal(query_shape[2], 2, message="Query points must be size 2 in dim 2.") ]): num_queries = query_shape[1] # pylint: enable=bad-continuation query_type = query_points.dtype grid_type = grid.dtype # pylint: disable=bad-continuation with control_dependencies([ assert_greater_equal( height, 2, message="Grid height must be at least 2."), assert_greater_equal(width, 2, message="Grid width must be at least 2."), ]): alphas = [] floors = [] ceils = [] index_order = [0, 1] if indexing == "ij" else [1, 0] unstacked_query_points = unstack(query_points, axis=2) # pylint: enable=bad-continuation for dim in index_order: with name_scope("dim-" + str(dim)): queries = unstacked_query_points[dim] size_in_indexing_dimension = shape_list[dim + 1] # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 # is still a valid index into the grid. max_floor = cast(size_in_indexing_dimension - 2, query_type) min_floor = constant(0.0, dtype=query_type) floor_val = minimum(maximum(min_floor, floor(queries)), max_floor) int_floor = cast(floor_val, int32) floors.append(int_floor) ceil = int_floor + 1 ceils.append(ceil) # alpha has the same type as the grid, as we will directly use alpha # when taking linear combinations of pixel values from the image. alpha = cast(queries - floor_val, grid_type) min_alpha = constant(0.0, dtype=grid_type) max_alpha = constant(1.0, dtype=grid_type) alpha = minimum(maximum(min_alpha, alpha), max_alpha) # Expand alpha to [b, n, 1] so we can use broadcasting # (since the alpha values don't depend on the channel). alpha = expand_dims(alpha, 2) alphas.append(alpha) # pylint: disable=bad-continuation with control_dependencies([ assert_less_equal( cast(batch_size * height * width, dtype=float32), np.iinfo(np.int32).max / 8.0, message="The image size or batch size is sufficiently " "large that the linearized addresses used by tf.gather " "may exceed the int32 limit.") ]): flattened_grid = reshape(grid, [batch_size * height * width, channels]) batch_offsets = reshape( tfrange(batch_size) * height * width, [batch_size, 1]) # pylint: enable=bad-continuation # This wraps tf.gather. We reshape the image data such that the # batch, y, and x coordinates are pulled into the first dimension. # Then we gather. Finally, we reshape the output back. It's possible this # code would be made simpler by using tf.gather_nd. def gather_fn(y_coords, x_coords, name): with name_scope("gather-" + name): linear_coordinates = (batch_offsets + y_coords * width + x_coords) gathered_values = gather(flattened_grid, linear_coordinates) return reshape(gathered_values, [batch_size, num_queries, channels]) # grab the pixel values in the 4 corners around each query point top_left = gather_fn(floors[0], floors[1], "top_left") top_right = gather_fn(floors[0], ceils[1], "top_right") bottom_left = gather_fn(ceils[0], floors[1], "bottom_left") bottom_right = gather_fn(ceils[0], ceils[1], "bottom_right") # now, do the actual interpolation with name_scope("interpolate"): interp_top = alphas[1] * (top_right - top_left) + top_left interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left interp = alphas[0] * (interp_bottom - interp_top) + interp_top return interp