def build_params_proto(params): """Build a TensorForestParams proto out of the V4ForestHParams object.""" proto = _params_proto.TensorForestParams() proto.num_trees = params.num_trees proto.max_nodes = params.max_nodes proto.is_regression = params.regression proto.num_outputs = params.num_classes proto.num_features = params.num_features proto.leaf_type = params.leaf_model_type proto.stats_type = params.stats_model_type proto.collection_type = _params_proto.COLLECTION_BASIC proto.pruning_type.type = params.pruning_type proto.finish_type.type = params.finish_type proto.inequality_test_type = params.split_type proto.drop_final_class = False proto.collate_examples = params.collate_examples proto.checkpoint_stats = params.checkpoint_stats proto.use_running_stats_method = params.use_running_stats_method proto.initialize_average_splits = params.initialize_average_splits proto.inference_tree_paths = params.inference_tree_paths parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples, params.prune_every_samples) parse_number_or_string_to_proto(proto.finish_type.check_every_steps, params.early_finish_check_every_samples) parse_number_or_string_to_proto(proto.split_after_samples, params.split_after_samples) parse_number_or_string_to_proto(proto.num_splits_to_consider, params.num_splits_to_consider) proto.dominate_fraction.constant_value = params.dominate_fraction if params.param_file: with open(params.param_file) as f: text_format.Merge(f.read(), proto) return proto
def build_params_proto(params): """Build a TensorForestParams proto out of the V4ForestHParams object.""" proto = _params_proto.TensorForestParams() proto.num_trees = params.num_trees proto.max_nodes = params.max_nodes proto.is_regression = params.regression proto.num_outputs = params.num_classes proto.num_features = params.num_features proto.leaf_type = params.v4_leaf_model_type proto.stats_type = params.v4_stats_model_type proto.collection_type = params.v4_split_collection_type proto.pruning_type.type = params.v4_pruning_type proto.finish_type.type = params.v4_finish_type proto.inequality_test_type = params.v4_split_type proto.drop_final_class = False proto.collate_examples = params.v4_collate_examples proto.checkpoint_stats = params.v4_checkpoint_stats proto.use_running_stats_method = params.v4_use_running_stats_method proto.initialize_average_splits = params.v4_initialize_average_splits if params.v4_prune_every_samples: text_format.Merge(params.v4_prune_every_samples, proto.pruning_type.prune_every_samples) else: # Pruning half-way through split_after_samples seems like a decent default, # making it easy to select the number being pruned with v4_pruning_type # while not paying the cost of pruning too often. Note that this only holds # if not using a depth-dependent split_after_samples. if params.v4_split_after_samples: logging.error( 'If using depth-dependent split_after_samples and also pruning, ' 'need to set v4_prune_every_samples') proto.pruning_type.prune_every_samples.constant_value = ( params.split_after_samples / 2) if params.v4_finish_check_every_samples: text_format.Merge(params.v4_finish_check_every_samples, proto.finish_type.check_every_steps) else: # Checking for finish every quarter through split_after_samples seems # like a decent default. We don't want to incur the checking cost too often, # but (at least for hoeffding) it's lower than the cost of pruning so # we can do it a little more frequently. proto.finish_type.check_every_steps.constant_value = int( params.split_after_samples / 4) if params.v4_split_after_samples: text_format.Merge(params.v4_split_after_samples, proto.split_after_samples) else: proto.split_after_samples.constant_value = params.split_after_samples if params.v4_num_splits_to_consider: text_format.Merge(params.v4_num_splits_to_consider, proto.num_splits_to_consider) else: proto.num_splits_to_consider.constant_value = params.num_splits_to_consider proto.dominate_fraction.constant_value = params.dominate_fraction proto.min_split_samples.constant_value = params.split_after_samples if params.v4_param_file: with open(params.v4_param_file) as f: text_format.Merge(f.read(), proto) return proto