def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) m_t_slice = tf.gather(m_t, grad.indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_prime = self.get_slot(var, "v_prime") v_t_slice = tf.gather(v_t, grad.indices) v_prime_slice = tf.gather(v_prime, grad.indices) v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice)) v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m := beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1 - beta1_t) * grad.values, use_locking=self._use_locking) # v := beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1 - beta2_t) * math_ops.square(grad.values), use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") # The exact error and message differ between graph construction (where the # error is realized during shape inference at graph construction time) and # eager execution (where the error is realized during kernel execution). with self.assertRaisesRegexp(Exception, r"shape.*2.*3"): state_ops.scatter_update(v, [0, 1], [0, 1, 2])
def _SparseUpdate(variable, gradients, accum, linear, base_lr, lr_power, l1, l2): """Sparse Update "variable", "accum", "linear" based on sparse "gradients". See the description in _Update. Args: variable: A Variable. gradients: A Sparse Tensor accum: A Variable containing the sum of the squares of gradients. linear: A Variable containing approximation info. base_lr: A constant represents base learning rate. lr_power: A constant is used to adjust learning rate. l1: A constant represents l1 regularization strength. l2: A constant represents l2 regularization strength. Returns: A group op including three ScatterUpdate ops: 1. ScatterUpdate for "accum" 2. ScatterUpdate for "linear" 3. ScatterUpdate for "variable" """ assert isinstance(gradients, ops.IndexedSlices) with ops.name_scope("sparse_update_" + variable.op.name) as scope: dtype = variable.dtype.base_dtype base_lr = ops.convert_to_tensor(base_lr, dtype=dtype) lr_power = ops.convert_to_tensor(lr_power, dtype=dtype) l1 = ops.convert_to_tensor(l1, dtype=dtype) l2 = ops.convert_to_tensor(l2, dtype=dtype) # Compute the new value for the accumulator previous_accum = array_ops.gather(accum, gradients.indices) sqr_grad = gradients.values * gradients.values accum_updated = sqr_grad + previous_accum # Compute the new linear neg_lr_power = math_ops.neg(lr_power) sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow( previous_accum, neg_lr_power) sigma /= base_lr variable_slice = array_ops.gather(variable, gradients.indices) proximal_adjust = sigma * variable_slice linear_slice = array_ops.gather(linear, gradients.indices) linear_updated = linear_slice + gradients.values - proximal_adjust # Compute the new "variable" variable_updated = _Compute(accum_updated, linear_updated, base_lr, lr_power, l1, l2) with ops.control_dependencies([sigma]): accum_update_op = state_ops.scatter_update(accum, gradients.indices, accum_updated) linear_update_op = state_ops.scatter_update(linear, gradients.indices, linear_updated) variable_update_op = state_ops.scatter_update(variable, gradients.indices, variable_updated) group_op = control_flow_ops.group(linear_update_op, accum_update_op, variable_update_op, name=scope) return group_op
def testScatterBool(self): with context.eager_mode(): ref = resource_variable_ops.ResourceVariable([False, True, False], trainable=False) indices = math_ops.range(3) updates = constant_op.constant([True, True, True]) state_ops.scatter_update(ref, indices, updates) self.assertAllEqual(ref.read_value(), [True, True, True])
def testScatterBool(self): with context.eager_mode(): ref = resource_variable_ops.ResourceVariable( [False, True, False], trainable=False) indices = math_ops.range(3) updates = constant_op.constant([True, True, True]) state_ops.scatter_update(ref, indices, updates) self.assertAllEqual(ref.read_value(), [True, True, True])
def shortlist_insert(): larger_ids = array_ops.boolean_mask( math_ops.to_int64(ids), larger_scores) larger_score_values = array_ops.boolean_mask(scores, larger_scores) shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores) return control_flow_ops.group(u1, u2)
def testBooleanScatterUpdate(self): if not test.is_gpu_available(): with self.session(use_gpu=False) as session: var = variables.Variable([True, False]) update0 = state_ops.scatter_update(var, 1, True) update1 = state_ops.scatter_update( var, constant_op.constant(0, dtype=dtypes.int64), False) var.initializer.run() session.run([update0, update1]) self.assertAllEqual([False, True], self.evaluate(var))
def _apply_sparse(self, grad, var): t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t) momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power) momentum_cache_t_1 = beta1_t * ( 1. - 0.5 * momentum_cache_power * self._momentum_cache_const) m_schedule_new = m_schedule * momentum_cache_t m_schedule_next = m_schedule_new * momentum_cache_t_1 # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update( m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) g_prime_slice = grad.values / (1. - m_schedule_new) m_t_prime_slice = array_ops.gather( m_t, grad.indices) / (1. - m_schedule_next) m_t_bar_slice = ( 1. - momentum_cache_t ) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update( v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_t_prime_slice = array_ops.gather( v_t, grad.indices) / (1. - tf.pow(beta2_t, t)) var_update = state_ops.scatter_sub( var, grad.indices, lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def testBooleanScatterUpdate(self): if not test.is_gpu_available(): with self.test_session(use_gpu=False) as session: var = variables.Variable([True, False]) update0 = state_ops.scatter_update(var, 1, True) update1 = state_ops.scatter_update( var, constant_op.constant( 0, dtype=dtypes.int64), False) var.initializer.run() session.run([update0, update1]) self.assertAllEqual([False, True], var.eval())
def shortlist_insert(): larger_ids = array_ops.boolean_mask(math_ops.to_int64(ids), larger_scores) larger_score_values = array_ops.boolean_mask( scores, larger_scores) shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores) return control_flow_ops.group(u1, u2)
def _apply_sparse(self, grad, var): return self._apply_sparse_shared( grad.values, var, grad.indices, lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda x, i, v, use_locking=self._use_locking), lambda x, i, v: state_ops.scatter_update( # pylint: disable=g-long-lambda x, i, v, use_locking=self._use_locking))
def _sparse_moving_average(self, x_tm1, idxs, b_t_, name, beta=.9): """ Creates a moving average for a sparse variable. Inputs: x_tm1: the associated parameter (e.g. a weight matrix) idxs: the tensor representing the indices used b_t_: the value to accumulate (e.g. slices of the gradient) name: a string to use to retrieve it later (e.g. 'm') beta: the decay factor (defaults to .9) Outputs: a_t: the average after moving (same shape as x_tm1, not b_t_) t: the internal timestep (used to correct initialization bias) """ a_tm1 = self._zeros_slot(x_tm1, '%s' % name, self._name) a_tm1_ = array_ops.gather(a_tm1, idxs) tm1 = self._zeros_idx_slot(x_tm1, '%s/tm1' % name, self._name) tm1_ = array_ops.gather(tm1, idxs) t = state_ops.scatter_add(tm1, idxs, tm1_*0+1, use_locking=self._use_locking) t_ = array_ops.gather(t, idxs) if beta < 1: beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name) beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_) else: beta_t_ = tm1_/t_ a_t = state_ops.scatter_update(a_tm1, idxs, beta_t_*a_tm1_, use_locking=self._use_locking) a_t = state_ops.scatter_add(a_t, idxs, (1-beta_t)*b_t_, use_locking=self._use_locking) return a_t, t
def _apply_sparse(self, grad, var): if len(grad.indices.get_shape()) == 1: grad_indices = grad.indices grad_values = grad.values else: grad_indices = array_ops.reshape(grad.indices, [-1]) grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value]) gidxs, metagidxs = array_ops.unique(grad_indices) sizegidxs = array_ops.size(gidxs) gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs) # m_t = mu * m + (1 - mu) * g_t m = self.get_slot(var, "m") m_scaled_g_values = gvals * (1 - self._mu_t) m_t = state_ops.scatter_update(m, gidxs, array_ops.gather(m, gidxs) * self._mu_t, use_locking=self._use_locking) m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values, use_locking=self._use_locking) m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power) # m_bar = mu * m_t + (1 - mu) * g_t m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power) var_update = state_ops.scatter_sub(var, gidxs, self._lr_t * m_bar, use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t])
def insert(self, ids, scores): """Insert the ids and scores into the TopN.""" with ops.control_dependencies(self.last_ops): scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores) larger_scores = math_ops.greater(scores, self.sl_scores[0]) def shortlist_insert(): larger_ids = array_ops.boolean_mask(math_ops.to_int64(ids), larger_scores) larger_score_values = array_ops.boolean_mask( scores, larger_scores) shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores) return control_flow_ops.group(u1, u2) # We only need to insert into the shortlist if there are any # scores larger than the threshold. cond_op = control_flow_ops.cond(math_ops.reduce_any(larger_scores), shortlist_insert, control_flow_ops.no_op) with ops.control_dependencies([cond_op]): self.last_ops = [scatter_op, cond_op]
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) clip_multiplier_t = math_ops.cast(self.clip_multiplier_t, var.dtype.base_dtype) clip_epsilon_t = math_ops.cast(self.clip_epsilon_t, var.dtype.base_dtype) v = self.get_slot(var, "v") v_slice = array_ops.gather(v, grad.indices) #clip gradient so that each value exceeds its previous maximum by no more than clip_multiplier clipped_values = grad.values if self.clip_gradients: clipVal = v_slice * clip_multiplier_t + clip_epsilon_t clipped_values = clip_ops.clip_by_value(grad.values, -clipVal, clipVal) # m := beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t_values = beta1_t * array_ops.gather( m, grad.indices) + (1 - beta1_t) * clipped_values m_t = state_ops.scatter_update(m, grad.indices, m_t_values, use_locking=self._use_locking) # v := max(beta2 * v , abs(grad)) v_t_values = math_ops.maximum(beta2_t * v_slice, math_ops.abs(clipped_values)) v_t = state_ops.scatter_update(v, grad.indices, v_t_values, use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + v_t) # we do not use bias-correction term for the first moment; it does not give observable benefit var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_values / (v_t_values + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(var_update, v_t, m_t)
def create_axis_ops(sp_input, num_items, update_fn, axis_name): """Creates book-keeping and training ops for a given axis. Args: sp_input: A SparseTensor corresponding to the row or column batch. num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of * new_factors: A flot Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. * sum_weights: A float Tensor, the sum of factor weights. axis_name: A string that specifies the name of the axis. Returns: A tuple consisting of: * reset_processed_items_op: A TensorFlow op, to be run before the beginning of any sweep. It marks all items as not-processed. * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. """ processed_items_init = array_ops.fill(dims=[num_items], value=False) with ops.colocate_with(processed_items_init): processed_items = variable_scope.variable( processed_items_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="processed_" + axis_name) reset_processed_items_op = state_ops.assign( processed_items, processed_items_init, name="reset_processed_" + axis_name) _, update_op, loss, reg, sum_weights = update_fn(sp_input) input_indices = sp_input.indices[:, 0] with ops.control_dependencies([ update_op, state_ops.assign(loss_var, loss + reg), state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights)) ]): with ops.colocate_with(processed_items): update_processed_items = state_ops.scatter_update( processed_items, input_indices, array_ops.ones_like(input_indices, dtype=dtypes.bool), name="update_processed_{}_indices".format(axis_name)) with ops.control_dependencies([update_processed_items]): is_sweep_done = math_ops.reduce_all(processed_items) axis_train_op = control_flow_ops.group( global_step_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done), state_ops.assign_add( completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32)), name="{}_sweep_train_op".format(axis_name)) return reset_processed_items_op, axis_train_op
def create_axis_ops(sp_input, num_items, update_fn, axis_name): """Creates book-keeping and training ops for a given axis. Args: sp_input: A SparseTensor corresponding to the row or column batch. num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of * new_factors: A flot Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. * sum_weights: A float Tensor, the sum of factor weights. axis_name: A string that specifies the name of the axis. Returns: A tuple consisting of: * reset_processed_items_op: A TensorFlow op, to be run before the beginning of any sweep. It marks all items as not-processed. * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. """ processed_items_init = array_ops.fill(dims=[num_items], value=False) with ops.colocate_with(processed_items_init): processed_items = variable_scope.variable( processed_items_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="processed_" + axis_name) reset_processed_items_op = state_ops.assign( processed_items, processed_items_init, name="reset_processed_" + axis_name) _, update_op, loss, reg, sum_weights = update_fn(sp_input) input_indices = sp_input.indices[:, 0] with ops.control_dependencies([ update_op, state_ops.assign(loss_var, loss + reg), state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): with ops.colocate_with(processed_items): update_processed_items = state_ops.scatter_update( processed_items, input_indices, array_ops.ones_like(input_indices, dtype=dtypes.bool), name="update_processed_{}_indices".format(axis_name)) with ops.control_dependencies([update_processed_items]): is_sweep_done = math_ops.reduce_all(processed_items) axis_train_op = control_flow_ops.group( global_step_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done), state_ops.assign_add( completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32)), name="{}_sweep_train_op".format(axis_name)) return reset_processed_items_op, axis_train_op
def _apply_sparse(self, grad, var): t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t) momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power) momentum_cache_t_1 = beta1_t * (1. - 0.5 * momentum_cache_power * self._momentum_cache_const) m_schedule_new = m_schedule * momentum_cache_t m_schedule_next = m_schedule_new * momentum_cache_t_1 # the following equations given in [1] # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update(m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (1. - beta1_t) * grad.values, use_locking=self._use_locking) g_prime_slice = grad.values / (1. - m_schedule_new) m_t_prime_slice = array_ops.gather(m_t, grad.indices) / (1. - m_schedule_next) m_t_bar_slice = (1. - momentum_cache_t) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update(v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1. - beta2_t) * tf.square(grad.values), use_locking=self._use_locking) v_t_prime_slice = array_ops.gather(v_t, grad.indices) / (1. - tf.pow(beta2_t, t)) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var): beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) tp = math_ops.sqrt(1 - beta2_power) / (1 - beta1_power) tp1 = math_ops.sqrt(1 - beta2_power * beta2_t) / ( 1 - beta1_power * beta1_t) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_t = state_ops.scatter_update( m, grad.indices, beta1_t * array_ops.gather(m, grad.indices) + (tp1 / tp) * (1 - beta1_t) * grad.values, use_locking=self._use_locking) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_t = state_ops.scatter_update( v, grad.indices, beta2_t * array_ops.gather(v, grad.indices) + (1 - beta2_t) * math_ops.square(grad.values), use_locking=self._use_locking) # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t)) m_t_slice = array_ops.gather(m_t, grad.indices) v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = state_ops.scatter_sub(var, grad.indices, lr_t * tp * m_t_slice / denominator_slice, use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t)
def scatter_update(cls, factor, indices, values, sharding_func): """Helper function for doing sharded scatter update.""" assert isinstance(factor, list) if len(factor) == 1: with ops.colocate_with(factor[0]): # TODO(agarwal): assign instead of scatter update for full batch update. return state_ops.scatter_update(factor[0], indices, values).op else: num_shards = len(factor) assignments, new_ids = sharding_func(indices) assert assignments is not None assignments = math_ops.cast(assignments, dtypes.int32) sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments, num_shards) sharded_values = data_flow_ops.dynamic_partition(values, assignments, num_shards) updates = [] for i in xrange(num_shards): updates.append( state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[ i])) return control_flow_ops.group(*updates)
def scatter_update(cls, factor, indices, values, sharding_func, name=None): """Helper function for doing sharded scatter update.""" assert isinstance(factor, list) if len(factor) == 1: with ops.colocate_with(factor[0]): # TODO(agarwal): assign instead of scatter update for full batch update. return state_ops.scatter_update(factor[0], indices, values, name=name).op else: num_shards = len(factor) assignments, new_ids = sharding_func(indices) assert assignments is not None assignments = math_ops.cast(assignments, dtypes.int32) sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments, num_shards) sharded_values = data_flow_ops.dynamic_partition(values, assignments, num_shards) updates = [] for i in xrange(num_shards): updates.append(state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[i])) return control_flow_ops.group(*updates, name=name)
def _resource_apply_dense(self, grad, var, state): self._variables.append(var) dim = self.shape_dict[var.name] start_index = self.index_dict[var.name] end_index = start_index + dim # Update flat_gradient at the index associated with the variable. flat_grad = self._get_flat_grad(state) new_flat_grad = array_ops.reshape(grad, [-1]) flat_grad_updated = state_ops.scatter_update( flat_grad, math_ops.range(start_index, end_index), new_flat_grad) return flat_grad_updated
def remove(self, ids): """Remove the ids (and their associated scores) from the TopN.""" with ops.control_dependencies(self.last_ops): scatter_op = state_ops.scatter_update( self.id_to_score, ids, array_ops.ones_like(ids, dtype=dtypes.float32) * dtypes.float32.min) # We assume that removed ids are almost always in the shortlist, # so it makes no sense to hide the Op behind a tf.cond shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove( self.sl_ids, ids) u1 = state_ops.scatter_update( self.sl_ids, array_ops.concat([[0], shortlist_ids_to_remove], 0), array_ops.concat([ new_length, array_ops.ones_like(shortlist_ids_to_remove) * -1 ], 0)) u2 = state_ops.scatter_update( self.sl_scores, shortlist_ids_to_remove, dtypes.float32.min * array_ops.ones_like( shortlist_ids_to_remove, dtype=dtypes.float32)) self.last_ops = [scatter_op, u1, u2]
def _apply_sparse_shared(self, grad, var, indices, scatter_add, state): beta1_power, beta2_power = self._get_beta_accumulators(state) beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype) beta1_t = state.get_hyper("beta1", var.dtype.base_dtype) beta2_t = state.get_hyper("beta2", var.dtype.base_dtype) epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # lazy Adam m = state.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.scatter_update(m, indices, beta1_t * array_ops.gather(m, indices) + m_scaled_g_values, use_locking=self._use_locking) m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # lazy Adam v = state.get_slot(var, "v") v_t = state_ops.scatter_update(v, indices, beta2_t * array_ops.gather(v, indices) + (1 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) # lazy Adam m_bar_slice = array_ops.gather(m_bar, indices) v_t_slice = array_ops.gather(v_t, indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t var_update = scatter_add(var, indices, -lr * m_bar_slice / denominator_slice) return control_flow_ops.group(*[var_update, m_bar, v_t])
def remove(self, ids): """Remove the ids (and their associated scores) from the TopN.""" with ops.control_dependencies(self.last_ops): scatter_op = state_ops.scatter_update( self.id_to_score, ids, array_ops.ones_like( ids, dtype=dtypes.float32) * dtypes.float32.min) # We assume that removed ids are almost always in the shortlist, # so it makes no sense to hide the Op behind a tf.cond shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove( self.sl_ids, ids) u1 = state_ops.scatter_update( self.sl_ids, array_ops.concat([[0], shortlist_ids_to_remove], 0), array_ops.concat( [new_length, array_ops.ones_like(shortlist_ids_to_remove) * -1], 0)) u2 = state_ops.scatter_update( self.sl_scores, shortlist_ids_to_remove, dtypes.float32.min * array_ops.ones_like( shortlist_ids_to_remove, dtype=dtypes.float32)) self.last_ops = [scatter_op, u1, u2]
def _apply_sparse(self, grad, var): lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype) beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype) eps = 1e-7 # cap for moving average m = self.get_slot(var, "m") m_slice = tf.gather(m, grad.indices) m_t = state_ops.scatter_update(m, grad.indices, tf.maximum(beta_t * m_slice + eps, tf.abs(grad.values))) m_t_slice = tf.gather(m_t, grad.indices) var_update = state_ops.scatter_sub(var, grad.indices, lr_t * grad.values * tf.exp( tf.log(alpha_t) * tf.sign(grad.values) * tf.sign(m_t_slice))) # Update 'ref' by subtracting 'value # Create an op that groups multiple operations. # When this op finishes, all ops in input have finished return control_flow_ops.group(*[var_update, m_t])
def testResourceVariableScatterGather(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l = list_ops.tensor_list_from_tensor(c, element_shape=[]) v = vs.get_variable("var", initializer=[l] * 10, use_resource=True) v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32) self.evaluate(v.initializer) self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked)) v_r_sparse_stacked = list_ops.tensor_list_stack( v.sparse_read(0), dtypes.float32) self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked)) l_new_0 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[]) l_new_1 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[]) updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1]) updated_v_elems = array_ops.unstack(updated_v) updated_v_stacked = [ list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems ] expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] + [[1.0, 2.0]] * 4) self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
def ScatterUpdateGrads(op, grad): var, indices, updates = op.inputs updates_grad = array_ops.gather(grad, indices) # dynamic stitch approach (this seems to be a bit slower) # grad_range = math_ops.range(grad.get_shape()[0].value) # var_grad = data_flow_ops.dynamic_stitch( # [grad_range, indices], # [grad, array_ops.zeros(updates.get_shape())]) if isinstance(grad, ops.IndexedSlices): # note: we could use this approach for everything, but the # temporary variable approach seems to be slightly faster (but we # can't use that on indexedslices) var_grad = grad - array_ops.scatter_nd( array_ops.expand_dims(indices, 1), updates_grad, var.get_shape()) else: shape = tuple(grad.get_shape().as_list()) dtype = grad.dtype.base_dtype with variable_scope.variable_scope( "gradient_vars", reuse=variable_scope.AUTO_REUSE): var_grad = variable_scope.get_variable( "tmp" + "_%s" * (len(grad.get_shape()) + 1) % ( shape + (dtype.name,)), shape=shape, dtype=dtype, trainable=False, collections=["gradient_vars"]) var_grad = state_ops.assign(var_grad, grad) var_grad = state_ops.scatter_update( var_grad, indices, array_ops.zeros_like(updates)) # we need to force a copy so that any future assignments to the # variable will not affect the value we return here # TODO: check if this is still necessary in TensorFlow 2.0 var_grad = var_grad + 0 return var_grad, None, updates_grad
def insert(self, ids, scores): """Insert the ids and scores into the TopN.""" with ops.control_dependencies(self.last_ops): scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores) larger_scores = math_ops.greater(scores, self.sl_scores[0]) def shortlist_insert(): larger_ids = array_ops.boolean_mask( math_ops.to_int64(ids), larger_scores) larger_score_values = array_ops.boolean_mask(scores, larger_scores) shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores) return control_flow_ops.group(u1, u2) # We only need to insert into the shortlist if there are any # scores larger than the threshold. cond_op = control_flow_ops.cond( math_ops.reduce_any(larger_scores), shortlist_insert, control_flow_ops.no_op) with ops.control_dependencies([cond_op]): self.last_ops = [scatter_op, cond_op]
def ScatterUpdateGrads(op, grad): var, indices, updates = op.inputs updates_grad = array_ops.gather(grad, indices) # TODO: the dynamic_stitch approach might be faster if there were # a GPU dynamic_stitch implementation. should be available in tf 1.4 # grad_range = math_ops.range(grad.get_shape()[0].value) # var_grad = data_flow_ops.dynamic_stitch( # [grad_range, indices], # [grad, array_ops.zeros(updates.get_shape())]) if isinstance(grad, ops.IndexedSlices): # note: we could use this approach for everything, but the # temporary variable approach seems to be slightly faster (but we # can't use that on indexedslices) var_grad = grad - array_ops.scatter_nd( array_ops.expand_dims(indices, 1), updates_grad, var.get_shape()) else: # pylint: disable=no-member if versions.__version__ < "1.7.0": temp_var = gen_state_ops._temporary_variable destroy_temp_var = gen_state_ops._destroy_temporary_variable else: temp_var = gen_state_ops.temporary_variable destroy_temp_var = gen_state_ops.destroy_temporary_variable var_grad = temp_var(grad.get_shape(), grad.dtype) var_name = var_grad.op.name var_grad = state_ops.assign(var_grad, grad) var_grad = state_ops.scatter_update(var_grad, indices, array_ops.zeros_like(updates)) var_grad = destroy_temp_var(var_grad, var_name) return var_grad, None, updates_grad
def get_col_update_op(): with ops.colocate_with(processed_cols): return state_ops.scatter_update( processed_cols, processed_col_indices, array_ops.ones_like(processed_col_indices, dtype=dtypes.bool))
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] inputs = [] var_list = [] for x in grads_and_vars: inputs.extend(list(x)) with ops.device(global_step.device): self._local_steps = variables.Variable( array_ops.zeros( [self._total_num_replicas], dtype=global_step.dtype), trainable=False, name="local_steps") # Check staleness. Note that this has to be ref(), otherwise identity will # be accessed and it will be old values. local_step = array_ops.slice(self._local_steps.ref(), array_ops.reshape(self._replica_id, (1,)), [1], name="get_local_step") local_step = array_ops.reshape(local_step, ()) is_stale = math_ops.less(local_step, global_step) with ops.name_scope(None, self._name, inputs): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): if isinstance(grad, ops.Tensor): gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2, grad.dtype, shapes=var.get_shape(), shared_name=var.name)) self._one_element_queue_list.append((gradient_queue, var.device)) train_ops.append(gradient_queue.enqueue([grad])) # Aggregate all gradients gradients = gradient_queue.dequeue_many( self._replicas_to_aggregate) aggregated_grad.append(math_ops.reduce_sum(gradients, [0])) elif grad is None: aggregated_grad.append(None) # pass-through. else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") aggregated_grad.append(self._aggregate_sparse_grad(grad, var, train_ops)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), shared_name="dummy_queue")) # Clear all the gradients queues in case there are stale gradients. clear_queue_ops = [] with ops.control_dependencies([update_op]): for queue, dev in self._one_element_queue_list: with ops.device(dev): stale_grads = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_grads) for queue, dev in self._sparse_grad_queues_and_devs: with ops.device(dev): _, stale_indices = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_indices) with ops.device(global_step.device): self._clean_up_op = control_flow_ops.abort( error_msg="From sync_replicas") # According to the staleness, select between the enqueue op (real_grad) # or no-op (no_op_grad). Effectively dropping all the stale gradients. no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")] real_grad = lambda: [control_flow_ops.group(*train_ops)] final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies([final_train_ops]): token = sync_token_queue.dequeue() train_op = state_ops.scatter_update(self._local_steps, self._replica_id, token) with ops.control_dependencies(clear_queue_ops): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. # Note that ref() is used to avoid reading from the identity with old # the step. tokens = array_ops.fill([self._tokens_per_step], global_step.ref()) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) self._gradients_applied = True return train_op
def testScatterUpdateCast(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy())
def _finish(self, update_ops, name_scope): """""" caches = [update_op[0] for update_op in update_ops] update_ops = [update_op[1:] for update_op in update_ops] if self._noise is not None: for cache in caches: s_t, x_tm1 = cache[:2] s_t += random_ops.random_normal( x_tm1.initialized_value().get_shape(), stddev=self._noise) cache[0] = s_t if self._clip > 0: S_t = [cache[0] for cache in caches] S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip) for cache, s_t in zip(caches, S_t): cache[0] = s_t new_update_ops = [] for cache, update_op in zip(caches, update_ops): if len(cache) == 3: s_t, x_tm1 = cache[:2] with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t = state_ops.assign_sub(x_tm1, s_t, use_locking=self._use_locking) cache.append(x_t) else: s_t_, x_tm1, idxs = cache[:3] with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t = state_ops.scatter_sub(x_tm1, idxs, s_t_, use_locking=self._use_locking) cache.append(x_t) new_update_ops.append(control_flow_ops.group(*([x_t] + update_op))) with ops.control_dependencies(new_update_ops): more_update_ops = [] if self._save_step: for cache in caches: if len(cache) == 4: s_t, x_tm1 = cache[:2] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): new_step_and_grads = [] s_t = state_ops.assign( s_tm1, -s_t, use_locking=self._use_locking) else: s_t_, x_tm1, idxs = cache[:3] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): s_t = state_ops.scatter_update( s_tm1, idxs, -s_t_, use_locking=self._use_locking) more_update_ops.append(s_t) if self._save_grad: for cache in caches: if len(cache) == 4: x_tm1, g_t = cache[1:3] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): new_step_and_grads = [] g_t = state_ops.assign( g_tm1, g_t, use_locking=self._use_locking) else: x_tm1, idxs, g_t_ = cache[1:4] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): g_t = state_ops.scatter_update( g_tm1, idxs, g_t_, use_locking=self._use_locking) more_update_ops.append(g_t) if self._chi > 0: for cache in caches: if len(cache) == 4: _, x_tm1, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_and_t = self._dense_moving_average( x_tm1, x_t, 'x', self._chi) more_update_ops.append( control_flow_ops.group(*x_and_t)) else: _, x_tm1, idxs, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device( x_tm1.device): x_t_ = array_ops.gather(x_t, idxs) x_and_t = self._sparse_moving_average( x_tm1, idxs, x_t_, 'x', self._chi) more_update_ops.append( control_flow_ops.group(*x_and_t)) return control_flow_ops.group(*(new_update_ops + more_update_ops), name=name_scope)
def training_graph(self, input_data, input_labels, random_seed, data_spec, epoch=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A list of tf.dtype values specifying the original types of each column. epoch: A tensor or placeholder for the epoch the training data comes from. Returns: The last op in the random tree training graph. """ epoch = [0] if epoch is None else epoch sparse_indices = [] sparse_values = [] sparse_shape = [] if isinstance(input_data, ops.SparseTensor): sparse_indices = input_data.indices sparse_values = input_data.values sparse_shape = input_data.shape input_data = [] # Count extremely random stats. ( node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves, ) = self.training_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, data_spec, input_labels, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression, ) node_update_ops = [] node_update_ops.append(state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.candidate_split_sums, splits_indices, splits_sums) ) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums) ) if self.params.regression: node_update_ops.append(state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares ) ) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_squares, totals_indices, totals_squares) ) # Sample inputs. update_indices, feature_updates, threshold_updates = self.training_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=(self.params.split_initializations_per_input), split_sampling_random_seed=random_seed, ) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates ) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates ) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): children = array_ops.squeeze(array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) finished, stale = self.training_ops.finished_nodes( leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples, ) # Update leaf scores. non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less(array_ops.gather(self.variables.node_to_accumulator_map, leaves), 0) ) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression, ) # Grow tree. with ops.control_dependencies([update_features_op, update_thresholds_op]): ( tree_update_indices, tree_children_updates, tree_threshold_updates, tree_depth_updates, new_eot, ) = self.training_ops.grow_tree( self.variables.end_of_tree, self.variables.tree_depths, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, ) tree_update_op = state_ops.scatter_update(self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates ) depth_update_op = state_ops.scatter_update( self.variables.tree_depths, tree_update_indices, tree_depth_updates ) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates ) # Update fertile slots. with ops.control_dependencies([depth_update_op]): (node_map_updates, accumulators_cleared, accumulators_allocated) = self.training_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.tree_depths, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, max_depth=self.params.max_depth, regression=self.params.regression, ) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. gated_new_eot, = control_flow_ops.tuple([new_eot], control_inputs=[node_map_updates]) eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update( self.variables.node_to_accumulator_map, array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]), squeeze_dims=[0]), array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]), squeeze_dims=[0]), ) ) cleared_and_allocated_accumulators = array_ops.concat(0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims( array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1 ), 2, ), [1, self.params.num_splits_to_consider, self.params.num_output_columns], ) updates.append( state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values ) ) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values ) ) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims(math_ops.neg(array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns], ) total_reset = array_ops.tile( array_ops.expand_dims(array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns], ) accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) updates.append( state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates ) ) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates ) ) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims(math_ops.neg(array_ops.ones_like(cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider], ) updates.append( state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates ) ) updates += self.finish_iteration() return control_flow_ops.group(*updates)
def _finish(self, update_ops, name_scope): """""" caches = [update_op[0] for update_op in update_ops] update_ops = [update_op[1:] for update_op in update_ops] if self._noise is not None: for cache in caches: s_t, x_tm1 = cache[:2] s_t += random_ops.random_normal(x_tm1.initialized_value().get_shape(), stddev=self._noise) cache[0] = s_t if self._clip is not None: S_t = [cache[0] for cache in caches] S_t, _ = clip_ops.clip_by_global_norm(S_t, self._clip) for cache, s_t in zip(caches, S_t): cache[0] = s_t new_update_ops = [] for cache, update_op in zip(caches, update_ops): if len(cache) == 3: s_t, x_tm1 = cache[:2] with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t = state_ops.assign_sub(x_tm1, s_t, use_locking=self._use_locking) cache.append(x_t) else: s_t_, x_tm1, idxs = cache[:3] with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t = state_ops.scatter_sub(x_tm1, idxs, s_t_, use_locking=self._use_locking) cache.append(x_t) new_update_ops.append(control_flow_ops.group(*([x_t] + update_op))) with ops.control_dependencies(new_update_ops): more_update_ops = [] if self._save_step: for cache in caches: if len(cache) == 4: s_t, x_tm1 = cache[:2] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): new_step_and_grads = [] s_t = state_ops.assign(s_tm1, -s_t, use_locking=self._use_locking) else: s_t_, x_tm1, idxs = cache[:3] s_tm1 = self.get_slot(x_tm1, 's') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): s_t = state_ops.scatter_update(s_tm1, idxs, -s_t_, use_locking=self._use_locking) more_update_ops.append(s_t) if self._save_grad: for cache in caches: if len(cache) == 4: x_tm1, g_t = cache[1:3] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): new_step_and_grads = [] g_t = state_ops.assign(g_tm1, g_t, use_locking=self._use_locking) else: x_tm1, idxs, g_t_ = cache[1:4] g_tm1 = self.get_slot(x_tm1, 'g') with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): g_t = state_ops.scatter_update(g_tm1, idxs, g_t_, use_locking=self._use_locking) more_update_ops.append(g_t) if self._chi > 0: for cache in caches: if len(cache) == 4: _, x_tm1, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_and_t = self._dense_moving_average(x_tm1, x_t, 'x', self._chi) more_update_ops.append(control_flow_ops.group(*x_and_t)) else: _, x_tm1, idxs, _, x_t = cache with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device): x_t_ = array_ops.gather(x_t, idxs) x_and_t = self._sparse_moving_average(x_tm1, idxs, x_t_, 'x', self._chi) more_update_ops.append(control_flow_ops.group(*x_and_t)) return control_flow_ops.group(*(new_update_ops + more_update_ops), name=name_scope)
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] inputs = [] var_list = [] for x in grads_and_vars: inputs.extend(list(x)) with ops.device(global_step.device): self._local_steps = variables.Variable( array_ops.zeros( [self._total_num_replicas], dtype=global_step.dtype), trainable=False, name="local_steps") # Check staleness. Note that this has to be ref(), otherwise identity will # be accessed and it will be old values. local_step = array_ops.slice(self._local_steps.ref(), array_ops.reshape(self._replica_id, (1,)), [1], name="get_local_step") local_step = array_ops.reshape(local_step, ()) is_stale = math_ops.less(local_step, global_step) with ops.op_scope(inputs, None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): if isinstance(grad, ops.Tensor): gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2, grad.dtype, shapes=var.get_shape(), shared_name=var.name)) self._one_element_queue_list.append((gradient_queue, var.device)) train_ops.append(gradient_queue.enqueue([grad])) # Aggregate all gradients gradients = gradient_queue.dequeue_many( self._replicas_to_aggregate) aggregated_grad.append(math_ops.reduce_sum(gradients, [0])) elif grad is None: aggregated_grad.append(None) # pass-through. else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") aggregated_grad.append(self._aggregate_sparse_grad(grad, var, train_ops)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), shared_name="dummy_queue")) # Clear all the gradients queues in case there are stale gradients. clear_queue_ops = [] with ops.control_dependencies([update_op]): for queue, dev in self._one_element_queue_list: with ops.device(dev): stale_grads = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_grads) for queue, dev in self._sparse_grad_queues_and_devs: with ops.device(dev): _, stale_indices = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_indices) with ops.device(global_step.device): self._clean_up_op = control_flow_ops.abort( error_msg="From sync_replicas") # According to the staleness, select between the enqueue op (real_grad) # or no-op (no_op_grad). Effectively dropping all the stale gradients. no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")] real_grad = lambda: [control_flow_ops.group(*train_ops)] final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies([final_train_ops]): token = sync_token_queue.dequeue() train_op = state_ops.scatter_update(self._local_steps, self._replica_id, token) with ops.control_dependencies(clear_queue_ops): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. # Note that ref() is used to avoid reading from the identity with old # the step. tokens = array_ops.fill([self._tokens_per_step], global_step.ref()) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) self._gradients_applied = True return train_op
def _apply_sparse_shared(self, grad, var, indices, scatter_add, state): # todo use nesterov here beta1_power, beta2_power = self._get_beta_accumulators(state) beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype) beta1_t = state.get_hyper("beta1", var.dtype.base_dtype) beta2_t = state.get_hyper("beta2", var.dtype.base_dtype) epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t # default Adam # m_scaled_g_values = grad * (1 - beta1_t) # m_t = state_ops.assign(m, m * beta1_t, # use_locking=self._use_locking) # with ops.control_dependencies([m_t]): # m_t = scatter_add(m, indices, m_scaled_g_values) # # w/ Nesterov # m_bar = m_scaled_g_values + beta1_t * m_t # lazy Adam m = state.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.scatter_update(m, indices, beta1_t * array_ops.gather(m, indices) + m_scaled_g_values, use_locking=self._use_locking) # todo could this be better (do just one gather?) m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) # default Adam # v_scaled_g_values = (grad * grad) * (1 - beta2_t) # v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) # with ops.control_dependencies([v_t]): # v_t = scatter_add(v, indices, v_scaled_g_values) # v_sqrt = math_ops.sqrt(v_t) # lazy Adam v = state.get_slot(var, "v") v_t = state_ops.scatter_update(v, indices, beta2_t * array_ops.gather(v, indices) + (1 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) # default Adam # var_update = state_ops.assign_sub(var, # lr * m_t / (v_sqrt + epsilon_t), # use_locking=self._use_locking) # w/ Nesterov # var_update = state_ops.assign_sub(var, # lr * m_bar / (v_sqrt + epsilon_t), # use_locking=self._use_locking) # return control_flow_ops.group(*[var_update, m_bar, v_t]) # lazy Adam # m_t_slice = array_ops.gather(m_t, indices) m_bar_slice = array_ops.gather(m_bar, indices) v_t_slice = array_ops.gather(v_t, indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t # var_update = state_ops.scatter_sub(var, indices, # lr * m_t_slice / denominator_slice, # use_locking=self._use_locking) var_update = scatter_add(var, indices, -lr * m_bar_slice / denominator_slice) # return control_flow_ops.group(*[var_update, m_t, v_t]) return control_flow_ops.group(*[var_update, m_bar, v_t])
def _finish(self, state): var_dtype = self._variables[0].dtype.base_dtype # Update global step. global_step = self._get_global_step(state) update_global_step = state_ops.assign_add(global_step, 1.) # Update the first moment estimate. beta1 = state.get_hyper("beta1", dtype=var_dtype) moment1 = self._get_moment1(state) flat_grad = self._get_flat_grad(state) # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) # Update the gradient buffer. window = state.get_hyper("window") grad_buffer = self._get_grad_buffer(state) next_grad_index = math_ops.floormod( math_ops.to_int32(update_global_step - 1.), window) # grad_buffer[(t-1) % window] := moment1_t update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, update_moment1) # Compute the update step. eps = state.get_hyper("eps", dtype=var_dtype) svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) lr = state.get_hyper("lr", dtype=var_dtype) denom = math_ops.sqrt( math_ops.minimum( ops.convert_to_tensor(update_global_step), ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) moment1_2d = array_ops.expand_dims(update_moment1, -1) # m = grad_buffer^T / sqrt(min(t, window)) # m has shape [model dimension, window], where model dimension is the sum # of the dimensions of the flattened variables. m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) # sigma, u, _ = SVD(m^Tm + I * svd_eps) mm = math_ops.matmul(m, m, transpose_a=True) damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps sigma, u, _ = linalg_ops.svd(mm + damping) sigma_sqrt = math_ops.sqrt(sigma) sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 # We add sigma_eps to alleviate numerical instability. # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. sigma_sqrt_inv = math_ops.divide( math_ops.cast(1.0, dtype=var_dtype), math_ops.pow(sigma_sqrt + sigma_eps, 3)) # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the # inversion of a model dimension by model dimension matrix is needed. To # speed up this computation we calculate the following instead: # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. new_step = array_ops.expand_dims( array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) head = math_ops.matmul( m, math_ops.matmul( u, math_ops.matmul( array_ops.diag(sigma_sqrt_inv), math_ops.matmul(u, math_ops.matmul(m, moment1_2d, transpose_a=True), transpose_a=True)))) # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using # Woodbury's identity. # For full derivation please see paper at # https://arxiv.org/pdf/1806.02958.pdf tail = moment1_2d - math_ops.matmul( m, math_ops.matmul( u, math_ops.matmul( array_ops.diag( math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), sigma)), math_ops.matmul(u, math_ops.matmul( m, moment1_2d, transpose_a=True), transpose_a=True)))) scaled_tail = math_ops.divide(tail, sigma_sqrt_min) update_new_step = control_flow_ops.cond( sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), lambda: math_ops.add(new_step, head)) # Update each variable. update_step = [] for var in self._variables: dim = self.shape_dict[var.name] start_index = self.index_dict[var.name] end_index = start_index + dim var_update_correct_shape = array_ops.reshape( update_new_step[start_index:end_index], var.get_shape()) var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) update_step.append(var_updated) return control_flow_ops.group(update_step)
def training_graph(self, input_data, input_labels, random_seed, data_spec, epoch=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A list of tf.dtype values specifying the original types of each column. epoch: A tensor or placeholder for the epoch the training data comes from. Returns: The last op in the random tree training graph. """ epoch = [0] if epoch is None else epoch sparse_indices = [] sparse_values = [] sparse_shape = [] if isinstance(input_data, ops.SparseTensor): sparse_indices = input_data.indices sparse_values = input_data.values sparse_shape = input_data.shape input_data = [] # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = (self.training_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, data_spec, input_labels, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_sums, splits_indices, splits_sums)) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums)) if self.params.regression: node_update_ops.append( state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.accumulator_squares, totals_indices, totals_squares)) # Sample inputs. update_indices, feature_updates, threshold_updates = ( self.training_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): children = array_ops.squeeze(array_ops.slice( self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32( array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) finished, stale = self.training_ops.finished_nodes( leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples) # Update leaf scores. non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less( array_ops.gather(self.variables.node_to_accumulator_map, leaves), 0)) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression) # Grow tree. with ops.control_dependencies( [update_features_op, update_thresholds_op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, new_eot) = (self.training_ops.grow_tree( self.variables.end_of_tree, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) tree_update_op = state_ops.scatter_update(self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like( tree_threshold_updates, dtype=dtypes.int32) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates) # Update fertile slots. with ops.control_dependencies([tree_update_op]): (node_map_updates, accumulators_cleared, accumulators_allocated) = (self.training_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. gated_new_eot, = control_flow_ops.tuple( [new_eot], control_inputs=[node_map_updates]) eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update( self.variables.node_to_accumulator_map, array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]), squeeze_dims=[0]), array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]), squeeze_dims=[0]))) cleared_and_allocated_accumulators = array_ops.concat( 0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims( array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1), 2), [ 1, self.params.num_splits_to_consider, self.params.num_output_columns ]) updates.append( state_ops.scatter_update(self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims( math_ops.neg( array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) total_reset = array_ops.tile( array_ops.expand_dims( array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) updates.append( state_ops.scatter_update(self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: updates.append( state_ops.scatter_update(self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims( math_ops.neg( array_ops.ones_like(cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) updates.append( state_ops.scatter_update(self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) updates += self.finish_iteration() return control_flow_ops.group(*updates)
def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. Creates two boolean tensors `processed_rows` and `processed_cols`, which keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - When `self._is_row_sweep_var` is True, it sets processed_rows[input_row_indices] to True. - When `self._is_row_sweep_var` is False, it sets processed_cols[input_col_indices] to True. Args: input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. train_ops: A list of ops. The ops created by this function have control dependencies on `train_ops`. Returns: A tuple consisting of: update_op: An op to be run jointly with training. It updates the state and increments counters (global step and completed sweeps). is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is done, i.e. all rows (during a row sweep) or all columns (during a column sweep) have been processed. switch_op: An op to be run in `self.before_run` when the sweep is done. """ processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( processed_rows_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_rows") processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) with ops.colocate_with(processed_cols_init): processed_cols = variable_scope.variable( processed_cols_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") switch_ops = control_flow_ops.group( state_ops.assign(self._is_row_sweep_var, math_ops.logical_not(self._is_row_sweep_var)), state_ops.assign(processed_rows, processed_rows_init), state_ops.assign(processed_cols, processed_cols_init)) is_sweep_done_var = variable_scope.variable( False, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="is_sweep_done") # After running the `train_ops`, updates `processed_rows` or # `processed_cols` tensors, depending on whether this is a row or col sweep. with ops.control_dependencies(train_ops): with ops.colocate_with(processed_rows): update_processed_rows = state_ops.scatter_update( processed_rows, input_row_indices, math_ops.logical_and( self._is_row_sweep_var, array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) with ops.colocate_with(processed_cols): update_processed_cols = state_ops.scatter_update( processed_cols, input_col_indices, math_ops.logical_and( math_ops.logical_not(self._is_row_sweep_var), array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) update_processed_op = control_flow_ops.group( update_processed_rows, update_processed_cols) with ops.control_dependencies([update_processed_op]): is_sweep_done = math_ops.logical_or( math_ops.reduce_all(processed_rows), math_ops.reduce_all(processed_cols)) # Increments global step. global_step = framework_variables.get_global_step() if global_step is not None: global_step_incr_op = state_ops.assign_add( global_step, 1, name="global_step_incr").op else: global_step_incr_op = control_flow_ops.no_op() # Increments completed sweeps. completed_sweeps_incr_op = state_ops.assign_add( self._completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32), use_locking=True).op update_ops = control_flow_ops.group( global_step_incr_op, completed_sweeps_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done)) return update_ops, is_sweep_done_var, switch_ops
def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. Creates two boolean tensors `processed_rows` and `processed_cols`, which keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - When `self._is_row_sweep_var` is True, it sets processed_rows[input_row_indices] to True. - When `self._is_row_sweep_var` is False, it sets processed_cols[input_col_indices] to True. Args: input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. train_ops: A list of ops. The ops created by this function have control dependencies on `train_ops`. Returns: A tuple consisting of: update_op: An op to be run jointly with training. It updates the state and increments counters (global step and completed sweeps). is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is done, i.e. all rows (during a row sweep) or all columns (during a column sweep) have been processed. switch_op: An op to be run in `self.before_run` when the sweep is done. """ processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( processed_rows_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_rows") processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) with ops.colocate_with(processed_cols_init): processed_cols = variable_scope.variable( processed_cols_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") switch_ops = control_flow_ops.group( state_ops.assign( self._is_row_sweep_var, math_ops.logical_not(self._is_row_sweep_var)), state_ops.assign(processed_rows, processed_rows_init), state_ops.assign(processed_cols, processed_cols_init)) is_sweep_done_var = variable_scope.variable( False, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="is_sweep_done") # After running the `train_ops`, updates `processed_rows` or # `processed_cols` tensors, depending on whether this is a row or col sweep. with ops.control_dependencies(train_ops): with ops.colocate_with(processed_rows): update_processed_rows = state_ops.scatter_update( processed_rows, input_row_indices, math_ops.logical_and( self._is_row_sweep_var, array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) with ops.colocate_with(processed_cols): update_processed_cols = state_ops.scatter_update( processed_cols, input_col_indices, math_ops.logical_and( math_ops.logical_not(self._is_row_sweep_var), array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) update_processed_op = control_flow_ops.group( update_processed_rows, update_processed_cols) with ops.control_dependencies([update_processed_op]): is_sweep_done = math_ops.logical_or( math_ops.reduce_all(processed_rows), math_ops.reduce_all(processed_cols)) # Increments global step. global_step = framework_variables.get_global_step() if global_step is not None: global_step_incr_op = state_ops.assign_add( global_step, 1, name="global_step_incr").op else: global_step_incr_op = control_flow_ops.no_op() # Increments completed sweeps. completed_sweeps_incr_op = state_ops.assign_add( self._completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32), use_locking=True).op update_ops = control_flow_ops.group( global_step_incr_op, completed_sweeps_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done)) return update_ops, is_sweep_done_var, switch_ops
def _init_tree(): return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op
def _scatter_update(self, x, i, v): return state_ops.scatter_update( x, i, v, use_locking=self._use_locking)
def training_graph(self, input_data, input_labels, random_seed, data_spec, sparse_features=None, input_weights=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A data_ops.TensorForestDataSpec object specifying the original feature/columns of the data. sparse_features: A tf.SparseTensor for sparse input data. input_weights: A float tensor or placeholder holding per-input weights, or None if all inputs are to be weighted equally. Returns: The last op in the random tree training graph. """ epoch = math_ops.to_int32(get_epoch_variable()) serialized_input_spec = data_spec.SerializeToString() if input_weights is None: input_weights = [] if input_data is None: input_data = [] sparse_indices = [] sparse_values = [] sparse_shape = [] if sparse_features is not None: sparse_indices = sparse_features.indices sparse_values = sparse_features.values sparse_shape = sparse_features.dense_shape # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = (tensor_forest_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, input_labels, input_weights, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, input_spec=serialized_input_spec, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.candidate_split_sums, splits_indices, splits_sums)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums)) if self.params.regression: node_update_ops.append(state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_squares, totals_indices, totals_squares)) # Sample inputs. update_indices, feature_updates, threshold_updates = ( tensor_forest_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, input_weights, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, input_spec=serialized_input_spec, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): # Passing input_leaves to finished nodes here means that nodes that # have become stale won't be deallocated until an input reaches them, # because we're trying to avoid considering every fertile node for # performance reasons. finished, stale = tensor_forest_ops.finished_nodes( input_leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples, dominate_method=self.params.dominate_method, dominate_fraction=self.params.dominate_fraction) # Update leaf scores. # TODO(thomaswc): Store the leaf scores in a TopN and only update the # scores of the leaves that were touched by this batch of input. children = array_ops.squeeze( array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32( array_ops.squeeze( array_ops.where(is_leaf), squeeze_dims=[1])) non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less(array_ops.gather( self.variables.node_to_accumulator_map, leaves), 0)) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = tensor_forest_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression) # Grow tree. with ops.control_dependencies([update_features_op, update_thresholds_op, non_fertile_leaves.op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, new_eot) = (tensor_forest_ops.grow_tree( self.variables.end_of_tree, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) tree_update_op = state_ops.scatter_update( self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like(tree_threshold_updates, dtype=dtypes.int32) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates) # Update fertile slots. with ops.control_dependencies([tree_update_op]): (n2a_map_updates, a2n_map_updates, accumulators_cleared, accumulators_allocated) = (tensor_forest_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, self.variables.node_sums, regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. with ops.control_dependencies([n2a_map_updates.op]): eot_update_op = state_ops.assign(self.variables.end_of_tree, new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update(self.variables.node_to_accumulator_map, n2a_map_updates[0], n2a_map_updates[1])) updates.append( state_ops.scatter_update(self.variables.accumulator_to_node_map, a2n_map_updates[0], a2n_map_updates[1])) cleared_and_allocated_accumulators = array_ops.concat( [accumulators_cleared, accumulators_allocated], 0) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims(array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1), 2), [1, self.params.num_splits_to_consider, self.params.num_output_columns]) updates.append(state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: updates.append(state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims( math_ops.negative(array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) total_reset = array_ops.tile( array_ops.expand_dims( array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) accumulator_updates = array_ops.concat([total_cleared, total_reset], 0) updates.append(state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: updates.append(state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims( math_ops.negative(array_ops.ones_like( cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) updates.append(state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) updates += self.finish_iteration() return control_flow_ops.group(*updates)
def _finish(self, state): var_dtype = self._variables[0].dtype.base_dtype # Update global step. global_step = self._get_global_step(state) update_global_step = state_ops.assign_add(global_step, 1.) # Update the first moment estimate. beta1 = state.get_hyper("beta1", dtype=var_dtype) moment1 = self._get_moment1(state) flat_grad = self._get_flat_grad(state) # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) # Update the gradient buffer. window = state.get_hyper("window") grad_buffer = self._get_grad_buffer(state) next_grad_index = math_ops.floormod( math_ops.to_int32(update_global_step - 1.), window) # grad_buffer[(t-1) % window] := moment1_t update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, update_moment1) # Compute the update step. eps = state.get_hyper("eps", dtype=var_dtype) svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) lr = state.get_hyper("lr", dtype=var_dtype) denom = math_ops.sqrt( math_ops.minimum( ops.convert_to_tensor(update_global_step), ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) moment1_2d = array_ops.expand_dims(update_moment1, -1) # m = grad_buffer^T / sqrt(min(t, window)) # m has shape [model dimension, window], where model dimension is the sum # of the dimensions of the flattened variables. m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) # sigma, u, _ = SVD(m^Tm + I * svd_eps) mm = math_ops.matmul(m, m, transpose_a=True) damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps sigma, u, _ = linalg_ops.svd(mm + damping) sigma_sqrt = math_ops.sqrt(sigma) sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 # We add sigma_eps to alleviate numerical instability. # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. sigma_sqrt_inv = math_ops.divide( math_ops.cast(1.0, dtype=var_dtype), math_ops.pow(sigma_sqrt + sigma_eps, 3)) # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the # inversion of a model dimension by model dimension matrix is needed. To # speed up this computation we calculate the following instead: # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. new_step = array_ops.expand_dims( array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) head = math_ops.matmul( m, math_ops.matmul( u, math_ops.matmul( array_ops.diag(sigma_sqrt_inv), math_ops.matmul( u, math_ops.matmul(m, moment1_2d, transpose_a=True), transpose_a=True)))) # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using # Woodbury's identity. # For full derivation please see paper at # https://arxiv.org/pdf/1806.02958.pdf tail = moment1_2d - math_ops.matmul( m, math_ops.matmul( u, math_ops.matmul( array_ops.diag( math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), sigma)), math_ops.matmul( u, math_ops.matmul(m, moment1_2d, transpose_a=True), transpose_a=True)))) scaled_tail = math_ops.divide(tail, sigma_sqrt_min) update_new_step = control_flow_ops.cond( sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), lambda: math_ops.add(new_step, head)) # Update each variable. update_step = [] for var in self._variables: dim = self.shape_dict[var.name] start_index = self.index_dict[var.name] end_index = start_index + dim var_update_correct_shape = array_ops.reshape( update_new_step[start_index:end_index], var.get_shape()) var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) update_step.append(var_updated) return control_flow_ops.group(update_step)