def testMissingLabel(self): labels = [0, 1, -1, 3] with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=5, regression=False)) self.assertAllEqual( [[3., 1., 1., 0., 1.], [2., 1., 1., 0., 0.], [1., 0., 0., 0., 1.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSimpleWeighted(self): with self.test_session(): input_weights = [1.0, 2.0, 3.0, 4.0] (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums, pcw_splits_squares, pcw_totals_indices, pcw_totals_sums, pcw_totals_squares, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, input_weights, self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=2, regression=True)) self.assertAllEqual([[10., 33.], [3., 15.], [7., 18.]], pcw_node_sums.eval()) self.assertAllEqual([[10., 129.], [3., 81.], [7., 48.]], pcw_node_squares.eval()) self.assertAllEqual([[0, 0]], pcw_splits_indices.eval()) self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval()) self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval()) self.assertAllEqual([[0]], pcw_totals_indices.eval()) self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval()) self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSimpleWeighted(self): with self.test_session(): input_weights = [1.0, 2.0, 3.0, 4.0] (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums, pcw_splits_squares, pcw_totals_indices, pcw_totals_sums, pcw_totals_squares, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, input_weights, self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=2, regression=True)) self.assertAllEqual([[10., 33.], [3., 15.], [7., 18.]], pcw_node_sums.eval()) self.assertAllEqual([[10., 129.], [3., 81.], [7., 48.]], pcw_node_squares.eval()) self.assertAllEqual([[0, 0]], pcw_splits_indices.eval()) self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval()) self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval()) self.assertAllEqual([[0]], pcw_totals_indices.eval()) self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval()) self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testNoAccumulators(self): with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, [], self.tree, self.tree_thresholds, [-1] * 3, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], pcw_node_sums.eval()) self.assertEquals((0, 3), pcw_splits_indices.eval().shape) self.assertAllEqual([], pcw_splits_sums.eval()) self.assertEquals((0, 2), pcw_totals_indices.eval().shape) self.assertAllEqual([], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSimpleWeighted(self): with self.test_session(): input_weights = [1.5, 2.0, 3.0, 4.0] (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, input_weights, self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[10.5, 1.5, 2., 3., 4.], [3.5, 1.5, 2., 0., 0.], [7., 0., 0., 3., 4.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1.5, 1.5], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([2., 3.5, 1.5], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testMissingLabel(self): labels = [0, 1, -1, 3] with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[3., 1., 1., 0., 1.], [2., 1., 1., 0., 0.], [1., 0., 0., 0., 1.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSimple(self): with self.test_session(): (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums, pcw_splits_squares, pcw_totals_indices, pcw_totals_sums, pcw_totals_squares, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=2, regression=True)) self.assertAllEqual([[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval()) self.assertAllEqual([[4., 58.], [2., 45.], [2., 13.]], pcw_node_squares.eval()) self.assertAllEqual([[0, 0]], pcw_splits_indices.eval()) self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval()) self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval()) self.assertAllEqual([[0]], pcw_totals_indices.eval()) self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval()) self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testThreaded(self): with self.test_session(config=tf.ConfigProto( intra_op_parallelism_threads=2)): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSparseInput(self): sparse_shape = [4, 10] sparse_indices = [[0, 0], [0, 4], [0, 9], [1, 0], [1, 7], [2, 0], [3, 1], [3, 4]] sparse_values = [3.0, -1.0, 0.5, 1.5, 6.0, -2.0, -0.5, 2.0] with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( [], sparse_indices, sparse_values, sparse_shape, self.data_spec, self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 0., 0., 1., 1.], [2., 1., 1., 0., 0.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]], pcw_splits_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([2, 2, 1, 1], leaves.eval())
def testSimple(self): with self.test_session(): (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums, pcw_splits_squares, pcw_totals_indices, pcw_totals_sums, pcw_totals_squares, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=2, regression=True)) self.assertAllEqual([[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval()) self.assertAllEqual([[4., 58.], [2., 45.], [2., 13.]], pcw_node_squares.eval()) self.assertAllEqual([[0, 0]], pcw_splits_indices.eval()) self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval()) self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval()) self.assertAllEqual([[0]], pcw_totals_indices.eval()) self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval()) self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSimpleWeighted(self): with self.test_session(): input_weights = [1.5, 2.0, 3.0, 4.0] (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, input_weights, self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=5, regression=False)) self.assertAllEqual([[10.5, 1.5, 2., 3., 4.], [3.5, 1.5, 2., 0., 0.], [7., 0., 0., 3., 4.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1.5, 1.5], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([2., 3.5, 1.5], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testThreaded(self): with self.test_session( config=config_pb2.ConfigProto(intra_op_parallelism_threads=2)): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval()) self.assertAllEqual([1., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testSparseInput(self): sparse_shape = [4, 10] sparse_indices = [[0, 0], [0, 4], [0, 9], [1, 1], [1, 7], [2, 0], [3, 0], [3, 4]] sparse_values = [3.0, -1.0, 0.5, -1.5, 6.0, -2.0, -0.5, 2.0] spec_proto = data_ops.TensorForestDataSpec() f1 = spec_proto.sparse.add() f1.name = 'f1' f1.original_type = data_ops.DATA_FLOAT f1.size = -1 spec_proto.dense_features_size = 0 data_spec = spec_proto.SerializeToString() with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( [], sparse_indices, sparse_values, sparse_shape, self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=data_spec, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 0., 0., 1., 1.], [2., 1., 1., 0., 0.]], pcw_node_sums.eval()) self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]], pcw_splits_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval()) self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval()) self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([2, 2, 1, 1], leaves.eval())
def testFutureEpoch(self): current_epoch = [3] with self.test_session(): (pcw_node_sums, _, _, pcw_splits_sums, _, _, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.data_spec, self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], pcw_node_sums.eval()) self.assertAllEqual([], pcw_splits_sums.eval()) self.assertAllEqual([], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testBadInput(self): del self.node_map[-1] with self.test_session(): with self.assertRaisesOpError( 'Number of nodes should be the same in ' 'tree, tree_thresholds, node_to_accumulator, and birth_epoch.'): pcw_node, _, _, _, _, _, _, _, _ = ( tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, self.current_epoch, input_spec=self.data_spec, num_classes=5, regression=False)) self.assertAllEqual([], pcw_node.eval())
def testFutureEpoch(self): current_epoch = [3] with self.test_session(): (pcw_node_sums, _, _, pcw_splits_sums, _, _, pcw_totals_sums, _, leaves) = (tensor_forest_ops.count_extremely_random_stats( self.input_data, [], [], [], self.input_labels, [], self.tree, self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, self.epochs, current_epoch, input_spec=self.data_spec, num_classes=5, regression=False)) self.assertAllEqual( [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], pcw_node_sums.eval()) self.assertAllEqual([], pcw_splits_sums.eval()) self.assertAllEqual([], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def training_graph(self, input_data, input_labels, random_seed, data_spec, sparse_features=None, input_weights=None): """Constructs a TF graph for training a random tree. Args: input_data: A tensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. data_spec: A data_ops.TensorForestDataSpec object specifying the original feature/columns of the data. sparse_features: A tf.SparseTensor for sparse input data. input_weights: A float tensor or placeholder holding per-input weights, or None if all inputs are to be weighted equally. Returns: The last op in the random tree training graph. """ epoch = math_ops.to_int32(get_epoch_variable()) serialized_input_spec = data_spec.SerializeToString() if input_weights is None: input_weights = [] if input_data is None: input_data = [] sparse_indices = [] sparse_values = [] sparse_shape = [] if sparse_features is not None: sparse_indices = sparse_features.indices sparse_values = sparse_features.values sparse_shape = sparse_features.dense_shape # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = (tensor_forest_ops.count_extremely_random_stats( input_data, sparse_indices, sparse_values, sparse_shape, input_labels, input_weights, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, self.variables.start_epoch, epoch, input_spec=serialized_input_spec, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.candidate_split_sums, splits_indices, splits_sums)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_sums, totals_indices, totals_sums)) if self.params.regression: node_update_ops.append(state_ops.assign_add(self.variables.node_squares, node_squares)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) splits_update_ops.append( tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_squares, totals_indices, totals_squares)) # Sample inputs. update_indices, feature_updates, threshold_updates = ( tensor_forest_ops.sample_inputs( input_data, sparse_indices, sparse_values, sparse_shape, input_weights, self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, input_spec=serialized_input_spec, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. with ops.control_dependencies(splits_update_ops): # Passing input_leaves to finished nodes here means that nodes that # have become stale won't be deallocated until an input reaches them, # because we're trying to avoid considering every fertile node for # performance reasons. finished, stale = tensor_forest_ops.finished_nodes( input_leaves, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, self.variables.start_epoch, epoch, num_split_after_samples=self.params.split_after_samples, min_split_samples=self.params.min_split_samples, dominate_method=self.params.dominate_method, dominate_fraction=self.params.dominate_fraction) # Update leaf scores. # TODO(thomaswc): Store the leaf scores in a TopN and only update the # scores of the leaves that were touched by this batch of input. children = array_ops.squeeze( array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32( array_ops.squeeze( array_ops.where(is_leaf), squeeze_dims=[1])) non_fertile_leaves = array_ops.boolean_mask( leaves, math_ops.less(array_ops.gather( self.variables.node_to_accumulator_map, leaves), 0)) # TODO(gilberth): It should be possible to limit the number of non # fertile leaves we calculate scores for, especially since we can only take # at most array_ops.shape(finished)[0] of them. with ops.control_dependencies(node_update_ops): sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) if self.params.regression: squares = array_ops.gather(self.variables.node_squares, non_fertile_leaves) non_fertile_leaf_scores = self._variance(sums, squares) else: non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. with ops.control_dependencies(splits_update_ops): split_indices = tensor_forest_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, self.variables.candidate_split_squares, self.variables.accumulator_sums, self.variables.accumulator_squares, regression=self.params.regression) # Grow tree. with ops.control_dependencies([update_features_op, update_thresholds_op, non_fertile_leaves.op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, new_eot) = (tensor_forest_ops.grow_tree( self.variables.end_of_tree, self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) tree_update_op = state_ops.scatter_update( self.variables.tree, tree_update_indices, tree_children_updates) thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) # TODO(thomaswc): Only update the epoch on the new leaves. new_epoch_updates = epoch * array_ops.ones_like(tree_threshold_updates, dtype=dtypes.int32) epoch_update_op = state_ops.scatter_update( self.variables.start_epoch, tree_update_indices, new_epoch_updates) # Update fertile slots. with ops.control_dependencies([tree_update_op]): (n2a_map_updates, a2n_map_updates, accumulators_cleared, accumulators_allocated) = (tensor_forest_ops.update_fertile_slots( finished, non_fertile_leaves, non_fertile_leaf_scores, self.variables.end_of_tree, self.variables.accumulator_sums, self.variables.node_to_accumulator_map, stale, self.variables.node_sums, regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. with ops.control_dependencies([n2a_map_updates.op]): eot_update_op = state_ops.assign(self.variables.end_of_tree, new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) updates.append(thresholds_update_op) updates.append(epoch_update_op) updates.append( state_ops.scatter_update(self.variables.node_to_accumulator_map, n2a_map_updates[0], n2a_map_updates[1])) updates.append( state_ops.scatter_update(self.variables.accumulator_to_node_map, a2n_map_updates[0], a2n_map_updates[1])) cleared_and_allocated_accumulators = array_ops.concat( [accumulators_cleared, accumulators_allocated], 0) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. split_values = array_ops.tile( array_ops.expand_dims(array_ops.expand_dims( array_ops.zeros_like(cleared_and_allocated_accumulators, dtype=dtypes.float32), 1), 2), [1, self.params.num_splits_to_consider, self.params.num_output_columns]) updates.append(state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: updates.append(state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. total_cleared = array_ops.tile( array_ops.expand_dims( math_ops.negative(array_ops.ones_like(accumulators_cleared, dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) total_reset = array_ops.tile( array_ops.expand_dims( array_ops.zeros_like(accumulators_allocated, dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) accumulator_updates = array_ops.concat([total_cleared, total_reset], 0) updates.append(state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: updates.append(state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. split_features_updates = array_ops.tile( array_ops.expand_dims( math_ops.negative(array_ops.ones_like( cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) updates.append(state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) updates += self.finish_iteration() return control_flow_ops.group(*updates)