def _TopKGrad(op, grad, _): """Return the gradients for TopK. Args: op: The TopKOp for which we need to generate gradients. grad: Tensor. The gradients passed to the TopKOp. Returns: A list of two tensors, the first being the gradient w.r.t to the input and TopK, and the second being the gradient w.r.t. to the indices (all zero). """ in_shape = array_ops.shape(op.inputs[0]) ind_shape = array_ops.shape(op.outputs[1]) ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1) # Flatten indices to 2D. ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1) outerdim = array_ops.shape(ind_2d)[0] # Compute linear indices (flattened to 1D). ind = array_ops.reshape(ind_2d + array_ops.expand_dims( math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1]) # Substitute grad to appropriate locations and fill the rest with zeros, # finally reshaping it to the original input shape. return [array_ops.reshape( sparse_ops.sparse_to_dense(ind, array_ops.reshape( math_ops.reduce_prod(in_shape), [1]), array_ops.reshape(grad, [-1]), validate_indices=False), in_shape), array_ops.zeros( [], dtype=dtypes.int32)]
def testInstantError(self): if context.num_gpus(): # TODO(nareshmodi): make this test better self.skipTest("Gather doesn't do index checking on GPUs") with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7)
def quantiles_ready(): """The subgraph for when the quantiles are ready.""" quantized_feature = quantile_ops.quantiles([sparse_column_values], [], [quantile_buckets], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) quantized_feature = array_ops.reshape(quantized_feature, [-1]) example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) example_indices = array_ops.squeeze(example_indices, [1]) filtered_gradients = array_ops.gather(gradients, example_indices) filtered_hessians = array_ops.gather(hessians, example_indices) filtered_partition_ids = array_ops.gather(example_partition_ids, example_indices) unique_partitions, mapped_partitions = array_ops.unique( example_partition_ids) # Compute aggregate stats for each partition. per_partition_gradients = math_ops.unsorted_segment_sum( gradients, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( hessians, mapped_partitions, array_ops.size(unique_partitions)) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) partition_ids = array_ops.concat( [unique_partitions, filtered_partition_ids], 0) filtered_gradients = array_ops.concat( [per_partition_gradients, filtered_gradients], 0) filtered_hessians = array_ops.concat( [per_partition_hessians, filtered_hessians], 0) bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
def _apply_sparse_shared(self, grad, var, indices, scatter_add, scatter_update): beta1_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_slice = array_ops.gather(m, indices) m_t_slice = m_slice * beta1_t + grad * (1 - beta1_t) with ops.control_dependencies([m_t_slice]): m_t = scatter_update(m, indices, m_t_slice) # u_t = max(beta2 * u, abs(g_t)) v = self.get_slot(var, "v") v_slice = array_ops.gather(v, indices) v_t_slice = math_ops.maximum(v_slice * beta2_t, math_ops.abs(grad)) with ops.control_dependencies([v_t_slice]): v_t = scatter_update(v, indices, v_t_slice) # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t var_slice = -lr_t / (1 - beta1_power) * (m_t_slice / (v_t_slice + epsilon_t)) with ops.control_dependencies([var_slice]): var_update = scatter_add(var, indices, var_slice) return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) m_t_slice = tf.gather(m_t, grad.indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_prime = self.get_slot(var, "v_prime") v_t_slice = tf.gather(v_t, grad.indices) v_prime_slice = tf.gather(v_prime, grad.indices) v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice)) v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
def f(params): index_values = [1, 3, 5, 6] indices = constant_op.constant(index_values, name="i") y = array_ops.gather(params, indices, name="y") index_values2 = [0, 2] indices2 = constant_op.constant(index_values2, name="i2") return array_ops.gather(y, indices2, name="y2")
def _sparse_moving_average(self, x_tm1, idxs, b_t_, name, beta=.9): """ Creates a moving average for a sparse variable. Inputs: x_tm1: the associated parameter (e.g. a weight matrix) idxs: the tensor representing the indices used b_t_: the value to accumulate (e.g. slices of the gradient) name: a string to use to retrieve it later (e.g. 'm') beta: the decay factor (defaults to .9) Outputs: a_t: the average after moving (same shape as x_tm1, not b_t_) t: the internal timestep (used to correct initialization bias) """ a_tm1 = self._zeros_slot(x_tm1, '%s' % name, self._name) a_tm1_ = array_ops.gather(a_tm1, idxs) tm1 = self._zeros_idx_slot(x_tm1, '%s/tm1' % name, self._name) tm1_ = array_ops.gather(tm1, idxs) t = state_ops.scatter_add(tm1, idxs, tm1_*0+1, use_locking=self._use_locking) t_ = array_ops.gather(t, idxs) if beta < 1: beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name) beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_) else: beta_t_ = tm1_/t_ a_t = state_ops.scatter_update(a_tm1, idxs, beta_t_*a_tm1_, use_locking=self._use_locking) a_t = state_ops.scatter_add(a_t, idxs, (1-beta_t)*b_t_, use_locking=self._use_locking) return a_t, t
def _resource_apply_sparse(self, grad, var, indices): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad m_update_op = resource_variable_ops.resource_scatter_update(m.handle, indices, m_t_slice) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") v_t_slice = (beta2_t * array_ops.gather(v, indices) + (1 - beta2_t) * math_ops.square(grad)) v_update_op = resource_variable_ops.resource_scatter_update(v.handle, indices, v_t_slice) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, indices, var_slice) return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
def loop_fn(i): x1 = array_ops.gather(x, i) y1 = array_ops.gather(y, i) outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)] del output_dtypes[:] output_dtypes.extend([t.dtype for t in outputs]) return outputs
def _apply_sparse(self, g_t, x_tm1, prepare): """""" idxs, idxs_ = array_ops.unique(g_t.indices) g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs)) updates = [] if self._mu > 0: m_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', self._mu) m_t_ = array_ops.gather(m_and_t[0], idxs) gamma_t = ops.convert_to_tensor(self._gamma) m_bar_t_ = (1-gamma_t)*m_t_ + gamma_t*g_t_ updates.extend(m_and_t) else: m_bar_t_ = g_t_ if self._ups > 0: v_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', self._ups) v_t_ = array_ops.gather(v_and_t[0], idxs) eps_t = ops.convert_to_tensor(self._eps) v_bar_t_ = math_ops.sqrt(v_t_ + eps_t) updates.extend(v_and_t) else: v_bar_t_ = 1. lr_t = ops.convert_to_tensor(self._lr) s_t_ = lr_t * m_bar_t_ / v_bar_t_ return [[s_t_, x_tm1, idxs, g_t]] + updates
def loop_fn(i): image = array_ops.gather(images, i) label = array_ops.gather(labels, i) logits = array_ops.reshape(model(image, training=training), [-1]) loss = losses.softmax_cross_entropy( logits=logits, onehot_labels=label, reduction=losses.Reduction.NONE) return gradient_ops.gradients(loss, variables.trainable_variables())
def loop_fn(i): loop_inputs = [ array_ops.expand_dims(array_ops.gather(x, i), 0) for x in inputs ] loop_init_state = rnn_cell.LSTMStateTuple( *[array_ops.expand_dims(array_ops.gather(x, i), 0) for x in init_state]) return model_fn(loop_inputs, loop_init_state)
def _apply_sparse(self, grad, var): if len(grad.indices.get_shape()) == 1: grad_indices = grad.indices grad_values = grad.values else: grad_indices = array_ops.reshape(grad.indices, [-1]) grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value]) gidxs, metagidxs = array_ops.unique(grad_indices) sizegidxs = array_ops.size(gidxs) gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs) # m_t = mu * m + (1 - mu) * g_t m = self.get_slot(var, "m") m_scaled_g_values = gvals * (1 - self._mu_t) m_t = state_ops.scatter_update(m, gidxs, array_ops.gather(m, gidxs) * self._mu_t, use_locking=self._use_locking) m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values, use_locking=self._use_locking) m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power) # m_bar = mu * m_t + (1 - mu) * g_t m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power) var_update = state_ops.scatter_sub(var, gidxs, self._lr_t * m_bar, use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t])
def testHigherRankMaxNorm(self): np.random.seed(8) with self.test_session(): for params_shape in (12,), (6, 3), (6, 2, 3): # Test embedding rank 0, 1, 2. # Note: the first dimension must be a common multiple of procs below. params = 2 * np.ones(params_shape) params_norm = params / np.sqrt( np.sum( params * params, tuple(range(params.ndim)[1:]), keepdims=True)) for ids_shape in (), (3), (4, 3), (2, 3, 4): ids = np.random.randint( params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(ids_shape) # Compare nonsharded to gather simple = embedding_ops.embedding_lookup( params, ids, max_norm=1.0).eval() self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval()) # Run a few different sharded versions. for procs in 1, 2, 3: stride = procs * math_ops.range(params.shape[0] // procs) split_params = [ array_ops.gather(params, stride + p) for p in xrange(procs) ] sharded = embedding_ops.embedding_lookup( split_params, ids, max_norm=1.0).eval() self.assertAllEqual(simple, sharded)
def testTransform(self): # This tests all combinations of: # - ids rank 0, 1, >1 # - params sharded/unsharded # It always applies max_norm. np.random.seed(8) l2_norm = 2. with self.test_session(): # Param values are in [l2_norm, l2_norm+1) so it will always clip. params = np.random.rand(6, 3) + l2_norm params_norm = l2_norm * params / np.sqrt( np.sum(params * params, axis=1, keepdims=True)) # Compute the norm of each embedding. This will change the embedding # rank to 0. params_norm = np.linalg.norm(params_norm, axis=1) transform = lambda x: linalg_ops.norm(x, axis=1) for ids_shape in (), (3), (4, 3), (2, 3, 4): # Test ids rank 0, 1, 2, 3. ids = np.random.randint( params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(ids_shape) # Compare nonsharded to gather. simple = embedding_ops._embedding_lookup_and_transform( params, ids, max_norm=l2_norm, transform_fn=transform).eval() self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval()) # Run a few different sharded versions. for procs in 1, 2, 3: stride = procs * math_ops.range(params.shape[0] // procs) split_params = [ array_ops.gather(params, stride + p) for p in xrange(procs) ] sharded = embedding_ops._embedding_lookup_and_transform( split_params, ids, max_norm=l2_norm, transform_fn=transform).eval() self.assertAllEqual(simple, sharded)
def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m := beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1 - beta1_t) * grad.values, use_locking=self._use_locking) # v := beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1 - beta2_t) * math_ops.square(grad.values), use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def testHigherRank(self): # We check that scalar and empty indices shapes work as well shape = (2, 1, 3, 2) for indices_shape in (), (0,), (2, 0), (2, 3): for dtype in _TEST_TYPES: for axis in range(len(shape)): params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) with self.cached_session(use_gpu=True) as sess: tf_params = constant_op.constant(params) tf_indices = constant_op.constant(indices) # Check that both positive and negative indices for axis work. tf_axis = constant_op.constant(axis) tf_negative_axis = constant_op.constant(-len(shape) + axis) gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) gather_negative_axis = array_ops.gather( tf_params, tf_indices, axis=tf_negative_axis) gather_value, gather_negative_axis_value = sess.run( [gather, gather_negative_axis]) gather_np = np.take(params, indices, axis) self.assertAllEqual(gather_np, gather_value) self.assertAllEqual(gather_np, gather_negative_axis_value) expected_shape = (params.shape[:axis] + indices.shape + params.shape[axis + 1:]) self.assertEqual(expected_shape, gather.shape) self.assertEqual(expected_shape, gather_negative_axis.shape) # Test gradients gather_grad = np.random.randn( *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) if dtype.is_complex: gather_grad -= 1j * gather_grad params_grad, indices_grad, axis_grad = gradients_impl.gradients( gather, [tf_params, tf_indices, tf_axis], gather_grad) self.assertEqual(indices_grad, None) self.assertEqual(axis_grad, None) if dtype.is_integer: self.assertEqual(params_grad, None) continue # For axis 0, we are able to create an efficient IndexedSlices for # the gradient. if axis == 0: self.assertEqual(type(params_grad), ops.IndexedSlices) params_grad = ops.convert_to_tensor(params_grad) correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) outer_dims = axis inner_dims = len(shape) - axis - 1 gather_grad = gather_grad.reshape( shape[:axis] + (indices.size,) + shape[axis + 1:]) for source_index, dest_index in enumerate(indices.flat): dest_slice = ((slice(None),) * outer_dims + (dest_index,) + (slice(None),) * inner_dims) source_slice = ((slice(None),) * outer_dims + (source_index,) + (slice(None),) * inner_dims) correct_params_grad[dest_slice] += gather_grad[source_slice] self.assertAllClose( correct_params_grad, self.evaluate(params_grad), atol=2e-6, rtol=2e-6)
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) epsilon_t = self._get_hyper('epsilon', var_dtype) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, 'm') m_slice = array_ops.gather(m, indices) m_t_slice = m_slice * beta_1_t + grad * (1 - beta_1_t) with ops.control_dependencies([m_t_slice]): m_t = self._resource_scatter_update(m, indices, m_t_slice) # u_t = max(beta2 * u, abs(g_t)) v = self.get_slot(var, 'v') v_slice = array_ops.gather(v, indices) v_t_slice = math_ops.maximum(v_slice * beta_2_t, math_ops.abs(grad)) with ops.control_dependencies([v_t_slice]): v_t = self._resource_scatter_update(v, indices, v_t_slice) # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t var_slice = -lr_t / (1 - beta_1_power) * ( m_t_slice / (v_t_slice + epsilon_t)) with ops.control_dependencies([var_slice]): var_update = self._resource_scatter_add(var, indices, var_slice) return control_flow_ops.group(*[var_update, m_t, v_t])
def testBadIndicesCPU(self): with test_util.force_cpu(): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): self.evaluate(array_ops.gather(params, [[7]], axis=0)) with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): self.evaluate(array_ops.gather(params, [[7]], axis=1))
def _prepare(self, grads_and_vars): """""" if self._lr is None: sTy = 0 sTs = 0 yTy = 0 for g_t, x_tm1 in grads_and_vars: if g_t is None: continue with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): if isinstance(g_t, ops.Tensor): g_tm1 = self.get_slot(x_tm1, 'g') s_tm1 = self.get_slot(x_tm1, 's') y_t = (g_t-g_tm1) sTy += math_ops.reduce_sum(s_tm1*y_t) sTs += math_ops.reduce_sum(s_tm1**2) yTy += math_ops.reduce_sum(y_t**2) else: idxs, idxs_ = array_ops.unique(g_t.indices) g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs)) g_tm1 = self.get_slot(x_tm1, 'g') g_tm1_ = array_ops.gather(g_tm1, idxs) s_tm1 = self.get_slot(x_tm1, 's') s_tm1_ = array_ops.gather(s_tm1, idxs) y_t_ = (g_t_-g_tm1_) sTy += math_ops.reduce_sum(s_tm1_*y_t_) sTs += math_ops.reduce_sum(s_tm1_**2) yTy += math_ops.reduce_sum(y_t_**2) sTy = math_ops.abs(sTy) self._lr = sTs / (sTy + self._eps)
def testString(self): params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) with self.cached_session(): self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0).eval()) self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1).eval())
def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) epsilon_t = self._get_hyper('epsilon', var_dtype) lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, 'm') m_scaled_g_values = grad * (1 - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) # m_bar = (1 - beta1) * g_t + beta1 * m_t m_bar = m_scaled_g_values + beta_1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, 'v') v_scaled_g_values = (grad * grad) * (1 - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) v_t_slice = array_ops.gather(v_t, indices) v_sqrt = math_ops.sqrt(v_t_slice) var_update = self._resource_scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t)) return control_flow_ops.group(*[var_update, m_bar, v_t])
def testBadIndicesCPU(self): with self.session(use_gpu=False): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): array_ops.gather(params, [[7]], axis=0).eval() with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): array_ops.gather(params, [[7]], axis=1).eval()
def testHigherRankMaxNorm(self): np.random.seed(8) with self.cached_session(): for params_shape in (12,), (6, 3), (6, 2, 3): # Test embedding rank 0, 1, 2. # Note: the first dimension must be a common multiple of procs below. params = 2 * np.ones(params_shape) params_norm = params / np.sqrt( np.sum( params * params, tuple(range(params.ndim)[1:]), keepdims=True)) for ids_shape in (), (3), (4, 3), (2, 3, 4): ids = np.random.randint( params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(ids_shape) # Compare nonsharded to gather simple = embedding_ops.embedding_lookup( params, ids, max_norm=1.0).eval() # assertAllClose is used here as different implementations of sqrt may # be used to compute each of the values being compared. For example, # on AVX512 builds the embedding operation makes use of Eigen's fast # vectorized square root algorithm for doubles. These different # implementations of sqrt are not guaranteed to produce exactly the # same results. Therefore, an exact comparison cannot be made. self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval()) # Run a few different sharded versions. for procs in 1, 2, 3: stride = procs * math_ops.range(params.shape[0] // procs) split_params = [ array_ops.gather(params, stride + p) for p in xrange(procs) ] sharded = embedding_ops.embedding_lookup( split_params, ids, max_norm=1.0).eval() self.assertAllEqual(simple, sharded)
def _check_shapes_dynamic(self, operator, v, diag): """Return (v, diag) with Assert dependencies, which check shape.""" checks = [] with ops.op_scope([operator, v, diag], 'check_shapes'): s_v = array_ops.shape(v) r_op = operator.rank() r_v = array_ops.rank(v) if diag is not None: s_d = array_ops.shape(diag) r_d = array_ops.rank(diag) # Check tensor rank. checks.append(check_ops.assert_rank(v, r_op)) if diag is not None: checks.append(check_ops.assert_rank(diag, r_op - 1)) # Check batch shape checks.append(check_ops.assert_equal( operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2]))) if diag is not None: checks.append(check_ops.assert_equal( operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1]))) # Check event shape checks.append(check_ops.assert_equal( operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2))) if diag is not None: checks.append(check_ops.assert_equal( array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1))) v = control_flow_ops.with_dependencies(checks, v) if diag is not None: diag = control_flow_ops.with_dependencies(checks, diag) return v, diag
def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # m_bar = (1 - beta1) * g_t + beta1 * m_t m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) v_t_slice = array_ops.gather(v_t, indices) v_sqrt = math_ops.sqrt(v_t_slice) var_update = scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t)) return control_flow_ops.group(*[var_update, m_bar, v_t])
def _apply_sparse_shared(self, grad, var, indices, scatter_update, scatter_sub): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") m_t = scatter_update(m, indices, beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") v_t = scatter_update(v, indices, beta2_t * array_ops.gather(v, indices) + (1 - beta2_t) * math_ops.square(grad)) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) m_t_slice = array_ops.gather(m_t, indices) v_t_slice = array_ops.gather(v_t, indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = scatter_sub(var, indices, lr * m_t_slice / denominator_slice) return control_flow_ops.group(var_update, m_t, v_t)
def _SparseUpdate(variable, gradients, accum, linear, base_lr, lr_power, l1, l2): """Sparse Update "variable", "accum", "linear" based on sparse "gradients". See the description in _Update. Args: variable: A Variable. gradients: A Sparse Tensor accum: A Variable containing the sum of the squares of gradients. linear: A Variable containing approximation info. base_lr: A constant represents base learning rate. lr_power: A constant is used to adjust learning rate. l1: A constant represents l1 regularization strength. l2: A constant represents l2 regularization strength. Returns: A group op including three ScatterUpdate ops: 1. ScatterUpdate for "accum" 2. ScatterUpdate for "linear" 3. ScatterUpdate for "variable" """ assert isinstance(gradients, ops.IndexedSlices) with ops.name_scope("sparse_update_" + variable.op.name) as scope: dtype = variable.dtype.base_dtype base_lr = ops.convert_to_tensor(base_lr, dtype=dtype) lr_power = ops.convert_to_tensor(lr_power, dtype=dtype) l1 = ops.convert_to_tensor(l1, dtype=dtype) l2 = ops.convert_to_tensor(l2, dtype=dtype) # Compute the new value for the accumulator previous_accum = array_ops.gather(accum, gradients.indices) sqr_grad = gradients.values * gradients.values accum_updated = sqr_grad + previous_accum # Compute the new linear neg_lr_power = math_ops.neg(lr_power) sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow( previous_accum, neg_lr_power) sigma /= base_lr variable_slice = array_ops.gather(variable, gradients.indices) proximal_adjust = sigma * variable_slice linear_slice = array_ops.gather(linear, gradients.indices) linear_updated = linear_slice + gradients.values - proximal_adjust # Compute the new "variable" variable_updated = _Compute(accum_updated, linear_updated, base_lr, lr_power, l1, l2) with ops.control_dependencies([sigma]): accum_update_op = state_ops.scatter_update(accum, gradients.indices, accum_updated) linear_update_op = state_ops.scatter_update(linear, gradients.indices, linear_updated) variable_update_op = state_ops.scatter_update(variable, gradients.indices, variable_updated) group_op = control_flow_ops.group(linear_update_op, accum_update_op, variable_update_op, name=scope) return group_op
def testUInt32AndUInt64(self): for unsigned_type in (dtypes.uint32, dtypes.uint64): params = self._buildParams( np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type) with self.cached_session(): self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0).eval()) self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
def testSkipEagerErrors(self): if context.executing_eagerly(): return with self.assertRaisesRegexp(ValueError, r"tf\.gather does not allow.*"): array_ops.gather( params=[1, 2], batch_dims=1, indices=array_ops.placeholder(dtypes.int32))
def loop_fn(i, use_pfor): image = array_ops.gather(images, i) logits = array_ops.reshape(model(image, training=training), [-1]) return gradients.jacobian( logits, variables.trainable_variables(), use_pfor=use_pfor)
def _forward(self, x): return array_ops.gather(x, self.permutation, axis=-1)
def loop_fn(i, use_pfor): inp_i = array_ops.expand_dims(array_ops.gather(inp, i), 0) output = array_ops.reshape(model(inp_i), [-1]) return gradients.jacobian( output, variables.trainable_variables(), use_pfor=use_pfor)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_dtype = var.dtype.base_dtype lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype) m = self.get_slot(var, 'm') v = self.get_slot(var, 'v') local_step = math_ops.cast(self.iterations + 1, var_dtype) next_step = math_ops.cast(self.iterations + 2, var_dtype) decay_base = math_ops.cast(0.96, var_dtype) # Learning rate multipliers if self.lr_multipliers is not None: lr_t = _apply_lr_multiplier(self, lr_t, var) momentum_cache_t = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * local_step))) momentum_cache_t_1 = beta_1_t * ( 1. - 0.5 * (math_ops.pow(decay_base, self._initial_decay * next_step))) m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * momentum_cache_t if var_dtype is self._m_cache.dtype: m_schedule_new = array_ops.identity( state_ops.assign(self._m_cache, m_schedule_new, use_locking=self._use_locking)) m_schedule_next = m_schedule_new * momentum_cache_t_1 m_scaled_g_values = grad * (1. - beta_1_t) m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) m_t_slice = array_ops.gather(m_t, indices) m_t_prime = m_t_slice / (1. - m_schedule_next) g_prime = grad / (1. - m_schedule_new) m_t_bar = (1. - momentum_cache_t) * g_prime + (momentum_cache_t_1 * m_t_prime) v_scaled_g_values = (grad * grad) * (1. - beta_2_t) v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) v_t_slice = array_ops.gather(v_t, indices) v_t_prime_denominator = 1. - math_ops.pow(beta_2_t, local_step) v_t_prime = v_t_slice / v_t_prime_denominator v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + epsilon_t var_t = self._resource_scatter_add( var, indices, -self.eta_t * lr_t * m_t_bar / v_prime_sqrt_plus_eps) # Weight decays if var.name in self.weight_decays.keys(): var_t = _apply_weight_decays(self, var, var_t) var_update = state_ops.assign(var, var_t, use_locking=self._use_locking) # Cosine annealing (iteration_done, t_cur_update, eta_t_update) = _update_t_cur_eta_t_v2(self, lr_t, var) if iteration_done and not self._init_notified: self._init_notified = True updates = [var_update, m_t_bar, v_t] if iteration_done: updates += [t_cur_update] if self.use_cosine_annealing and iteration_done: updates += [eta_t_update] return control_flow_ops.group(*updates)
def training_graph(self, input_data, input_labels, num_trainers=1, trainer_id=0, **tree_kwargs): """Constructs a TF graph for training a random forest. Args: input_data: A tensor or dict of string->Tensor for input data. input_labels: A tensor or placeholder for labels associated with input_data. num_trainers: Number of parallel trainers to split trees among. trainer_id: Which trainer this instance is. **tree_kwargs: Keyword arguments passed to each tree's training_graph. Returns: The last op in the random forest training graph. Raises: NotImplementedError: If trying to use bagging with sparse features. """ processed_dense_features, processed_sparse_features, data_spec = ( data_ops.ParseDataTensorOrDict(input_data)) if input_labels is not None: labels = data_ops.ParseLabelTensorOrDict(input_labels) data_spec = data_spec or self.get_default_data_spec(input_data) tree_graphs = [] trees_per_trainer = self.params.num_trees / num_trainers tree_start = int(trainer_id * trees_per_trainer) tree_end = int((trainer_id + 1) * trees_per_trainer) for i in range(tree_start, tree_end): with ops.device(self.variables.device_dummies[i].device): seed = self.params.base_random_seed if seed != 0: seed += i # If using bagging, randomly select some of the input. tree_data = processed_dense_features tree_labels = labels if self.params.bagging_fraction < 1.0: # TODO(gilberth): Support bagging for sparse features. if processed_sparse_features is not None: raise NotImplementedError( 'Bagging not supported with sparse features.') # TODO(thomaswc): This does sampling without replacement. Consider # also allowing sampling with replacement as an option. batch_size = array_ops.strided_slice( array_ops.shape(processed_dense_features), [0], [1]) r = random_ops.random_uniform(batch_size, seed=seed) mask = math_ops.less( r, array_ops.ones_like(r) * self.params.bagging_fraction) gather_indices = array_ops.squeeze(array_ops.where(mask), axis=[1]) # TODO(thomaswc): Calculate out-of-bag data and labels, and store # them for use in calculating statistics later. tree_data = array_ops.gather(processed_dense_features, gather_indices) tree_labels = array_ops.gather(labels, gather_indices) if self.params.bagged_features: if processed_sparse_features is not None: raise NotImplementedError( 'Feature bagging not supported with sparse features.' ) tree_data = self._bag_features(i, tree_data) tree_graphs.append(self.trees[i].training_graph( tree_data, tree_labels, seed, data_spec=data_spec, sparse_features=processed_sparse_features, **tree_kwargs)) return control_flow_ops.group(*tree_graphs, name='train')
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with ops.name_scope(name or 'lu_solve'): lower_upper = ops.convert_to_tensor( lower_upper, dtype_hint=dtypes.float32, name='lower_upper') perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with ops.control_dependencies(assertions): lower_upper = array_ops.identity(lower_upper) perm = array_ops.identity(perm) rhs = array_ops.identity(rhs) if (rhs.shape.rank == 2 and perm.shape.rank == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = array_ops.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = array_ops.shape(rhs) broadcast_batch_shape = array_ops.broadcast_dynamic_shape( rhs_shape[:-2], array_ops.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = array_ops.broadcast_to( math_ops.range(broadcast_batch_size)[:, array_ops.newaxis], [broadcast_batch_size, d]) broadcast_perm = array_ops.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape) lower = set_diag( band_part(lower_upper, num_lower=-1, num_upper=0), array_ops.ones( array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return triangular_solve( lower_upper, # Only upper is accessed. triangular_solve(lower, permuted_rhs), lower=False)
def boolean_mask(data, mask, name=None): """Applies a boolean mask to `data` without flattening the mask dimensions. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. Note that `output` preserves the mask dimensions `a1...aA`; this differs from `tf.boolean_mask`, which flattens those dimensions. Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True])).to_list() [[1, 2, 3], [5, 6]] """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum( int_mask, axis=1) splits.append( ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask) # Add the ragged `splits` back to the result. masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids( data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask) return ragged_tensor.RaggedTensor.from_row_splits(masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero(mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape( masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range( split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def decoder_fn(time, cell_state, cell_input, cell_output, context_state): """Decoder function used in the `dynamic_rnn_decoder` for inference. The main difference between this decoder function and the `decoder_fn` in `attention_decoder_fn_train` is how `next_cell_input` is calculated. In decoder function we calculate the next input by applying an argmax across the feature dimension of the output from the decoder. This is a greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) use beam-search instead. Args: time: positive integer constant reflecting the current timestep. cell_state: state of RNNCell. cell_input: input provided by `dynamic_rnn_decoder`. cell_output: output of RNNCell. context_state: context state provided by `dynamic_rnn_decoder`. Returns: A tuple (done, next state, next input, emit output, next context state) where: done: A boolean vector to indicate which sentences has reached a `end_of_sequence_id`. This is used for early stopping by the `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with all elements as `true` is returned. next state: `cell_state`, this decoder function does not modify the given state. next input: The embedding from argmax of the `cell_output` is used as `next_input`. emit output: If `output_fn is None` the supplied `cell_output` is returned, else the `output_fn` is used to update the `cell_output` before calculating `next_input` and returning `cell_output`. next context state: `context_state`, this decoder function does not modify the given context state. The context state could be modified when applying e.g. beam search. Raises: ValueError: if cell_input is not None. """ with ops.name_scope( name, "attention_decoder_fn_inference", [time, cell_state, cell_input, cell_output, context_state]): if cell_input is not None: raise ValueError( "Expected cell_input to be None, but saw: %s" % cell_input) if cell_output is None: # invariant that this is time == 0 next_input_id = array_ops.ones([ batch_size, ], dtype=dtype) * (start_of_sequence_id) done = array_ops.zeros([ batch_size, ], dtype=dtypes.bool) cell_state = encoder_state cell_output = array_ops.zeros([num_decoder_symbols], dtype=dtypes.float32) cell_input = array_ops.gather(embeddings, next_input_id) # init attention attention = _init_attention(encoder_state) # init context state log_beam_probs = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="log_beam_probs", size=maximum_length, dynamic_size=True, infer_shape=False) beam_parents = tensor_array_ops.TensorArray( dtype=dtypes.int32, tensor_array_name="beam_parents", size=maximum_length, dynamic_size=True, infer_shape=False) beam_symbols = tensor_array_ops.TensorArray( dtype=dtypes.int32, tensor_array_name="beam_symbols", size=maximum_length, dynamic_size=True, infer_shape=False) result_probs = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="result_probs", size=maximum_length, dynamic_size=True, infer_shape=False) result_parents = tensor_array_ops.TensorArray( dtype=dtypes.int32, tensor_array_name="result_parents", size=maximum_length, dynamic_size=True, infer_shape=False) result_symbols = tensor_array_ops.TensorArray( dtype=dtypes.int32, tensor_array_name="result_symbols", size=maximum_length, dynamic_size=True, infer_shape=False) context_state = (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) else: # construct attention attention = attention_construct_fn(cell_output, attention_keys, attention_values) cell_output = attention # beam search decoder (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) = context_state cell_output = output_fn(cell_output) # logits cell_output = nn_ops.softmax(cell_output) cell_output = array_ops.split(cell_output, [2, num_decoder_symbols - 2], 1)[1] tmp_output = array_ops.gather( cell_output, math_ops.range(origin_batch) * beam_size) probs = control_flow_ops.cond( math_ops.equal(time, ops.convert_to_tensor(1, dtype)), lambda: math_ops.log(tmp_output + ops.convert_to_tensor( 1e-20, dtypes.float32)), lambda: math_ops.log(cell_output + ops.convert_to_tensor( 1e-20, dtypes.float32)) + array_ops.reshape( log_beam_probs.read(time - 2), [-1, 1])) probs = array_ops.reshape(probs, [origin_batch, -1]) best_probs, indices = nn_ops.top_k(probs, beam_size * 2) #indices = array_ops.reshape(indices, [-1]) indices_flatten = array_ops.reshape(indices, [ -1 ]) + array_ops.reshape( array_ops.concat([ array_ops.reshape( math_ops.range(origin_batch) * ((num_decoder_symbols - 2) * beam_size), [-1, 1]) ] * (beam_size * 2), 1), [origin_batch * beam_size * 2]) best_probs_flatten = array_ops.reshape(best_probs, [-1]) symbols = indices_flatten % (num_decoder_symbols - 2) symbols = symbols + 2 parents = indices_flatten // (num_decoder_symbols - 2) probs_wo_eos = best_probs + 1e5 * math_ops.cast( math_ops.cast( (indices % (num_decoder_symbols - 2) + 2) - end_of_sequence_id, dtypes.bool), dtypes.float32) best_probs_wo_eos, indices_wo_eos = nn_ops.top_k( probs_wo_eos, beam_size) indices_wo_eos = array_ops.reshape( indices_wo_eos, [-1]) + array_ops.reshape( array_ops.concat([ array_ops.reshape( math_ops.range(origin_batch) * (beam_size * 2), [-1, 1]) ] * beam_size, 1), [origin_batch * beam_size]) _probs = array_ops.gather(best_probs_flatten, indices_wo_eos) _symbols = array_ops.gather(symbols, indices_wo_eos) _parents = array_ops.gather(parents, indices_wo_eos) log_beam_probs = log_beam_probs.write(time - 1, _probs) beam_symbols = beam_symbols.write(time - 1, _symbols) beam_parents = beam_parents.write(time - 1, _parents) result_probs = result_probs.write(time - 1, best_probs_flatten) result_symbols = result_symbols.write(time - 1, symbols) result_parents = result_parents.write(time - 1, parents) next_input_id = array_ops.reshape(_symbols, [batch_size]) state_size = int(cell_state[0].get_shape().with_rank(2)[1]) attn_size = int(attention.get_shape().with_rank(2)[1]) state = [] for j in cell_state: state.append( array_ops.reshape(array_ops.gather(j, _parents), [-1, state_size])) cell_state = tuple(state) attention = array_ops.reshape( array_ops.gather(attention, _parents), [-1, attn_size]) done = math_ops.equal(next_input_id, end_of_sequence_id) cell_input = array_ops.gather(embeddings, next_input_id) # combine cell_input and attention next_input = array_ops.concat([cell_input, attention], 1) # if time > maxlen, return all true vector done = control_flow_ops.cond( math_ops.greater(time, maximum_length), lambda: array_ops.ones([ batch_size, ], dtype=dtypes.bool), lambda: array_ops.zeros([ batch_size, ], dtype=dtypes.bool)) return (done, cell_state, next_input, cell_output, (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols)) #context_state)
def loop_fn(i): return model_fn(array_ops.expand_dims(array_ops.gather(inp, i), 0))
def _SparseSegmentSumGrad(op, grad): """Gradient for SparseSegmentSum.""" input_rows = array_ops.shape(op.inputs[0])[0] return (math_ops.unsorted_segment_sum(array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, None)
def get(self, iteration): return array_ops.gather(self._index, iteration)
def _UnsortedSegmentSumGrad(op, grad): """Gradient for SegmentSum.""" return array_ops.gather(grad, op.inputs[1]), None, None
def _event_shape(self): return array_ops.gather(array_ops.shape(self._mean_val), [array_ops.rank(self._mean_val) - 1])
def __init__(self, mu, sigma=None, sigma_chol=None, name=None): """Multivariate Normal distributions on `R^k`. User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) with the last dimension having length `k`. User must provide exactly one of `sigma` (the covariance matrices) or `sigma_chol` (the cholesky decompositions of the covariance matrices). `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions must both have length `k`. The first `N` dimensions correspond to batch indices. If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` is calculated for you. The shapes of `mu` and `sigma` must match for the first `N` dimensions. Regardless of which parameter is provided, the covariance matrices must all be **positive definite** (an error is raised if one of them is not). Args: mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances of the distribution(s). The first `N+1` dimensions must match those of `mu`. Must be batch-positive-definite. sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a lower-triangular factorization of `sigma` (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions must match those of `mu`. The tensor itself need not be batch lower triangular: we ignore the upper triangular part. However, the batch diagonals must be positive (i.e., sigma_chol must be batch-positive-definite). name: The name to give Ops created by the initializer. Raises: ValueError: if neither sigma nor sigma_chol is provided. TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. """ if (sigma is None) == (sigma_chol is None): raise ValueError( "Exactly one of sigma and sigma_chol must be provided") with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): sigma_or_half = sigma_chol if sigma is None else sigma mu = ops.convert_to_tensor(mu) sigma_or_half = ops.convert_to_tensor(sigma_or_half) contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) with ops.control_dependencies( [_assert_compatible_shapes(mu, sigma_or_half)]): mu = array_ops.identity(mu, name="mu") # Store the dimensionality of the MVNs self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) if sigma_chol is not None: # Ensure we only keep the lower triangular part. sigma_chol = array_ops.batch_matrix_band_part(sigma_chol, num_lower=-1, num_upper=0) sigma_det = _determinant_from_sigma_chol(sigma_chol) with ops.control_dependencies( [_assert_batch_positive_definite(sigma_chol)]): self._sigma = math_ops.batch_matmul(sigma_chol, sigma_chol, adj_y=True, name="sigma") self._sigma_chol = array_ops.identity( sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity( sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu") else: # sigma is not None sigma_chol = linalg_ops.batch_cholesky(sigma) sigma_det = _determinant_from_sigma_chol(sigma_chol) # batch_cholesky checks for PSD; so we can just use it here. with ops.control_dependencies([sigma_chol]): self._sigma = array_ops.identity(sigma, "sigma") self._sigma_chol = array_ops.identity( sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity( sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu")
def loop_fn(i): image = array_ops.gather(images, i) return model(image, training=training)
def _rank_resample(weights, biases, inputs, sampled_values, num_resampled, resampling_temperature, partition_strategy): """A helper function for rank_sampled_softmax_loss. This computes, for each i in `sampled_values`, log(sum_j exp((w_i * x_j + b_i) / resampling_temperature)) where w_i, b_i are the weight and bias of the i-th class, repsectively, and j ranges over the rows of `inputs`. For efficiency, we rearrange the computation to log(sum_j exp(w_i * (x_j / resampling_temperature))) + b_i / resampling_temperature. This translates to the following batched computation using tensorflow ops: reduce_logsumexp(matmul(embeddings, transpose(inputs / resampling_temperature))) + biases / resampling_temperature The computation of the first term is colocated with the embeddings using `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second term, not the bottleneck, is computed at the worker. Args: weights: From `rank_sampled_softmax_loss`. biases: From `rank_sampled_softmax_loss`. inputs: From `rank_sampled_softmax_loss`. sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`, `sampled_expected_count`) returned by a `*_candidate_sampler` function. num_resampled: An `int`. This many values are selected from `sampled_values` using the adaptive resampling algorithm. The caller must ensure that `num_resampled` is less than the size of `sampled_values`. resampling_temperature: A scalar `Tensor` with the temperature parameter for the adaptive resampling algorithm. partition_strategy: From `rank_sampled_softmax_loss`. Returns: A tuple of (`resampled_candidates`, `true_expected_count`, `resampled_expected_count`), similar to `sampled_values` but sampled down to `num_resampled` values. """ # This code supports passing a Tensor for num_resampled, but since it is only # called with an int, that's what we specify in the arg list. If this # function is ever externalized, we should change the doc to support Tensor. sampled, true_expected_count, sampled_expected_count = sampled_values sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64) true_expected_count = array_ops.stop_gradient(true_expected_count) sampled_expected_count = array_ops.stop_gradient(sampled_expected_count) reweighted_inputs = inputs / resampling_temperature def logsumexp_logit(embeddings): return math_ops.reduce_logsumexp( math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True), axis=1, keep_dims=False) # Calling this protected form of embedding_lookup allows co-locating # the logsumexp computation with the partitioned weights, which yields # a large speedup in practice. sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access weights, sampled, partition_strategy, transform_fn=logsumexp_logit) sampled_b = array_ops.reshape( embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1]) sampled_logits += sampled_b / resampling_temperature _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False) resampled = array_ops.gather(sampled, indices=resampled_indices) resampled_expected_count = array_ops.gather( sampled_expected_count, indices=resampled_indices) return resampled, true_expected_count, resampled_expected_count
def decoder_fn(time, cell_state, cell_input, cell_output, context_state): """Decoder function used in the `dynamic_rnn_decoder` for inference. The main difference between this decoder function and the `decoder_fn` in `attention_decoder_fn_train` is how `next_cell_input` is calculated. In decoder function we calculate the next input by applying an argmax across the feature dimension of the output from the decoder. This is a greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) use beam-search instead. Args: time: positive integer constant reflecting the current timestep. cell_state: state of RNNCell. cell_input: input provided by `dynamic_rnn_decoder`. cell_output: output of RNNCell. context_state: context state provided by `dynamic_rnn_decoder`. Returns: A tuple (done, next state, next input, emit output, next context state) where: done: A boolean vector to indicate which sentences has reached a `end_of_sequence_id`. This is used for early stopping by the `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with all elements as `true` is returned. next state: `cell_state`, this decoder function does not modify the given state. next input: The embedding from argmax of the `cell_output` is used as `next_input`. emit output: If `output_fn is None` the supplied `cell_output` is returned, else the `output_fn` is used to update the `cell_output` before calculating `next_input` and returning `cell_output`. next context state: `context_state`, this decoder function does not modify the given context state. The context state could be modified when applying e.g. beam search. Raises: ValueError: if cell_input is not None. """ with ops.name_scope( name, "attention_decoder_fn_inference", [time, cell_state, cell_input, cell_output, context_state]): if cell_input is not None: raise ValueError( "Expected cell_input to be None, but saw: %s" % cell_input) if cell_output is None: # invariant that this is time == 0 next_input_id = array_ops.ones([ batch_size, ], dtype=dtype) * (start_of_sequence_id) done = array_ops.zeros([ batch_size, ], dtype=dtypes.bool) cell_state = encoder_state cell_output = array_ops.zeros([num_decoder_symbols], dtype=dtypes.float32) word_input = array_ops.gather(embeddings, next_input_id) naf_triple_id = array_ops.zeros([batch_size, 2], dtype=dtype) triple_input = array_ops.gather_nd(imem[1], naf_triple_id) cell_input = array_ops.concat([word_input, triple_input], axis=1) # init attention attention = _init_attention(encoder_state) if imem is not None: context_state = tensor_array_ops.TensorArray( dtype=dtypes.int32, tensor_array_name="output_ids_ta", size=maximum_length, dynamic_size=True, infer_shape=False) else: # construct attention attention = attention_construct_fn(cell_output, attention_keys, attention_values) if type(attention) is tuple: attention, alignment = attention cell_output = attention alignment = tf.reshape(alignment, [batch_size, -1]) selector = selector_fn(cell_output) logit = output_fn(cell_output) word_prob = nn_ops.softmax(logit) * (1 - selector) entity_prob = alignment * selector mask = array_ops.reshape( math_ops.cast(math_ops.greater( tf.reduce_max(word_prob, 1), tf.reduce_max(entity_prob, 1)), dtype=dtypes.float32), [-1, 1]) word_input = mask * array_ops.gather( embeddings, math_ops.cast(math_ops.argmax(word_prob, 1), dtype=dtype) ) + (1 - mask) * array_ops.gather_nd( imem[0], array_ops.concat([ array_ops.reshape( math_ops.range(batch_size, dtype=dtype), [-1, 1]), array_ops.reshape( math_ops.cast(math_ops.argmax(entity_prob, 1), dtype=dtype), [-1, 1]) ], axis=1)) indices = array_ops.concat([ array_ops.reshape( math_ops.range(batch_size, dtype=dtype), [-1, 1]), math_ops.cast(1 - mask, dtype=dtype) * tf.reshape( math_ops.cast(math_ops.argmax(alignment, 1), dtype=dtype), [-1, 1]) ], axis=1) triple_input = array_ops.gather_nd(imem[1], indices) cell_input = array_ops.concat([word_input, triple_input], axis=1) mask = array_ops.reshape(math_ops.cast(mask, dtype=dtype), [-1]) input_id = mask * math_ops.cast( math_ops.argmax(word_prob, 1), dtype=dtype) + (mask - 1) * math_ops.cast( math_ops.argmax(entity_prob, 1), dtype=dtype) context_state = context_state.write(time - 1, input_id) done = array_ops.reshape( math_ops.equal(input_id, end_of_sequence_id), [-1]) cell_output = logit else: cell_output = attention # argmax decoder cell_output = output_fn(cell_output) # logits next_input_id = math_ops.cast(math_ops.argmax( cell_output, 1), dtype=dtype) done = math_ops.equal(next_input_id, end_of_sequence_id) cell_input = array_ops.gather(embeddings, next_input_id) # combine cell_input and attention next_input = array_ops.concat([cell_input, attention], 1) # if time > maxlen, return all true vector done = control_flow_ops.cond( math_ops.greater(time, maximum_length), lambda: array_ops.ones([ batch_size, ], dtype=dtypes.bool), lambda: done) return (done, cell_state, next_input, cell_output, context_state)
def loop_fn(i): y = array_ops.gather(output, i, axis=1) return gradient_ops.gradients(y, inp)[0]
def tpu_function(sparse): # Assumes dense_shape is (2, *) looked_up = array_ops.gather(table, sparse.values) segment_sum = math_ops.unsorted_segment_sum( looked_up, sparse.indices[:, 0], 2) return {"sparse": sparse, "segment_sum": segment_sum}
def create_batch(self): """Create queues to window and batch time series data. Returns: A dictionary of Tensors corresponding to the output of `self._reader` (from the `time_series_reader` constructor argument), each with shapes prefixed by [`batch_size`, `window_size`]. """ features = self._reader.read() if self._jitter: # TODO(agarwal, allenl): Figure out if more jitter is needed here. jitter = random_ops.random_uniform(shape=[], maxval=2, dtype=dtypes.int32) else: jitter = 0 # To keep things efficient, we pass from the windowing batcher to the # batch-of-windows batcher in batches. This avoids the need for huge numbers # of threads, but does mean that jitter is only applied occasionally. # TODO(allenl): Experiment with different internal passing sizes. internal_passing_size = self._batch_size features_windowed = input_lib.batch( features, batch_size=self._window_size * internal_passing_size + jitter, enqueue_many=True, capacity=(self._queue_capacity_multiplier * internal_passing_size * self._window_size), num_threads=self._num_threads) raw_features_windowed = features_windowed if self._jitter: features_windowed = { key: value[jitter:] for key, value in features_windowed.items()} features_windowed = { key: array_ops.reshape( value, array_ops.concat( [[internal_passing_size, self._window_size], array_ops.shape(value)[1:]], axis=0)) for key, value in features_windowed.items()} batch_and_window_shape = tensor_shape.TensorShape( [internal_passing_size, self._window_size]) for key in features_windowed.keys(): features_windowed[key].set_shape( batch_and_window_shape.concatenate( raw_features_windowed[key].get_shape()[1:])) # When switching files, we may end up with windows where the time is not # decreasing, even if times within each file are sorted (and even if those # files are visited in order, when looping back around to the beginning of # the first file). This is hard for models to deal with, so we either # discard such examples, creating a bias where the beginning and end of the # series is under-sampled, or we sort the window, creating large gaps. times = features_windowed[feature_keys.TrainEvalFeatures.TIMES] if self._discard_out_of_order: non_decreasing = math_ops.reduce_all( times[:, 1:] >= times[:, :-1], axis=1) # Ensure that no more than self._discard_limit complete batches are # discarded contiguously (resetting the count when we find a single clean # window). This prevents infinite looping when the dataset is smaller than # the window size. # TODO(allenl): Figure out a way to return informative errors from # count_up_to. discarded_windows_limiter = variable_scope.variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), name="discarded_windows_limiter", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) def _initialized_limit_check(): return control_flow_ops.cond( math_ops.reduce_any(non_decreasing), lambda: state_ops.assign(discarded_windows_limiter, 0), lambda: discarded_windows_limiter.count_up_to(self._discard_limit)) discard_limit_op = control_flow_ops.cond( state_ops.is_variable_initialized(discarded_windows_limiter), _initialized_limit_check, lambda: constant_op.constant(0, dtype=dtypes.int64)) with ops.control_dependencies([discard_limit_op]): non_decreasing = array_ops.identity(non_decreasing) else: _, indices_descending = nn.top_k( times, k=array_ops.shape(times)[-1], sorted=True) indices = array_ops.reverse(indices_descending, axis=[0]) features_windowed = { key: array_ops.gather(params=value, indices=indices) for key, value in features_windowed.items() } non_decreasing = True features_batched = input_lib.maybe_shuffle_batch( features_windowed, num_threads=self._num_threads, seed=self._shuffle_seed, batch_size=self._batch_size, capacity=self._queue_capacity_multiplier * self._batch_size, min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier * self._batch_size), keep_input=non_decreasing, enqueue_many=True) return (features_batched, None)
def stack_dynamic_partitions(data, partitions, num_partitions, name=None): """Stacks dynamic partitions of a Tensor or RaggedTensor. Returns a RaggedTensor `output` with `num_partitions` rows, where the row `output[i]` is formed by stacking all slices `data[j1...jN]` such that `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major order. If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`. #### Example: >>> data = ['a', 'b', 'c', 'd', 'e'] >>> partitions = [ 3, 0, 2, 2, 3] >>> num_partitions = 5 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions) <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]> Args: data: A `Tensor` or `RaggedTensor` containing the values to stack. partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the partition that each slice of `data` should be added to. `partitions.shape` must be a prefix of `data.shape`. Values must be greater than or equal to zero, and less than `num_partitions`. `partitions` is not required to be sorted. num_partitions: An `int32` or `int64` scalar specifying the number of partitions to output. This determines the number of rows in `output`. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` containing the stacked partitions. The returned tensor has the same dtype as `data`, and its shape is `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a ragged dimension whose length is the number of data slices stacked for each `partition`. """ with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') row_splits_dtype = (data.row_splits.dtype if isinstance( data, ragged_tensor.RaggedTensor) else None) partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor( partitions, name='partitions', preferred_dtype=row_splits_dtype) num_partitions = ops.convert_to_tensor( num_partitions, name='num_partitions', preferred_dtype=partitions.dtype) if row_splits_dtype is not None: partitions = math_ops.cast(partitions, row_splits_dtype) num_partitions = math_ops.cast(num_partitions, partitions.dtype) # Sanity-checks for shapes. partitions_rank = partitions.shape.ndims if partitions_rank is None: raise ValueError('partitions must have known rank.') num_partitions.shape.assert_has_rank(0) partitions.shape.assert_is_compatible_with( data.shape[:partitions_rank]) if partitions_rank == 0: # If partitions is a scalar, then just create a RaggedTensor containing # that single the complete `data` value in the specified row. return ragged_tensor.RaggedTensor.from_value_rowids( values=array_ops.stack([data]), value_rowids=array_ops.stack([partitions]), nrows=num_partitions, validate=False) elif partitions_rank == 1: # If partitions is a vector (the typical case): we can just use data and # partitions as the `values` and `value_rowids` for `from_value_rowids`, # as long as we sort them first. permutation = sort_ops.argsort(partitions, stable=True) value_rowids = array_ops.gather(partitions, permutation) values = array_ops.gather(data, permutation) check = check_ops.assert_less( value_rowids[-1:], num_partitions, message='partitions must be less than num_partitions') with ops.control_dependencies([check]): return ragged_tensor.RaggedTensor.from_value_rowids( values, value_rowids, nrows=num_partitions, validate=False) else: # Handle higher-dimensional partitions via recursion. if not isinstance(data, ragged_tensor.RaggedTensor): data = ragged_tensor.RaggedTensor.from_tensor( data, row_splits_dtype=partitions.dtype, ragged_rank=1) if not isinstance(partitions, ragged_tensor.RaggedTensor): partitions = ragged_tensor.RaggedTensor.from_tensor( partitions, row_splits_dtype=partitions.dtype, ragged_rank=max(data.ragged_rank, partitions_rank - 1)) check = check_ops.assert_equal( data.row_splits, partitions.row_splits, message='data and partitions have incompatible ragged shapes') with ops.control_dependencies([check]): return stack_dynamic_partitions(data.values, partitions.values, num_partitions)
def stratified_sample(tensors, labels, target_probs, batch_size, init_probs=None, enqueue_many=False, queue_capacity=16, threads_per_queue=1, name=None): """Stochastically creates batches based on per-class probabilities. This method discards examples. Internally, it creates one queue to amortize the cost of disk reads, and one queue to hold the properly-proportioned batch. Args: tensors: List of tensors for data. All tensors are either one item or a batch, according to enqueue_many. labels: Tensor for label of data. Label is a single integer or a batch, depending on enqueue_many. It is not a one-hot vector. target_probs: Target class proportions in batch. An object whose type has a registered Tensor conversion function. batch_size: Size of batch to be returned. init_probs: Class proportions in the data. An object whose type has a registered Tensor conversion function, or `None` for estimating the initial distribution. enqueue_many: Bool. If true, interpret input tensors as having a batch dimension. queue_capacity: Capacity of the large queue that holds input examples. threads_per_queue: Number of threads for the large queue that holds input examples and for the final queue with the proper class proportions. name: Optional prefix for ops created by this function. Raises: ValueError: enqueue_many is True and labels doesn't have a batch dimension, or if enqueue_many is False and labels isn't a scalar. ValueError: enqueue_many is True, and batch dimension on data and labels don't match. ValueError: if probs don't sum to one. ValueError: if a zero initial probability class has a nonzero target probability. TFAssertion: if labels aren't integers in [0, num classes). Returns: (data_batch, label_batch), where data_batch is a list of tensors of the same length as `tensors` Example: # Get tensor for a single data and label example. data, label = data_provider.Get(['data', 'label']) # Get stratified batch according to per-class probabilities. target_probs = [...distribution you want...] [data_batch], labels = tf.contrib.training.stratified_sample( [data], label, target_probs) # Run batch through network. ... """ with ops.name_scope(name, 'stratified_sample', tensors + [labels]): tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) labels = ops.convert_to_tensor(labels) target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32) # Reduce the case of a single example to that of a batch of size 1. if not enqueue_many: tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list] labels = array_ops.expand_dims(labels, 0) # If `init_probs` is `None`, set up online estimation of data distribution. if init_probs is None: # We use `target_probs` to get the number of classes, so its shape must be # fully defined at graph construction time. target_probs.get_shape().assert_is_fully_defined() init_probs = _estimate_data_distribution( labels, target_probs.get_shape().num_elements()) else: init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32) # Validate that input is consistent. tensor_list, labels, [init_probs, target_probs] = _verify_input( tensor_list, labels, [init_probs, target_probs]) # Check that all zero initial probabilities also have zero target # probabilities. assert_op = control_flow_ops.Assert( math_ops.reduce_all(math_ops.logical_or( math_ops.not_equal(init_probs, 0), math_ops.equal(target_probs, 0))), ['All classes with zero initial probability must also have zero target ' 'probability: ', init_probs, target_probs]) init_probs = control_flow_ops.with_dependencies([assert_op], init_probs) # Calculate acceptance sampling probabilities. accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs) proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs) accept_probs = control_flow_ops.cond( math_ops.less(proportion_rejected, .5), lambda: accept_probs, lambda: logging_ops.Print( # pylint: disable=g-long-lambda accept_probs, [accept_probs], message='Proportion of examples rejected by sampler is high.', first_n=10)) # Make a single queue to hold input examples. Reshape output so examples # don't have singleton batch dimension. batched = input_ops.batch(tensor_list + [labels], batch_size=1, num_threads=threads_per_queue, capacity=queue_capacity, enqueue_many=True) val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]] label = array_ops.squeeze(batched[-1], [0]) # Set up second queue containing batches that have the desired class # proportions. cur_prob = array_ops.gather(accept_probs, label) keep_input = random_ops.random_uniform([]) < cur_prob batched = _conditional_batch( val_list + [label], keep_input, batch_size, num_threads=threads_per_queue) return batched[:-1], batched[-1]
def _tile_ragged_splits(rt_input, multiples, const_multiples=None): """Builds nested_split tensors for a tiled `RaggedTensor`. Returns a list of split tensors that can be used to construct the `RaggedTensor` that tiles `rt_input` as specified by `multiples`. Args: rt_input: The `RaggedTensor` that is being tiled. multiples: A 1-D integer `tensor`, indicating how many times each dimension should be repeated. const_multiples: Optional constant value for multiples. Used to skip tiling dimensions where `multiples=1`. Returns: A list of 1-D integer `Tensor`s (one for each ragged dimension in `rt_input`). #### Example: >>> rt = tf.ragged.constant([[1, 2], [3]]) >>> _tile_ragged_splits(rt, [3, 2]) [<tf.Tensor: shape=(7,), dtype=int64, numpy=array([ 0, 4, 6, 10, 12, 16, 18])>] """ ragged_rank = rt_input.ragged_rank nested_splits = rt_input.nested_row_splits # projected_splits[src_axis, dst_axis] contains the split points that divide # the rows from src_axis in the list of dst_axis values. E.g., # projected_splits[i, i] = nested_splits[i], and # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]). projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)] for src_axis in range(ragged_rank): for dst_axis in range(src_axis + 1, ragged_rank - 1): projected_splits[src_axis][dst_axis] = array_ops.gather( nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1]) # For each ragged dimension: nested_splits[axis] -> result_splits[axis]. result_splits = [] for axis in range(ragged_rank): # Get the length of each row for the input tensor for this dimension. input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1] # Multiply those lengths by the `multiples` of dimension axis+1, since # each value will be repeated that number of times. output_lengths = input_lengths * multiples[axis + 1] # Repeat ranges of the row lengths as necessary for them to be tiled in # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and # work our way up to dimension d=0.) repeats = 1 for d in range(axis - 1, -1, -1): if const_multiples is None or const_multiples[d + 1] != 1: splits = projected_splits[d][axis - 1] * repeats output_lengths = ragged_util.repeat_ranges( output_lengths, splits, multiples[d + 1]) repeats *= multiples[d + 1] # Tile splits for the outermost (uniform) dimension. output_lengths = array_ops.tile(output_lengths, multiples[:1]) # Convert to splits. result_splits.append(ragged_util.lengths_to_splits(output_lengths)) return result_splits
def func(self, x): return array_ops.gather(self.shared_weights, x)
def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"): """Creates a (batch of) lower triangular matrix from a vector of inputs. If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`. Although the non-batch complexity is O(n^2), large constants and sub-optimal vectorization means the complexity of this function is 5x slower than zeroing out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`. This function becomes competitive only when several matmul/cholesky/etc ops can be ellided in constructing the input. Example: wiring a fully connected layer as a covariance matrix; this function reduces the final layer by 2x and possibly reduces the network arch complexity considerably. In most cases it is better to simply build a full matrix and zero out the upper triangular elements, e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly construct a lower triangular. Example: ```python fill_lower_triangular([1, 2, 3, 4, 5, 6]) # Returns: [[1, 0, 0], # [2, 3, 0], # [4, 5, 6]] ``` For comparison, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_lower_triangular`. Args: x: `Tensor` representing lower triangular elements. validate_args: `Boolean`, default `False`. Whether to ensure the shape of `x` can be mapped to a lower triangular matrix (controls non-static checks only). name: `String`. The name to give this op. Returns: tril: `Tensor` with lower triangular elements filled from `x`. Raises: ValueError: if shape if `x` has static shape which cannot be mapped to a lower triangular matrix. """ # TODO(jvdillon): Replace this code with dedicated op when it exists. with ops.name_scope(name, values=(x, )): x = ops.convert_to_tensor(x, name="x") if (x.get_shape().ndims is not None and x.get_shape()[-1].value is not None): d = x.get_shape()[-1].value # d = n(n+1)/2 implies n is: n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.)) d_inferred = n * (n + 1) / 2 if d != d_inferred: raise ValueError( "Input cannot be mapped to a lower triangular; " "n*(n+1)/2 = %d != %d" % (d_inferred, d)) final_shape = x.get_shape()[:-1].concatenate( tensor_shape.TensorShape([n, n])) else: d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32) # d = n(n+1)/2 implies n is: n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.), dtype=dtypes.int32) if validate_args: is_valid_input_shape = check_ops.assert_equal( n * (n + 1) / 2, d, message="Input cannot be mapped to a lower triangular.") n = control_flow_ops.with_dependencies([is_valid_input_shape], n) final_shape = x.get_shape()[:-1].concatenate( tensor_shape.TensorShape([None, None])) def tril_ids(n): """Internal helper to create vector of linear indices into y.""" # Build the ids statically; chose 512 because it implies 1MiB. if not contrib_framework.is_tensor(n) and n <= 512: ids = np.arange(n**2, dtype=np.int32) rows = (ids / n).astype(np.int32) # Implicit floor. # We need to stop incrementing the index when we encounter # upper-triangular elements. The idea here is to compute the # lower-right number of zeros then by "symmetry" subtract this from the # total number of zeros, n(n-1)/2. # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2 offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32) # We could also zero out when (rows < cols) == (rows < ids-n*rows). # mask = (ids <= (n + 1) * rows).astype(np.int32) else: ids = math_ops.range(n**2) rows = math_ops.cast(ids / n, dtype=dtypes.int32) offset = math_ops.cast(rows * (2 * n - rows - 1) / 2, dtype=dtypes.int32) return ids - offset # Special-case non-batch case. if x.get_shape().ndims == 1: y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n])) y = array_ops.matrix_band_part(y, -1, 0) y.set_shape(y.get_shape().merge_with(final_shape)) return y # Make ids for each batch dim. if (x.get_shape().ndims is not None and x.get_shape()[:-1].is_fully_defined()): batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32) m = np.prod(batch_shape).astype(np.int32) else: batch_shape = array_ops.shape(x)[:-1] m = array_ops.reduce_prod(array_ops.shape(x)[:-1]) batch_ids = math_ops.range(m) # Assemble the tril_ids into batch,tril_id pairs. idx = array_ops.pack([ array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]), array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1]) ]) idx = array_ops.transpose(idx, [1, 2, 0]) # Gather up, reshape, and return. y = array_ops.reshape(x, [-1, d]) y = array_ops.gather_nd(y, idx) y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]])) y = array_ops.matrix_band_part(y, -1, 0) y.set_shape(y.get_shape().merge_with(final_shape)) return y
def _done(t): # Note that we don't use tf.control_dependencies since that will not make # sure that the computation on GPU has actually finished. So we fetch the # first element of the output, and assume that this will not be called on # empty tensors. return array_ops.gather(array_ops.reshape(t, [-1]), 0)
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): """Gradient for concat op. Args: op: An operation. grad: `Tensor` or `IndexedSlices` representing the gradients with respect to each output of the op. start_value_index: An integer index of the first value in the op.inputs. end_value_index: An integer index of the last value in the op.inputs. dim_index: An interger index of concat_dim or axis parameter in op.inputs. Returns: Tensors representing the partial gradients with respect to each input of the op. Raises: ValueError: if concat_dim/axis is not statically known. """ def _CreateDenseMaskAndBegin(sizes, concat_dim): """Create variables for iteratively slicing a dense gradients tensor.""" # Since shape is 1-D, shape_of_shape = [rank-of-inputs] shape_of_shape = array_ops.shape(sizes[0]) # Make a vector of length equal to the input's dimensions, # with 0's everywhere and 1 in the concat dim position. # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) mask = array_ops.concat([ array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1], array_ops.fill(shape_of_shape - concat_dim - 1, 0) ], 0) begin = array_ops.fill(shape_of_shape, 0) return mask, begin def _ExtractInputShapes(inputs): """Extract the shapes of a set of input tensors.""" if not context.in_graph_mode(): return array_ops.shape_n(inputs) sizes = [] fully_known = True for x in inputs: input_shape = array_ops.shape(x) if not isinstance(input_shape, ops.Tensor) or input_shape.op.type != "Const": fully_known = False break sizes.append(input_shape) if fully_known: return sizes else: return array_ops.shape_n(inputs) # Degenerate concatenation, just return grad. if len(op.inputs) == 2: return grad + [None] if end_value_index <= dim_index else [None] + grad concat_dim = op.inputs[dim_index] input_values = op.inputs[start_value_index:end_value_index] out_grads = [] if isinstance(grad, ops.Tensor): if context.in_eager_mode(): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = (concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access # All inputs are guaranteed to be EagerTensors in eager mode sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice( input_values, non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) # Get the inputs' tensor shapes sizes = _ExtractInputShapes(input_values) # The magic number of 16 was found through benchmarking a range of sizes # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of # cases when switching implementations at N=16, but it is possible that # there will be a small number of performance regressions. # pylint: disable=protected-access if len(sizes) > 16: # extract the size of each input along the concat dimension sizes = array_ops.squeeze( array_ops.slice(array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], [1, -1])) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes) for (begin, size) in zip(offset, sizes): out_grads.append(array_ops.slice(grad, begin, size)) # pylint: enable=protected-access elif isinstance(grad, ops.IndexedSlices): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) concat_dim_static = tensor_util.constant_value(concat_dim) if concat_dim_static is None: raise ValueError("Can only compute IndexedSlices gradient with " "statically-known concat_dim") if concat_dim_static < 0: rank = tensor_util.constant_value(array_ops.rank(input_values[0])) if rank is None: raise ValueError( "Can only compute IndexedSlices gradient with " "negative concat_dim when first value rank is " "statically-known.") concat_dim_static %= rank # Get the inputs' tensor shapes sizes = [array_ops.shape(x) for x in input_values] if concat_dim_static > 0: # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices # gradients with all the indices, but with grad.values sliced accordingly. # This is like the Tensor case, except shape(grad.values)[0] is not equal # to shape(sizes[i])[0], since only a subset of the dim-0 values are # stored. mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) for size in sizes: new_values = array_ops.slice( grad.values, begin, array_ops.concat( [[-1], array_ops.slice(size, [1], [-1])], 0)) out_grads.append( ops.IndexedSlices(new_values, grad.indices, size)) # Lint complains begin = begin + ... begin = math_ops.add(begin, size * mask) else: # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients # only for the relevant indices. start = constant_op.constant(0, dtype=grad.indices.dtype) for size in sizes: size_concat_dim = array_ops.gather(size, non_neg_concat_dim) if size_concat_dim.dtype != grad.indices.dtype: size_concat_dim = math_ops.cast(size_concat_dim, dtype=grad.indices.dtype) end = start + size_concat_dim # Compute the 1-D Tensor of indices relevant for this input. indices_to_select = array_ops.squeeze(array_ops.where( math_ops.logical_and(grad.indices >= start, grad.indices < end)), squeeze_dims=[1]) new_indices = array_ops.gather(grad.indices, indices_to_select) - start new_values = array_ops.gather(grad.values, indices_to_select) out_grads.append( ops.IndexedSlices(new_values, new_indices, size)) start = end else: raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) return (out_grads + [None] if end_value_index <= dim_index else [None] + out_grads)
def loop_fn(i): y = array_ops.gather(output, i) return gradient_ops.gradients(y, flat_inputs)
def _inverse(self, y): return array_ops.gather(y, array_ops.invert_permutation(self.permutation), axis=-1)
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The reconstruct one or more matrices from their LU decomposition(s). Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with ops.name_scope(name or 'lu_reconstruct'): lower_upper = ops.convert_to_tensor( lower_upper, dtype_hint=dtypes.float32, name='lower_upper') perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with ops.control_dependencies(assertions): lower_upper = array_ops.identity(lower_upper) perm = array_ops.identity(perm) shape = array_ops.shape(lower_upper) lower = set_diag( band_part(lower_upper, num_lower=-1, num_upper=0), array_ops.ones(shape[:-1], dtype=lower_upper.dtype)) upper = band_part(lower_upper, num_lower=0, num_upper=-1) x = math_ops.matmul(lower, upper) if (lower_upper.shape is None or lower_upper.shape.rank is None or lower_upper.shape.rank != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = math_ops.reduce_prod(shape[:-2]) d = shape[-1] x = array_ops.reshape(x, [batch_size, d, d]) perm = array_ops.reshape(perm, [batch_size, d]) perm = map_fn.map_fn(array_ops.invert_permutation, perm) batch_indices = array_ops.broadcast_to( math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d]) x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm], axis=-1)) x = array_ops.reshape(x, shape) else: x = array_ops.gather(x, array_ops.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x