def erosion2d(value, kernel, strides, rates, padding, name=None): """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors. The `value` tensor has shape `[batch, in_height, in_width, depth]` and the `kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e., each input channel is processed independently of the others with its own structuring function. The `output` tensor has shape `[batch, out_height, out_width, depth]`. The spatial dimensions of the output tensor depend on the `padding` algorithm. We currently only support the default "NHWC" `data_format`. In detail, the grayscale morphological 2-D erosion is given by: output[b, y, x, c] = min_{dy, dx} value[b, strides[1] * y - rates[1] * dy, strides[2] * x - rates[2] * dx, c] - kernel[dy, dx, c] Duality: The erosion of `value` by the `kernel` is equal to the negation of the dilation of `-value` by the reflected `kernel`. Args: value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`. kernel: A `Tensor`. Must have the same type as `value`. 3-D with shape `[kernel_height, kernel_width, depth]`. strides: A list of `ints` that has length `>= 4`. 1-D of length 4. The stride of the sliding window for each dimension of the input tensor. Must be: `[1, stride_height, stride_width, 1]`. rates: A list of `ints` that has length `>= 4`. 1-D of length 4. The input stride for atrous morphological dilation. Must be: `[1, rate_height, rate_width, 1]`. padding: A `string` from: `"SAME", "VALID"`. The type of padding algorithm to use. name: A name for the operation (optional). If not specified "erosion2d" is used. Returns: A `Tensor`. Has the same type as `value`. 4-D with shape `[batch, out_height, out_width, depth]`. Raises: ValueError: If the `value` depth does not match `kernel`' shape, or if padding is other than `'VALID'` or `'SAME'`. """ with ops.op_scope([value, kernel], name, "erosion2d") as name: # Reduce erosion to dilation by duality. return math_ops.neg(gen_nn_ops.dilation2d(input=math_ops.neg(value), filter=array_ops.reverse( kernel, [True, True, False]), strides=strides, rates=rates, padding=padding, name=name))
def testInitializerFunction(self): value = [[-42], [133.7]] shape = [2, 1] with self.test_session(): initializer = lambda: constant_op.constant(value) v1 = variables.Variable(initializer, dtype=dtypes.float32) self.assertEqual(shape, v1.get_shape()) self.assertAllClose(value, v1.initial_value.eval()) with self.assertRaises(errors_impl.FailedPreconditionError): v1.eval() v2 = variables.Variable( math_ops.neg(v1.initialized_value()), dtype=dtypes.float32) self.assertEqual(v1.get_shape(), v2.get_shape()) self.assertAllClose(np.negative(value), v2.initial_value.eval()) # Once v2.initial_value.eval() has been called, v1 has effectively been # initialized. self.assertAllClose(value, v1.eval()) with self.assertRaises(errors_impl.FailedPreconditionError): v2.eval() variables.global_variables_initializer().run() self.assertAllClose(np.negative(value), v2.eval())
def testInitializerFunction(self): value = [[-42], [133.7]] shape = [2, 1] with self.test_session(): initializer = lambda: constant_op.constant(value) v1 = variables.Variable(initializer, dtype=dtypes.float32) self.assertEqual(shape, v1.get_shape()) self.assertAllClose(value, v1.initial_value.eval()) with self.assertRaises(errors_impl.FailedPreconditionError): v1.eval() v2 = variables.Variable(math_ops.neg(v1.initialized_value()), dtype=dtypes.float32) self.assertEqual(v1.get_shape(), v2.get_shape()) self.assertAllClose(np.negative(value), v2.initial_value.eval()) # Once v2.initial_value.eval() has been called, v1 has effectively been # initialized. self.assertAllClose(value, v1.eval()) with self.assertRaises(errors_impl.FailedPreconditionError): v2.eval() variables.global_variables_initializer().run() self.assertAllClose(np.negative(value), v2.eval())
def _SparseUpdate(variable, gradients, accum, linear, base_lr, lr_power, l1, l2): """Sparse Update "variable", "accum", "linear" based on sparse "gradients". See the description in _Update. Args: variable: A Variable. gradients: A Sparse Tensor accum: A Variable containing the sum of the squares of gradients. linear: A Variable containing approximation info. base_lr: A constant represents base learning rate. lr_power: A constant is used to adjust learning rate. l1: A constant represents l1 regularization strength. l2: A constant represents l2 regularization strength. Returns: A group op including three ScatterUpdate ops: 1. ScatterUpdate for "accum" 2. ScatterUpdate for "linear" 3. ScatterUpdate for "variable" """ assert isinstance(gradients, ops.IndexedSlices) with ops.name_scope("sparse_update_" + variable.op.name) as scope: dtype = variable.dtype.base_dtype base_lr = ops.convert_to_tensor(base_lr, dtype=dtype) lr_power = ops.convert_to_tensor(lr_power, dtype=dtype) l1 = ops.convert_to_tensor(l1, dtype=dtype) l2 = ops.convert_to_tensor(l2, dtype=dtype) # Compute the new value for the accumulator previous_accum = array_ops.gather(accum, gradients.indices) sqr_grad = gradients.values * gradients.values accum_updated = sqr_grad + previous_accum # Compute the new linear neg_lr_power = math_ops.neg(lr_power) sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow( previous_accum, neg_lr_power) sigma /= base_lr variable_slice = array_ops.gather(variable, gradients.indices) proximal_adjust = sigma * variable_slice linear_slice = array_ops.gather(linear, gradients.indices) linear_updated = linear_slice + gradients.values - proximal_adjust # Compute the new "variable" variable_updated = _Compute(accum_updated, linear_updated, base_lr, lr_power, l1, l2) with ops.control_dependencies([sigma]): accum_update_op = state_ops.scatter_update(accum, gradients.indices, accum_updated) linear_update_op = state_ops.scatter_update(linear, gradients.indices, linear_updated) variable_update_op = state_ops.scatter_update(variable, gradients.indices, variable_updated) group_op = control_flow_ops.group(linear_update_op, accum_update_op, variable_update_op, name=scope) return group_op
def natural_exp_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None): """Applies natural exponential decay to the initial learning rate. When training a model, it is often recommended to lower the learning rate as the training progresses. This function applies an exponential decay function to a provided initial learning rate. It requires an `global_step` value to compute the decayed learning rate. You can just pass a TensorFlow variable that you increment at each training step. The function returns the decayed learning rate. It is computed as: ```python decayed_learning_rate = learning_rate * exp(-decay_rate * global_step) ``` Example: decay exponetially with a base of 0.96: ```python ... global_step = tf.Variable(0, trainable=False) learning_rate = 0.1 k = 0.5 learning_rate = tf.train.exponential_time_decay(learning_rate, global_step, k) # Passing global_step to minimize() will increment it at each step. learning_step = ( tf.train.GradientDescentOptimizer(learning_rate) .minimize(...my loss..., global_step=global_step) ) ``` Args: learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. The initial learning rate. global_step: A Python number. Global step to use for the decay computation. Must not be negative. decay_rate: A Python number. The decay rate. name: String. Optional name of the operation. Defaults to 'ExponentialTimeDecay' Returns: A scalar `Tensor` of the same type as `learning_rate`. The decayed learning rate. """ with ops.name_scope(name, "NaturalExpDecay", [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) p = global_step / decay_steps if staircase: p = math_ops.floor(p) exponent = math_ops.exp(math_ops.mul(math_ops.neg(decay_rate), p)) return math_ops.mul(learning_rate, exponent, name=name)
def _Update(variable, gradients, accum, linear, base_lr, lr_power, l1, l2): """Update "variable", "accum", "linear" based on "gradients". Some notations here: "variable" as W, "accum" as N, "linear" as Z, "gradients" as G, N(t) means "accum" at t-step. Assuming lr_power = -0.5 which means using adagrad learning rate. "accum" updates as: N = N + G^2 "linear" updates as: Z = Z + G - W * (sqrt(N(t)) - sqrt(N(t-1)))/base_lr REQUIRES: Dimensionality of variable, gradients, accum and linear must be same. Args: variable: A Variable. gradients: A Tensor of same shape as 'variable'. accum: A Variable containing the sum of the squares of gradients. linear: A Variable containing approximation info. base_lr: A constant represents base learning rate. lr_power: A constant is used to adjust learning rate. l1: A constant represents l1 regularization strength. l2: A constant represents l2 regularization strength. Returns: A group op including three Assign ops: 1. Assign for "accum" 2. Assign for "linear" 3. Assign for "variable" """ dtype = variable.dtype.base_dtype base_lr = ops.convert_to_tensor(base_lr, dtype=dtype) lr_power = ops.convert_to_tensor(lr_power, dtype=dtype) l1 = ops.convert_to_tensor(l1, dtype=dtype) l2 = ops.convert_to_tensor(l2, dtype=dtype) # Compute the new accumulator sqr_grad = math_ops.square(gradients) accum_updated = sqr_grad + accum # Compute the new linear neg_lr_power = math_ops.neg(lr_power) sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow( accum, neg_lr_power) sigma /= base_lr proximal_adjust = sigma * variable linear_updated = linear + gradients - proximal_adjust # Compute the "variable" variable_updated = _Compute(accum_updated, linear_updated, base_lr, lr_power, l1, l2) with ops.control_dependencies([sigma]): accum_update_op = state_ops.assign(accum, accum_updated) linear_update_op = state_ops.assign(linear, linear_updated) variable_update_op = state_ops.assign(variable, variable_updated) group_op = control_flow_ops.group(linear_update_op, accum_update_op, variable_update_op) return group_op
def _flip_gradient_grad(op, grad): """The gradients for `flip_gradient`. Args: op: The `flip_gradient` `Operation` that we are differentiating, which we can use to find the inputs and outputs of the original op. grad: Gradient with respect to the output of the `flip_gradient` op. Returns: Gradients with respect to the input of `flip_gradient`. """ s = op.inputs[1] return [math_ops.neg(grad) * s, None]
def setUp(self): self.a = variables.Variable(2.0, name="a") self.b = variables.Variable(3.0, name="b") self.c = math_ops.mul(self.a, self.b, name="c") # Should be 6.0. self.d = math_ops.mul(self.a, self.a, name="d") # Should be 4.0. self.e = math_ops.mul(self.d, self.c, name="e") # Should be 24.0. self.f_y = constant_op.constant(0.30, name="f_y") self.f = math_ops.div(self.b, self.f_y, name="f") # Should be 10.0. # The there nodes x, y and z form a graph with "cross-links" in. I.e., x # and y are both direct inputs to z, but x is also a direct input to y. self.x = variables.Variable(2.0, name="x") # Should be 2.0 self.y = math_ops.neg(self.x, name="y") # Should be -2.0. self.z = math_ops.mul(self.x, self.y, name="z") # Should be -4.0. self.sess = session.Session() self.sess.run(variables.global_variables_initializer()) self.sess = session.Session() self.sess.run(variables.global_variables_initializer())
def training_loss(self, features, labels, data_spec=None, name='training_loss'): return math_ops.neg(self.average_size(), name=name)
def training_graph(self, input_data, input_labels, random_seed, data_spec, epoch=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A list of tf.dtype values specifying the original types of each column. epoch: A tensor or placeholder for the epoch the training data comes from. Returns: The last op in the random tree training graph. """ epoch = [0] if epoch is None else epoch sparse_indices = [] sparse_values = [] sparse_shape = [] if isinstance(input_data, ops.SparseTensor): sparse_indices = input_data.indices sparse_values = input_data.values sparse_shape = input_data.shape input_data = [] # Count extremely random stats. ( node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves, ) = self.training_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, data_spec, input_labels, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression, ) node_update_ops = [] node_update_ops.append(state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.candidate_split_sums, splits_indices, splits_sums) ) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums) ) if self.params.regression: node_update_ops.append(state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares ) ) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_squares, totals_indices, totals_squares) ) # Sample inputs. update_indices, feature_updates, threshold_updates = self.training_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=(self.params.split_initializations_per_input), split_sampling_random_seed=random_seed, ) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates ) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates ) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): children = array_ops.squeeze(array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) finished, stale = self.training_ops.finished_nodes( leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples, ) # Update leaf scores. non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less(array_ops.gather(self.variables.node_to_accumulator_map, leaves), 0) ) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression, ) # Grow tree. with ops.control_dependencies([update_features_op, update_thresholds_op]): ( tree_update_indices, tree_children_updates, tree_threshold_updates, tree_depth_updates, new_eot, ) = self.training_ops.grow_tree( self.variables.end_of_tree, self.variables.tree_depths, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, ) tree_update_op = state_ops.scatter_update(self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates ) depth_update_op = state_ops.scatter_update( self.variables.tree_depths, tree_update_indices, tree_depth_updates ) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates ) # Update fertile slots. with ops.control_dependencies([depth_update_op]): (node_map_updates, accumulators_cleared, accumulators_allocated) = self.training_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.tree_depths, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, max_depth=self.params.max_depth, regression=self.params.regression, ) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. gated_new_eot, = control_flow_ops.tuple([new_eot], control_inputs=[node_map_updates]) eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update( self.variables.node_to_accumulator_map, array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]), squeeze_dims=[0]), array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]), squeeze_dims=[0]), ) ) cleared_and_allocated_accumulators = array_ops.concat(0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims( array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1 ), 2, ), [1, self.params.num_splits_to_consider, self.params.num_output_columns], ) updates.append( state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values ) ) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values ) ) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims(math_ops.neg(array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns], ) total_reset = array_ops.tile( array_ops.expand_dims(array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns], ) accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) updates.append( state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates ) ) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates ) ) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims(math_ops.neg(array_ops.ones_like(cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider], ) updates.append( state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates ) ) updates += self.finish_iteration() return control_flow_ops.group(*updates)
def validation_loss(self, features, labels): return math_ops.neg(self.average_size())
def training_loss(self): return math_ops.neg(self.average_size())
def training_loss(self, features, labels): return math_ops.neg(self.average_size())
def training_graph(self, input_data, input_labels, random_seed, data_spec, input_weights=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A list of tf.dtype values specifying the original types of each column. input_weights: A float tensor or placeholder holding per-input weights, or None if all inputs are to be weighted equally. Returns: The last op in the random tree training graph. """ epoch = math_ops.to_int32(get_epoch_variable()) if input_weights is None: input_weights = [] sparse_indices = [] sparse_values = [] sparse_shape = [] if isinstance(input_data, sparse_tensor.SparseTensor): sparse_indices = input_data.indices sparse_values = input_data.values sparse_shape = input_data.dense_shape input_data = [] # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = (self.training_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, data_spec, input_labels, input_weights, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_sums, splits_indices, splits_sums)) splits_update_ops.append( self.training_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums)) if self.params.regression: node_update_ops.append( state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) splits_update_ops.append( self.training_ops.scatter_add_ndim( self.variables.accumulator_squares, totals_indices, totals_squares)) # Sample inputs. update_indices, feature_updates, threshold_updates = ( self.training_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, input_weights, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): # Passing input_leaves to finished nodes here means that nodes that # have become stale won't be deallocated until an input reaches them, # because we're trying to avoid considering every fertile node for # performance reasons. finished, stale = self.training_ops.finished_nodes( input_leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples, dominate_method=self.params.dominate_method, dominate_fraction=self.params.dominate_fraction) # Update leaf scores. # TODO(thomaswc): Store the leaf scores in a TopN and only update the # scores of the leaves that were touched by this batch of input. children = array_ops.squeeze(array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32( array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less( array_ops.gather(self.variables.node_to_accumulator_map, leaves), 0)) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression) # Grow tree. with ops.control_dependencies( [update_features_op, update_thresholds_op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, new_eot) = (self.training_ops.grow_tree( self.variables.end_of_tree, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) tree_update_op = state_ops.scatter_update(self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like( tree_threshold_updates, dtype=dtypes.int32) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates) # Update fertile slots. with ops.control_dependencies([tree_update_op]): (n2a_map_updates, a2n_map_updates, accumulators_cleared, accumulators_allocated) = (self.training_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, self.variables.node_sums, regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. gated_new_eot, = control_flow_ops.tuple( [new_eot], control_inputs=[n2a_map_updates]) eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update(self.variables.node_to_accumulator_map, n2a_map_updates[0], n2a_map_updates[1])) updates.append( state_ops.scatter_update(self.variables.accumulator_to_node_map, a2n_map_updates[0], a2n_map_updates[1])) cleared_and_allocated_accumulators = array_ops.concat( 0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims( array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1), 2), [ 1, self.params.num_splits_to_consider, self.params.num_output_columns ]) updates.append( state_ops.scatter_update(self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: updates.append( state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims( math_ops.neg( array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) total_reset = array_ops.tile( array_ops.expand_dims( array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) updates.append( state_ops.scatter_update(self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: updates.append( state_ops.scatter_update(self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims( math_ops.neg( array_ops.ones_like(cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) updates.append( state_ops.scatter_update(self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) updates += self.finish_iteration() return control_flow_ops.group(*updates)