def Quantized(): min_value, max_value = self._GetCurrentMinMax(state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.quantization.fake_quant_with_min_max_vars( x, min_value, max_value, num_bits=bits)
def Clipped(): clip_ratio = state[0] min_value, max_value = self._GetCurrentMinMax(state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.where(clip_ratio >= 0.0, (lambda: tf.clip_by_value(x, min_value, max_value))(), (lambda: x)())
def GetTensorRange(self, t_name, ts): # Always straddle a real zero point. if self.do_eval: # At eval/inference time, use the memorized range. # Important: Don't capture these variables in training mode so as to # avoid extra/unnecessary captures. min_var = tf.stop_gradient(self._GetQStateVar(t_name, 'min')) max_var = tf.stop_gradient(self._GetQStateVar(t_name, 'max')) return (min_var, max_var) # Calculate min/max for all tensors. batch_min = tf.minimum(tf.reduce_min(ts), 0.0) batch_max = tf.maximum(tf.reduce_max(ts), 0.0) return (tf.stop_gradient(batch_min), tf.stop_gradient(batch_max))
def _CreateTargetLambdas(self, atten_probs, source_lambdas_pair, source_paddings_pair, target_paddings_pair, smooth=0): """Compute target interpolation ratios. Args: atten_probs: A list containing two attention matrics. source_lambdas_pair: A list containing two source interpolation ratios. source_paddings_pair: A list containing two source paddings. target_paddings_pair: A list containing two target paddings smooth: A real value to smooth target interpolation ratios before normalization. Returns: source_lambdas_pair: Source interpolation ratios. input_lambdas: Interpolation ratios for target input embeddings. label_lambdas: Interpolation ratios for target labels. """ atten_probs_0 = tf.stop_gradient(atten_probs[0]) atten_probs_1 = tf.stop_gradient(atten_probs[1]) source_lambdas = source_lambdas_pair[0] other_source_lambdas = source_lambdas_pair[1] lambdas_0 = atten_probs_0 * tf.expand_dims( source_lambdas * (1.0 - source_paddings_pair[0]), 1) lambdas_0 = tf.reduce_sum(lambdas_0, -1) lambdas_0 = (lambdas_0 + smooth) * (1.0 - target_paddings_pair[0]) lambdas_1 = atten_probs_1 * tf.expand_dims( other_source_lambdas * (1.0 - source_paddings_pair[1]), 1) lambdas_1 = tf.reduce_sum(lambdas_1, -1) lambdas_1 = (lambdas_1 + smooth) * (1.0 - target_paddings_pair[1]) label_lambdas_0 = lambdas_0 / (lambdas_0 + lambdas_1 + 1e-9) label_lambdas = [label_lambdas_0, (1.0 - label_lambdas_0)] input_lambdas_0 = tf.pad(label_lambdas_0, [[0, 0], [1, 0]], constant_values=1.)[:, :-1] input_lambdas = [ input_lambdas_0 * (1. - target_paddings_pair[0]), (1.0 - input_lambdas_0) * (1. - target_paddings_pair[1]) ] return source_lambdas_pair, input_lambdas, label_lambdas
def _InputBatch(self): length = tf.reduce_prod(self.shape) counter = summary_utils.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed, *extra_inputs): """Implements Algorithm 1 in the revnet paper. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2. d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2, the total derivatives. f_seed: Scalar tensor. The step seed used in forward for the f block. g_seed: Scalar tensor. The step seed used in forward for the g block. The step seeds are needed for deterministic randomness, e.g. to ensure dropout generate the same random mask in forward and reverse_grad. *extra_inputs: additional inputs that will be passed to both f and g. No gradient will be computed for these inputs. Returns: A tuple of NestedMaps - inputs: .split1 and .split2 corresponding to x1 and x2. - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total derivatives with respect to inputs. - d_theta: has the same structure as theta. The total derivatives with respect to weights. """ # Stop gradient on the outputs to avoid circular symbolic dependency. y1 = tf.stop_gradient(outputs.split1) y2 = tf.stop_gradient(outputs.split2) dy1 = d_outputs.split1 dy2 = d_outputs.split2 # Computes the reverse. z1 = y1 py_utils.ResetStepSeed(g_seed) gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs) x2 = y2 - gz1 py_utils.ResetStepSeed(f_seed) fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs) x1 = z1 - fx2 # Computes the gradients. dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0] dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0] dgw = tf.gradients(gz1, theta.g_block.Flatten(), dy2, unconnected_gradients=tf.UnconnectedGradients.ZERO) dgw = theta.g_block.Pack(dgw) dfw = tf.gradients(fx2, theta.f_block.Flatten(), dz1, unconnected_gradients=tf.UnconnectedGradients.ZERO) dfw = theta.f_block.Pack(dfw) return (py_utils.NestedMap(split1=x1, split2=x2), py_utils.NestedMap(split1=dz1, split2=dx2), py_utils.NestedMap(f_block=dfw, g_block=dgw, global_step=tf.zeros_like( theta.global_step)))
def CrossReplicaConcat(local_tensor, tpu_cores: int, axis: int = 0, stop_cross_gradients: bool = False): """Concatenates a single local tensor across all TPU cores. This is mostly a fork of //nlp/neon/dual_encoder/utils/tpu_utils.py, with some additional functionality to support int64-typed inputs. Args: local_tensor: The local tensor to concatenate across cores. tpu_cores: The total number of TPU cores. axis: The axis to concatenate. stop_cross_gradients: Whether or not to stop gradients on cross-replica slices. Returns: The tensor concatenated across all replicas. """ # Handle int64 inputs as a special case since collective_permute() doesn't # natively support them. At a high level, we break each int64 into two 32-bit # parts, concatenate each part separately, and then recombine the result. # # Implementation notes: # - The "parts" have to be int32 because collective_permute doesn't support # uint32 inputs, either. # - uint64 <-> int64 casts also have to be avoided because XLA doesn't # know how to compile them for TPU. (Error: "While rewriting computation # to not contain X64 element types, XLA encountered an HLO for which this # rewriting is not implemented...") if local_tensor.dtype == tf.int64: low32 = tf.cast(local_tensor, tf.int32) high32 = tf.cast( tf.bitwise.bitwise_and(tf.bitwise.right_shift(local_tensor, 32), 0xffffffff), tf.int32) # Concatenate each int32 part. low32 = CrossReplicaConcat(low32, tpu_cores, axis=axis, stop_cross_gradients=stop_cross_gradients) high32 = CrossReplicaConcat(high32, tpu_cores, axis=axis, stop_cross_gradients=stop_cross_gradients) # Recombine high and low parts. Make the low part unsigned before upcasting # to avoid propagating its sign bit. low32 = tf.cast(tf.cast(low32, tf.uint32), tf.int64) high32 = tf.cast(high32, tf.int64) return tf.cast( tf.bitwise.bitwise_or(low32, tf.bitwise.left_shift(high32, 32)), tf.int64) all_tensors = [local_tensor] for rotation_index in range(tpu_cores - 1): permutation = tuple((source, (source + rotation_index + 1) % tpu_cores) for source in range(tpu_cores)) permuted_tensor = tf.raw_ops.CollectivePermute( input=local_tensor, source_target_pairs=permutation) if stop_cross_gradients: permuted_tensor = tf.stop_gradient(permuted_tensor) all_tensors.append(permuted_tensor) result = tf.concat(all_tensors, axis=axis) logging.info('TPU concat across %d cores; result shape %s', tpu_cores, result.shape) return result